#!/usr/bin/env python3
# Owner(s): ["oncall: mobile"]
# mypy: allow-untyped-defs

import io
import textwrap
from typing import Dict, List, Optional

import torch
import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import run_tests, TestCase


def model_size(sm):
    buffer = io.BytesIO()
    torch.jit.save(sm, buffer)
    return len(buffer.getvalue())


def save_and_load(sm):
    buffer = io.BytesIO()
    torch.jit.save(sm, buffer)
    buffer.seek(0)
    return torch.jit.load(buffer)


class TestBundledInputs(TestCase):
    def test_single_tensors(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        sm = torch.jit.script(SingleTensorModel())
        original_size = model_size(sm)
        get_expr: List[str] = []
        samples = [
            # Tensor with small numel and small storage.
            (torch.tensor([1]),),
            # Tensor with large numel and small storage.
            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
            # Tensor with small numel and large storage.
            (torch.tensor(range(1 << 16))[-8:],),
            # Large zero tensor.
            (torch.zeros(1 << 16),),
            # Large channels-last ones tensor.
            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
            # Special encoding of random tensor.
            (torch.utils.bundled_inputs.bundle_randn(1 << 16),),
            # Quantized uniform tensor.
            (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
        ]
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
            sm, samples, get_expr
        )
        # print(get_expr[0])
        # print(sm._generate_bundled_inputs.code)

        # Make sure the model only grew a little bit,
        # despite having nominally large bundled inputs.
        augmented_size = model_size(sm)
        self.assertLess(augmented_size, original_size + (1 << 12))

        loaded = save_and_load(sm)
        inflated = loaded.get_all_bundled_inputs()
        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
        self.assertEqual(len(inflated), len(samples))
        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])

        for idx, inp in enumerate(inflated):
            self.assertIsInstance(inp, tuple)
            self.assertEqual(len(inp), 1)
            self.assertIsInstance(inp[0], torch.Tensor)
            if idx != 5:
                # Strides might be important for benchmarking.
                self.assertEqual(inp[0].stride(), samples[idx][0].stride())
                self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)

        # This tensor is random, but with 100,000 trials,
        # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
        self.assertEqual(inflated[5][0].shape, (1 << 16,))
        self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
        self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)

    def test_large_tensor_with_inflation(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        sm = torch.jit.script(SingleTensorModel())
        sample_tensor = torch.randn(1 << 16)
        # We can store tensors with custom inflation functions regardless
        # of size, even if inflation is just the identity.
        sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)])

        loaded = save_and_load(sm)
        inflated = loaded.get_all_bundled_inputs()
        self.assertEqual(len(inflated), 1)

        self.assertEqual(inflated[0][0], sample_tensor)

    def test_rejected_tensors(self):
        def check_tensor(sample):
            # Need to define the class in this scope to get a fresh type for each run.
            class SingleTensorModel(torch.nn.Module):
                def forward(self, arg):
                    return arg

            sm = torch.jit.script(SingleTensorModel())
            with self.assertRaisesRegex(Exception, "Bundled input argument"):
                torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
                    sm, [(sample,)]
                )

        # Plain old big tensor.
        check_tensor(torch.randn(1 << 16))
        # This tensor has two elements, but they're far apart in memory.
        # We currently cannot represent this compactly while preserving
        # the strides.
        small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
        self.assertEqual(small_sparse.numel(), 2)
        check_tensor(small_sparse)

    def test_non_tensors(self):
        class StringAndIntModel(torch.nn.Module):
            def forward(self, fmt: str, num: int):
                return fmt.format(num)

        sm = torch.jit.script(StringAndIntModel())
        samples = [
            ("first {}", 1),
            ("second {}", 2),
        ]
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples)

        loaded = save_and_load(sm)
        inflated = loaded.get_all_bundled_inputs()
        self.assertEqual(inflated, samples)
        self.assertTrue(loaded(*inflated[0]) == "first 1")

    def test_multiple_methods_with_inputs(self):
        class MultipleMethodModel(torch.nn.Module):
            def forward(self, arg):
                return arg

            @torch.jit.export
            def foo(self, arg):
                return arg

        mm = torch.jit.script(MultipleMethodModel())
        samples = [
            # Tensor with small numel and small storage.
            (torch.tensor([1]),),
            # Tensor with large numel and small storage.
            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
            # Tensor with small numel and large storage.
            (torch.tensor(range(1 << 16))[-8:],),
            # Large zero tensor.
            (torch.zeros(1 << 16),),
            # Large channels-last ones tensor.
            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
        ]
        info = [
            "Tensor with small numel and small storage.",
            "Tensor with large numel and small storage.",
            "Tensor with small numel and large storage.",
            "Large zero tensor.",
            "Large channels-last ones tensor.",
            "Special encoding of random tensor.",
        ]
        torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
            mm,
            inputs={mm.forward: samples, mm.foo: samples},
            info={mm.forward: info, mm.foo: info},
        )
        loaded = save_and_load(mm)
        inflated = loaded.get_all_bundled_inputs()

        # Make sure these functions are all consistent.
        self.assertEqual(inflated, samples)
        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())

        # Check running and size helpers
        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))

        # Check helper that work on all functions
        all_info = loaded.get_bundled_inputs_functions_and_info()
        self.assertEqual(set(all_info.keys()), {"forward", "foo"})
        self.assertEqual(
            all_info["forward"]["get_inputs_function_name"],
            ["get_all_bundled_inputs_for_forward"],
        )
        self.assertEqual(
            all_info["foo"]["get_inputs_function_name"],
            ["get_all_bundled_inputs_for_foo"],
        )
        self.assertEqual(all_info["forward"]["info"], info)
        self.assertEqual(all_info["foo"]["info"], info)

        # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
        for func_name in all_info.keys():
            input_func_name = all_info[func_name]["get_inputs_function_name"][0]
            func_to_run = getattr(loaded, input_func_name)
            self.assertEqual(func_to_run(), samples)

    def test_multiple_methods_with_inputs_both_defined_failure(self):
        class MultipleMethodModel(torch.nn.Module):
            def forward(self, arg):
                return arg

            @torch.jit.export
            def foo(self, arg):
                return arg

        samples = [(torch.tensor([1]),)]

        # inputs defined 2 ways so should fail
        with self.assertRaises(Exception):
            mm = torch.jit.script(MultipleMethodModel())
            definition = textwrap.dedent(
                """
                def _generate_bundled_inputs_for_forward(self):
                    return []
                """
            )
            mm.define(definition)
            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
                mm,
                inputs={
                    mm.forward: samples,
                    mm.foo: samples,
                },
            )

    def test_multiple_methods_with_inputs_neither_defined_failure(self):
        class MultipleMethodModel(torch.nn.Module):
            def forward(self, arg):
                return arg

            @torch.jit.export
            def foo(self, arg):
                return arg

        samples = [(torch.tensor([1]),)]

        # inputs not defined so should fail
        with self.assertRaises(Exception):
            mm = torch.jit.script(MultipleMethodModel())
            mm._generate_bundled_inputs_for_forward()
            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
                mm,
                inputs={
                    mm.forward: None,
                    mm.foo: samples,
                },
            )

    def test_bad_inputs(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        # Non list for input list
        with self.assertRaises(TypeError):
            m = torch.jit.script(SingleTensorModel())
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
                m,
                inputs="foo",  # type: ignore[arg-type]
            )

        # List of non tuples. Most common error using the api.
        with self.assertRaises(TypeError):
            m = torch.jit.script(SingleTensorModel())
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
                m,
                inputs=[torch.ones(1, 2)],  # type: ignore[list-item]
            )

    def test_double_augment_fail(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        m = torch.jit.script(SingleTensorModel())
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
            m, inputs=[(torch.ones(1),)]
        )
        with self.assertRaisesRegex(
            Exception, "Models can only be augmented with bundled inputs once."
        ):
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
                m, inputs=[(torch.ones(1),)]
            )

    def test_double_augment_non_mutator(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        m = torch.jit.script(SingleTensorModel())
        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
            m, inputs=[(torch.ones(1),)]
        )
        with self.assertRaises(AttributeError):
            m.get_all_bundled_inputs()
        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
        self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1))

    def test_double_augment_success(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        m = torch.jit.script(SingleTensorModel())
        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
            m, inputs={m.forward: [(torch.ones(1),)]}
        )
        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])

        bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
            bundled_model, inputs=[(torch.ones(2),)]
        )
        self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])

    def test_dict_args(self):
        class MyModel(torch.nn.Module):
            def forward(
                self,
                arg1: Optional[Dict[str, torch.Tensor]],
                arg2: Optional[List[torch.Tensor]],
                arg3: torch.Tensor,
            ):
                if arg1 is None:
                    return arg3
                elif arg2 is None:
                    return arg1["a"] + arg1["b"]
                else:
                    return arg1["a"] + arg1["b"] + arg2[0]

        small_sample = dict(
            a=torch.zeros([10, 20]),
            b=torch.zeros([1, 1]),
            c=torch.zeros([10, 20]),
        )
        small_list = [torch.zeros([10, 20])]

        big_sample = dict(
            a=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
            b=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
            c=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
        )
        big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])]

        def condensed(t):
            ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
            assert ret.storage().size() == 1
            # ret.storage()[0] = 0
            return ret

        def bundle_optional_dict_of_randn(template):
            return torch.utils.bundled_inputs.InflatableArg(
                value=(
                    None
                    if template is None
                    else {k: condensed(v) for (k, v) in template.items()}
                ),
                fmt="{}",
                fmt_fn="""
                def {}(self, value: Optional[Dict[str, Tensor]]):
                    if value is None:
                        return None
                    output = {{}}
                    for k, v in value.items():
                        output[k] = torch.randn_like(v)
                    return output
                """,
            )

        def bundle_optional_list_of_randn(template):
            return torch.utils.bundled_inputs.InflatableArg(
                value=(None if template is None else [condensed(v) for v in template]),
                fmt="{}",
                fmt_fn="""
                def {}(self, value: Optional[List[Tensor]]):
                    if value is None:
                        return None
                    output = []
                    for v in value:
                        output.append(torch.randn_like(v))
                    return output
                """,
            )

        out: List[str] = []
        sm = torch.jit.script(MyModel())
        original_size = model_size(sm)
        small_inputs = (
            bundle_optional_dict_of_randn(small_sample),
            bundle_optional_list_of_randn(small_list),
            torch.zeros([3, 4]),
        )
        big_inputs = (
            bundle_optional_dict_of_randn(big_sample),
            bundle_optional_list_of_randn(big_list),
            torch.zeros([1 << 5, 1 << 8, 1 << 10]),
        )

        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
            sm,
            [big_inputs, small_inputs],
            _receive_inflate_expr=out,
        )
        augmented_size = model_size(sm)
        # assert the size has not increased more than 8KB
        self.assertLess(augmented_size, original_size + (1 << 13))

        loaded = save_and_load(sm)
        inflated = loaded.get_all_bundled_inputs()
        self.assertEqual(len(inflated[0]), len(small_inputs))

        methods, _ = (
            torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
                loaded
            )
        )

        # One Function (forward)
        # two bundled inputs (big_inputs and small_inputs)
        # two args which have InflatableArg with fmt_fn
        # 1 * 2 * 2 = 4
        self.assertEqual(
            sum(method.startswith("_inflate_helper") for method in methods), 4
        )


if __name__ == "__main__":
    run_tests()
