# Owner(s): ["oncall: mobile"]

import fnmatch
import io
import shutil
import tempfile
from pathlib import Path

import torch
import torch.utils.show_pickle

# from torch.utils.mobile_optimizer import optimize_for_mobile
from torch.jit.mobile import (
    _backport_for_mobile,
    _backport_for_mobile_to_buffer,
    _get_mobile_model_contained_types,
    _get_model_bytecode_version,
    _get_model_ops_and_info,
    _load_for_lite_interpreter,
)
from torch.testing._internal.common_utils import run_tests, TestCase


pytorch_test_dir = Path(__file__).resolve().parents[1]

# script_module_v4.ptl and script_module_v5.ptl source code
# class TestModule(torch.nn.Module):
#     def __init__(self, v):
#         super().__init__()
#         self.x = v

#     def forward(self, y: int):
#         increment = torch.ones([2, 4], dtype=torch.float64)
#         return self.x + y + increment

# output_model_path = Path(tmpdirname, "script_module_v5.ptl")
# script_module = torch.jit.script(TestModule(1))
# optimized_scripted_module = optimize_for_mobile(script_module)
# exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
#   str(output_model_path))

SCRIPT_MODULE_V4_BYTECODE_PKL = """
(4,
 ('__torch__.*.TestModule.forward',
  (('instructions',
    (('STOREN', 1, 2),
     ('DROPR', 1, 0),
     ('LOADC', 0, 0),
     ('LOADC', 1, 0),
     ('MOVE', 2, 0),
     ('OP', 0, 0),
     ('LOADC', 1, 0),
     ('OP', 1, 0),
     ('RET', 0, 0))),
   ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
   ('constants',
    (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
       0,
       (2, 4),
       (4, 1),
       False,
       collections.OrderedDict()),
     1)),
   ('types', ()),
   ('register_size', 2)),
  (('arguments',
    ((('name', 'self'),
      ('type', '__torch__.*.TestModule'),
      ('default_value', None)),
     (('name', 'y'), ('type', 'int'), ('default_value', None)))),
   ('returns',
    ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
        """

SCRIPT_MODULE_V5_BYTECODE_PKL = """
(5,
 ('__torch__.*.TestModule.forward',
  (('instructions',
    (('STOREN', 1, 2),
     ('DROPR', 1, 0),
     ('LOADC', 0, 0),
     ('LOADC', 1, 0),
     ('MOVE', 2, 0),
     ('OP', 0, 0),
     ('LOADC', 1, 0),
     ('OP', 1, 0),
     ('RET', 0, 0))),
   ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
   ('constants',
    (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, 'constants/0', 'cpu', 8),),
       0,
       (2, 4),
       (4, 1),
       False,
       collections.OrderedDict()),
     1)),
   ('types', ()),
   ('register_size', 2)),
  (('arguments',
    ((('name', 'self'),
      ('type', '__torch__.*.TestModule'),
      ('default_value', None)),
     (('name', 'y'), ('type', 'int'), ('default_value', None)))),
   ('returns',
    ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
        """

SCRIPT_MODULE_V6_BYTECODE_PKL = """
(6,
 ('__torch__.*.TestModule.forward',
  (('instructions',
    (('STOREN', 1, 2),
     ('DROPR', 1, 0),
     ('LOADC', 0, 0),
     ('LOADC', 1, 0),
     ('MOVE', 2, 0),
     ('OP', 0, 0),
     ('OP', 1, 0),
     ('RET', 0, 0))),
   ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))),
   ('constants',
    (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
       0,
       (2, 4),
       (4, 1),
       False,
       collections.OrderedDict()),
     1)),
   ('types', ()),
   ('register_size', 2)),
  (('arguments',
    ((('name', 'self'),
      ('type', '__torch__.*.TestModule'),
      ('default_value', None)),
     (('name', 'y'), ('type', 'int'), ('default_value', None)))),
   ('returns',
    ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
    """

SCRIPT_MODULE_BYTECODE_PKL = {
    4: {
        "bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL,
        "model_name": "script_module_v4.ptl",
    },
}

# The minimum version a model can be backported to
# Need to be updated when a bytecode version is completely retired
MINIMUM_TO_VERSION = 4


class testVariousModelVersions(TestCase):
    def test_get_model_bytecode_version(self):
        def check_model_version(model_path, expect_version):
            actual_version = _get_model_bytecode_version(model_path)
            assert actual_version == expect_version

        for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
            model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
            check_model_version(model_path, version)

    def test_bytecode_values_for_all_backport_functions(self):
        # Find the maximum version of the checked in models, start backporting to the minimum support version,
        # and comparing the bytecode pkl content.
        # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and
        # the content might change when optimize function changes. This test focuses
        # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but
        # regular expression matching. The wildcard can be used to skip some specific content comparison.
        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
        current_from_version = maximum_checked_in_model_version

        with tempfile.TemporaryDirectory() as tmpdirname:
            while current_from_version > MINIMUM_TO_VERSION:
                # Load model v5 and run forward method
                model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version][
                    "model_name"
                ]
                input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name

                # A temporary model file will be export to this path, and run through bytecode.pkl
                # content check.
                tmp_output_model_path_backport = Path(
                    tmpdirname, "tmp_script_module_backport.ptl"
                )

                current_to_version = current_from_version - 1
                backport_success = _backport_for_mobile(
                    input_model_path, tmp_output_model_path_backport, current_to_version
                )
                assert backport_success

                expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version][
                    "bytecode_pkl"
                ]

                buf = io.StringIO()
                torch.utils.show_pickle.main(
                    [
                        "",
                        tmpdirname
                        + "/"
                        + tmp_output_model_path_backport.name
                        + "@*/bytecode.pkl",
                    ],
                    output_stream=buf,
                )
                output = buf.getvalue()

                acutal_result_clean = "".join(output.split())
                expect_result_clean = "".join(expect_bytecode_pkl.split())
                isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
                assert isMatch

                current_from_version -= 1
            shutil.rmtree(tmpdirname)

    # Please run this test manually when working on backport.
    # This test passes in OSS, but fails internally, likely due to missing step in build
    # def test_all_backport_functions(self):
    #     # Backport from the latest bytecode version to the minimum support version
    #     # Load, run the backport model, and check version
    #     class TestModule(torch.nn.Module):
    #         def __init__(self, v):
    #             super().__init__()
    #             self.x = v

    #         def forward(self, y: int):
    #             increment = torch.ones([2, 4], dtype=torch.float64)
    #             return self.x + y + increment

    #     module_input = 1
    #     expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)

    #     # temporary input model file and output model file will be exported in the temporary folder
    #     with tempfile.TemporaryDirectory() as tmpdirname:
    #         tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl")
    #         script_module = torch.jit.script(TestModule(1))
    #         optimized_scripted_module = optimize_for_mobile(script_module)
    #         exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(str(tmp_input_model_path))

    #         current_from_version = _get_model_bytecode_version(tmp_input_model_path)
    #         current_to_version = current_from_version - 1
    #         tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl")

    #         while current_to_version >= MINIMUM_TO_VERSION:
    #             # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport"
    #             backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, current_to_version)
    #             assert(backport_success)

    #             backport_version = _get_model_bytecode_version(tmp_output_model_path)
    #             assert(backport_version == current_to_version)

    #             # Load model and run forward method
    #             mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path))
    #             mobile_module_result = mobile_module(module_input)
    #             torch.testing.assert_close(mobile_module_result, expected_mobile_module_result)
    #             current_to_version -= 1

    #         # Check backport failure case
    #         backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1)
    #         assert(not backport_success)
    #         # need to clean the folder before it closes, otherwise will run into git not clean error
    #         shutil.rmtree(tmpdirname)

    # Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
    def test_backport_bytecode_from_file_to_file(self):
        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
        script_module_v5_path = (
            pytorch_test_dir
            / "cpp"
            / "jit"
            / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"]
        )

        if maximum_checked_in_model_version > MINIMUM_TO_VERSION:
            with tempfile.TemporaryDirectory() as tmpdirname:
                tmp_backport_model_path = Path(
                    tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl"
                )
                # backport from file
                success = _backport_for_mobile(
                    script_module_v5_path,
                    tmp_backport_model_path,
                    maximum_checked_in_model_version - 1,
                )
                assert success

                buf = io.StringIO()
                torch.utils.show_pickle.main(
                    [
                        "",
                        tmpdirname
                        + "/"
                        + tmp_backport_model_path.name
                        + "@*/bytecode.pkl",
                    ],
                    output_stream=buf,
                )
                output = buf.getvalue()

                expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL
                acutal_result_clean = "".join(output.split())
                expect_result_clean = "".join(expected_result.split())
                isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
                assert isMatch

                # Load model v4 and run forward method
                mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
                module_input = 1
                mobile_module_result = mobile_module(module_input)
                expected_mobile_module_result = 3 * torch.ones(
                    [2, 4], dtype=torch.float64
                )
                torch.testing.assert_close(
                    mobile_module_result, expected_mobile_module_result
                )
                shutil.rmtree(tmpdirname)

    # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
    def test_backport_bytecode_from_file_to_buffer(self):
        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
        script_module_v5_path = (
            pytorch_test_dir
            / "cpp"
            / "jit"
            / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"]
        )

        if maximum_checked_in_model_version > MINIMUM_TO_VERSION:
            # Backport model to v4
            script_module_v4_buffer = _backport_for_mobile_to_buffer(
                script_module_v5_path, maximum_checked_in_model_version - 1
            )
            buf = io.StringIO()

            # Check version of the model v4 from backport
            bytesio = io.BytesIO(script_module_v4_buffer)
            backport_version = _get_model_bytecode_version(bytesio)
            assert backport_version == maximum_checked_in_model_version - 1

            # Load model v4 from backport and run forward method
            bytesio = io.BytesIO(script_module_v4_buffer)
            mobile_module = _load_for_lite_interpreter(bytesio)
            module_input = 1
            mobile_module_result = mobile_module(module_input)
            expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
            torch.testing.assert_close(
                mobile_module_result, expected_mobile_module_result
            )

    def test_get_model_ops_and_info(self):
        # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
        script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
        ops_v6 = _get_model_ops_and_info(script_module_v6)
        assert ops_v6["aten::add.int"].num_schema_args == 2
        assert ops_v6["aten::add.Scalar"].num_schema_args == 2

    def test_get_mobile_model_contained_types(self):
        class MyTestModule(torch.nn.Module):
            def forward(self, x):
                return x + 10

        sample_input = torch.tensor([1])

        script_module = torch.jit.script(MyTestModule())
        script_module_result = script_module(sample_input)

        buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
        buffer.seek(0)
        type_list = _get_mobile_model_contained_types(buffer)
        assert len(type_list) >= 0


if __name__ == "__main__":
    run_tests()
