# Owner(s): ["oncall: package/deploy"]

from io import BytesIO

from torch.package import PackageExporter, PackageImporter
from torch.package._mangling import (
    demangle,
    get_mangle_prefix,
    is_mangled,
    PackageMangler,
)
from torch.testing._internal.common_utils import run_tests


try:
    from .common import PackageTestCase
except ImportError:
    # Support the case where we run this file directly.
    from common import PackageTestCase


class TestMangling(PackageTestCase):
    def test_unique_manglers(self):
        """
        Each mangler instance should generate a unique mangled name for a given input.
        """
        a = PackageMangler()
        b = PackageMangler()
        self.assertNotEqual(a.mangle("foo.bar"), b.mangle("foo.bar"))

    def test_mangler_is_consistent(self):
        """
        Mangling the same name twice should produce the same result.
        """
        a = PackageMangler()
        self.assertEqual(a.mangle("abc.def"), a.mangle("abc.def"))

    def test_roundtrip_mangling(self):
        a = PackageMangler()
        self.assertEqual("foo", demangle(a.mangle("foo")))

    def test_is_mangled(self):
        a = PackageMangler()
        b = PackageMangler()
        self.assertTrue(is_mangled(a.mangle("foo.bar")))
        self.assertTrue(is_mangled(b.mangle("foo.bar")))

        self.assertFalse(is_mangled("foo.bar"))
        self.assertFalse(is_mangled(demangle(a.mangle("foo.bar"))))

    def test_demangler_multiple_manglers(self):
        """
        PackageDemangler should be able to demangle name generated by any PackageMangler.
        """
        a = PackageMangler()
        b = PackageMangler()

        self.assertEqual("foo.bar", demangle(a.mangle("foo.bar")))
        self.assertEqual("bar.foo", demangle(b.mangle("bar.foo")))

    def test_mangle_empty_errors(self):
        a = PackageMangler()
        with self.assertRaises(AssertionError):
            a.mangle("")

    def test_demangle_base(self):
        """
        Demangling a mangle parent directly should currently return an empty string.
        """
        a = PackageMangler()
        mangled = a.mangle("foo")
        mangle_parent = mangled.partition(".")[0]
        self.assertEqual("", demangle(mangle_parent))

    def test_mangle_prefix(self):
        a = PackageMangler()
        mangled = a.mangle("foo.bar")
        mangle_prefix = get_mangle_prefix(mangled)
        self.assertEqual(mangle_prefix + "." + "foo.bar", mangled)

    def test_unique_module_names(self):
        import package_a.subpackage

        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)
        f1 = BytesIO()
        with PackageExporter(f1) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj2)
        f1.seek(0)
        importer1 = PackageImporter(f1)
        loaded1 = importer1.load_pickle("obj", "obj.pkl")
        f1.seek(0)
        importer2 = PackageImporter(f1)
        loaded2 = importer2.load_pickle("obj", "obj.pkl")

        # Modules from loaded packages should not shadow the names of modules.
        # See mangling.md for more info.
        self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__)
        self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__)

    def test_package_mangler(self):
        a = PackageMangler()
        b = PackageMangler()
        a_mangled = a.mangle("foo.bar")
        # Since `a` mangled this string, it should demangle properly.
        self.assertEqual(a.demangle(a_mangled), "foo.bar")
        # Since `b` did not mangle this string, demangling should leave it alone.
        self.assertEqual(b.demangle(a_mangled), a_mangled)


if __name__ == "__main__":
    run_tests()
