# Owner(s): ["module: unknown"]

import unittest
from typing import Dict, Optional

import numpy as np
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.static_module import StaticModule
from typing import List


def linear_shim(
    input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
    output = input.matmul(weight.t())
    if bias is not None:
        output += bias
    ret = output
    return ret


torch.nn.functional.linear = linear_shim


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        # self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        # energy = energy.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(energy, dim=-1)
        # x = torch.matmul(self.dropout(attention), V)
        x = torch.matmul(attention, V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)
        return x, attention


# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
def create_mlp(ln, sigmoid_layer):
    layers = nn.ModuleList()
    for i in range(0, len(ln) - 1):
        n = ln[i]
        m = ln[i + 1]

        LL = nn.Linear(int(n), int(m), bias=True)

        mean = 0.0  # std_dev = np.sqrt(variance)
        std_dev = np.sqrt(2 / (m + n))  # np.sqrt(1 / m) # np.sqrt(1 / n)
        W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
        std_dev = np.sqrt(1 / m)  # np.sqrt(2 / (m + 1))
        bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
        LL.weight.data = torch.tensor(W, requires_grad=True)
        LL.bias.data = torch.tensor(bt, requires_grad=True)
        layers.append(LL)

        if i == sigmoid_layer:
            layers.append(nn.Sigmoid())
        else:
            layers.append(nn.ReLU())

    with torch.no_grad():
        s = torch.jit.script(torch.nn.Sequential(*layers))
    s.eval()
    return s


def trivial_graph(a, b, c):
    s = torch.tensor([[3, 3], [3, 3]])
    return a + b * c + s

def elementwise_square_addition(input1, input2):
    return input1 * input1 + input2 * input2

def fork_wait_graph1(input1, input2):
    fut = torch.jit.fork(elementwise_square_addition, input1, input2)
    return torch.jit.wait(fut)

def fork_wait_graph2(input1, input2):
    fut = torch.jit.fork(loop_graph, input1, input2, 5)
    return torch.jit.wait(fut)

"""
   graph with multiple fork/wait operations
   :param input: torch.tensor input to forked subgraph
   :param iters: number of future/wait pairs to be created
"""
def fork_wait_graph3(input, iters: int):
    futures : List[torch.jit.Future[torch.Tensor]] = []
    for _ in range(iters):
        futures.append(torch.jit.fork(torch.neg, input))
    results = []
    for future in futures:
        results.append(torch.jit.wait(future))
    return torch.sum(torch.stack(results))

"""
   graph with multi-level fork/wait operations
   :param input: torch.tensor input to forked subgraph
   :param num_forks: number of top level forks
   :param num_child_forks: number of child forks per parent fork
"""
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
    futures : List[torch.jit.Future[torch.Tensor]] = []
    for _ in range(num_forks):
        futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
    results = []
    for future in futures:
        results.append(torch.jit.wait(future))
    return torch.sum(torch.stack(results))

def add_tensor(input1, input2):
    return input1 + input2

def fork_wait_graph_exception(input1, input2):
    fut = torch.jit.fork(add_tensor, input1, input2)
    return torch.jit.wait(fut)

def loop_graph(a, b, iters: int):
    c = a + b * 2
    for i in range(iters):
        c = c + b
        c *= 2
        c -= a
    return c


def output_graph(a, b, c, iters: int):
    s = torch.tensor([[3, 3], [3, 3]])
    k = a + b * c + s
    d: Dict[int, torch.Tensor] = {}
    for i in range(iters):
        d[i] = k + i
    return d


class SubModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = 11
        self.b = 2

    def forward(self, x):
        return self.a + self.b + x


class SubModule2(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = 12
        self.b = 2

    def forward(self, x):
        self.b = 30
        return self.a + self.b + x


class TestModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sub1 = SubModule()
        self.sub2 = SubModule2()
        self.a = 3
        self.b = 4

    def forward(self, x):
        self.b = 20
        return self.sub1(x) + self.a + self.b + self.sub2(x)


class TestStaticModule(TestCase):

    """
    Test Case: To test simple fork/wait operation in a graph
    fork is called on simple addition operation on input tensors
    """
    def test_fork_wait_1(self):
        inp1 = torch.ones(5, 5)
        inp2 = torch.randn(5, 5)
        torch_graph = torch.jit.script(fork_wait_graph1)
        output_ref = torch_graph(inp1, inp2)
        static_runtime_module = StaticModule(torch_graph)
        output_test = static_runtime_module(inp1, inp2)
        torch.testing.assert_close(output_test, output_ref)

    """
    Test Case: To test simple fork/wait operation with
    StaticRuntime runAsync API returning future
    """
    def test_fork_wait_1_async(self):
        inp1 = torch.ones(5, 5)
        inp2 = torch.randn(5, 5)
        torch_graph = torch.jit.script(fork_wait_graph1)
        output_ref = torch_graph(inp1, inp2)
        static_runtime_module = StaticModule(torch_graph)
        output_test = static_runtime_module.runAsync((inp1, inp2), {})
        output_test.wait()
        torch.testing.assert_close(output_test.value(), output_ref)

    """
    Test Case: To test fork/wait operation in a graph on
    a loop subgraph performing mix of operations
    """
    def test_fork_wait_2(self):
        inp1 = torch.randn(5, 5)
        inp2 = torch.randn(5, 5)
        torch_graph = torch.jit.script(fork_wait_graph2)
        output_ref = torch_graph(inp1, inp2)
        static_runtime_module = StaticModule(torch_graph)
        output_test = static_runtime_module(inp1, inp2)
        torch.testing.assert_close(output_test, output_ref)

    """
    Test Case: To test fork/wait operation on a loop
    subgraph with StaticRuntime runAsync API returning future
    """
    def test_fork_wait_2_async(self):
        inp1 = torch.randn(5, 5)
        inp2 = torch.randn(5, 5)
        torch_graph = torch.jit.script(fork_wait_graph2)
        output_ref = torch_graph(inp1, inp2)
        static_runtime_module = StaticModule(torch_graph)
        output_test = static_runtime_module.runAsync((inp1, inp2), {})
        output_test.wait()
        torch.testing.assert_close(output_test.value(), output_ref)

    """
    Test Case: To test fork/wait operation in a graph on
    having multiple fork/wait operations
    """
    def test_fork_wait_3(self):
        input = torch.ones(3, 3)
        num_forks = 10
        torch_graph = torch.jit.script(fork_wait_graph3)
        output_ref = torch_graph(input, num_forks)
        static_runtime_module = StaticModule(torch_graph)
        output_test = static_runtime_module(input, num_forks)
        torch.testing.assert_close(output_test, output_ref)

    """
    Test Case: To test fork/wait operation in a graph with
    multiple fork/wait operations on runAsync API returning future
    """
    def test_fork_wait_3_async(self):
        input = torch.ones(3, 3)
        num_forks = 10
        torch_graph = torch.jit.script(fork_wait_graph3)
        output_ref = torch_graph(input, num_forks)
        static_runtime_module = StaticModule(torch_graph)
        output_test = static_runtime_module.runAsync((input, num_forks), {})
        output_test.wait()
        torch.testing.assert_close(output_test.value(), output_ref)

    """
    Test Case: To test fork/wait operation in a graph on
    multiple nested fork/wait operations
    """
    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
    def test_fork_wait_4(self):
        input = torch.ones(3, 3)
        num_forks = 10
        num_child_forks = 10
        torch_graph = torch.jit.script(fork_wait_graph4)
        static_runtime_module = StaticModule(torch_graph)
        output_ref = torch_graph(input, num_forks, num_child_forks)
        output_test = static_runtime_module(input, num_forks, num_child_forks)
        torch.testing.assert_close(output_test, output_ref)

    """
    Test Case: To test fork/wait operation in a graph with multiple
    nested fork/wait operations on runAsync API returning future
    """
    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
    def test_fork_wait_4_async(self):
        input = torch.ones(3, 3)
        num_forks = 10
        num_child_forks = 10
        torch_graph = torch.jit.script(fork_wait_graph4)
        static_runtime_module = StaticModule(torch_graph)
        output_ref = torch_graph(input, num_forks, num_child_forks)
        output_test = static_runtime_module.runAsync(
            (input, num_forks, num_child_forks), {})
        output_test.wait()
        torch.testing.assert_close(output_test.value(), output_ref)

    """
    Test Case: To test exception handling in fork/wait
    operation. Add.Tensor op is called for tensors with
    non-matching dims on the forked subgraph and the
    exception raised by subgraph is set on future returned
    by prim::fork to parent graph. Returned exception is
    checked for substring expected_error_msg as declared below
    """
    def test_fork_wait_exception(self):
        # incompatible tensors for add due to shape mismatch
        input1 = torch.randn(4, 7)
        input2 = torch.randn(4, 5)
        torch_graph = torch.jit.script(fork_wait_graph_exception)
        try:
            static_runtime_module = StaticModule(torch_graph)
            output_test = static_runtime_module(input1, input2)
        except Exception as error:
            expected_error_msg = (
                "The size of tensor a (7) must match the size "
                "of tensor b (5) at non-singleton dimension 1"
            )
            # test fails if error does not contain expected substr
            if str(error).find(expected_error_msg) == -1:
                raise RuntimeError(
                    "Tried execution of add.Tensors with incompatible shape. "
                    "Exception raised by forked runtime execution does "
                    f'not contain expected substring: "{expected_error_msg}"'
                ) from error

    """
    Test Case: To test exception handling in fork/wait
    operation with runAsync API. Add.Tensor op is called for
    tensors with non-matching dims on the forked subgraph
    and the exception raised by subgraph is set on future returned
    by prim::fork to parent graph. Returned exception is
    checked for substring expected_error_msg as declared below
    """
    def test_fork_wait_exception_async(self):
        # incompatible tensors for add due to shape mismatch
        input1 = torch.randn(4, 7)
        input2 = torch.randn(4, 5)
        torch_graph = torch.jit.script(fork_wait_graph_exception)
        try:
            static_runtime_module = StaticModule(torch_graph)
            output_test = static_runtime_module.runAsync(
                (input1, input2), {})
        except Exception as error:
            expected_error_msg = (
                "The size of tensor a (7) must match the size "
                "of tensor b (5) at non-singleton dimension 1"
            )
            # test fails if error does not contain expected substr
            if str(error).find(expected_error_msg) == -1:
                raise RuntimeError(
                    "Tried execution of add.Tensors with incompatible shape. "
                    "Exception raised by forked runtime execution does "
                    f'not contain expected substring: "{expected_error_msg}"'
                ) from error

    def test_multihead_attention_layer(self):
        HID_DIM = 256
        QUERY_LEN = 8
        BATCH_SIZE = 128
        LAYERS = 3
        HEADS = 8
        DROPOUT = 0.1
        device = torch.device("cpu")
        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
        with torch.no_grad():
            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)

        attention.eval()
        attention = torch.jit.script(attention)
        attention.eval()
        o_ref = attention(src, src, src, src_mask)

        attention_a = StaticModule(attention)
        o_test = attention_a(src, src, src, src_mask)
        o_test_kw = attention_a(src, src, value=src, mask=src_mask)

        for a, b in zip(o_ref, o_test):
            torch.testing.assert_close(a, b)

        for a, b in zip(o_ref, o_test_kw):
            torch.testing.assert_close(a, b)

    def test_multihead_attention_layer_benchmark(self):
        HID_DIM = 256
        QUERY_LEN = 8
        BATCH_SIZE = 128
        LAYERS = 3
        HEADS = 8
        DROPOUT = 0.1
        device = torch.device("cpu")
        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
        with torch.no_grad():
            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)

        attention.eval()
        attention = torch.jit.script(attention)
        attention_a = StaticModule(attention)

        attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
        metrics = attention_a.benchmark_individual_ops(
            [src, src, src, src_mask], {}, 2, 2
        )

    def test_mlp(self):
        # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
        ln_bot = [512, 512, 64]
        sigmoid_bot = -1
        ln_top = [100, 1024, 1024, 1024, 1]
        sigmoid_top = 3
        bot_l = create_mlp(ln_bot, sigmoid_bot)
        bot_l_acc = StaticModule(bot_l)
        top_l = create_mlp(ln_top, sigmoid_top)
        top_l_acc = StaticModule(top_l)
        with torch.no_grad():
            bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
            top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
        ref_bot = bot_l(bot_inp)
        acc_bot = bot_l_acc(bot_inp)
        torch.testing.assert_close(acc_bot, ref_bot)
        ref_top = top_l(top_inp)
        acc_top = top_l_acc(top_inp)
        torch.testing.assert_close(acc_top, ref_top)
        for _ in range(5):
            with torch.no_grad():
                bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
                top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
            ref_bot = bot_l(bot_inp)
            acc_bot = bot_l_acc(bot_inp)
            torch.testing.assert_close(acc_bot, ref_bot)
            ref_top = top_l(top_inp)
            acc_top = top_l_acc(top_inp)
            torch.testing.assert_close(acc_top, ref_top)

    def test_trivial_graph(self):
        s = torch.full((2, 2), 2)
        tg = torch.jit.script(trivial_graph)
        o_ref = tg(s, s, s)
        tg_a = StaticModule(tg)
        o_test = tg_a(s, s, s)
        torch.testing.assert_close(o_ref, o_test)

    def test_leaky_relu(self):
        s = torch.randn(5, 5)
        tg = torch.jit.script(nn.LeakyReLU(0.1))
        o_ref = tg(s)
        tg_a = StaticModule(tg)
        o_test = tg_a(s)
        torch.testing.assert_close(o_ref, o_test)

    def test_attr(self):
        """
        TorchScript IR of TestModule() after freezing:
        graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
              %x.1 : Tensor):
            %18 : int = prim::Constant[value=30]()
            %30 : int = prim::Constant[value=13]()
            %3 : int = prim::Constant[value=20]()
            %2 : int = prim::Constant[value=1]()
            %self.sub2.a : int = prim::Constant[value=12]()
            %self.a : int = prim::Constant[value=3]()
            = prim::SetAttr[name="b"](%self, %3)
            %17 : Tensor = aten::add(%x.1, %30, %2)
            %7 : Tensor = aten::add(%17, %self.a, %2)
            %b.1 : int = prim::GetAttr[name="b"](%self)
            %9 : Tensor = aten::add(%7, %b.1, %2)
            %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
            = prim::SetAttr[name="b"](%sub2, %18)
            %b : int = prim::GetAttr[name="b"](%sub2)
            %22 : int = aten::add(%self.sub2.a, %b)
            %23 : Tensor = aten::add(%x.1, %22, %2)
            %12 : Tensor = aten::add(%9, %23, %2)
            return (%12)
        """
        # test prim::SetAttr and prim::GetAttr impl in Static Runtime
        m = TestModule()

        m.eval()
        input = torch.randn(2, 2)
        output_s = m.forward(input)

        ms = torch.jit.script(m)
        sm = StaticModule(ms)
        output_sm = sm(input)
        torch.testing.assert_close(output_s, output_sm)
        sm.benchmark([input], {}, 2, 2)
        sm.benchmark_individual_ops([input], {}, 2, 2)
        sm.benchmark([], {"x": input}, 2, 2)
        sm.benchmark_individual_ops([], {"x": input}, 2, 2)

    @unittest.skip("Temporarily disabled")
    def test_fusion_trivial_graph(self):
        s = torch.full((2, 2), 2)
        tg = torch.jit.script(trivial_graph)
        o_ref = tg(s, s, s)
        torch._C._fuse_to_static_module(tg.graph)
        assert "StaticSubgraph" in str(tg.graph)
        o_test = tg(s, s, s)
        torch.testing.assert_close(o_ref, o_test)

    @unittest.skip("Temporarily disabled")
    def test_fusion_multihead_attention_layer(self):
        HID_DIM = 256
        QUERY_LEN = 8
        BATCH_SIZE = 128
        LAYERS = 3
        HEADS = 8
        DROPOUT = 0.1
        device = torch.device("cpu")
        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
        with torch.no_grad():
            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)

        attention.eval()
        attention = torch.jit.script(attention)
        attention.eval()
        o_ref = attention(src, src, src, src_mask)

        torch._C._fuse_to_static_module(attention._c)
        o_test = attention(src, src, src, src_mask)

        for a, b in zip(o_ref, o_test):
            torch.testing.assert_close(a, b)

    @unittest.skip("Temporarily disabled")
    def test_fusion_loop(self):
        a = torch.randn(5, 5)
        b = torch.randn(5, 5)
        c = 4
        lg = torch.jit.script(loop_graph)
        o_ref = lg(a, b, c)
        torch._C._fuse_to_static_module(lg.graph)
        assert "StaticSubgraph" in str(lg.graph)
        o_test = lg(a, b, c)
        torch.testing.assert_close(o_ref, o_test)

    @unittest.skip("Temporarily disabled")
    def test_fusion_outputs(self):
        a = torch.randn(2, 2)
        b = torch.randn(2, 2)
        c = 4
        og = torch.jit.script(output_graph)
        o_ref = og(a, b, b, c)
        torch._C._fuse_to_static_module(og.graph)
        assert "StaticSubgraph" in str(og.graph)
        o_test = og(a, b, b, c)
        for i in o_ref.keys():
            torch.testing.assert_close(o_ref[i], o_test[i])

    def test_create_object(self):
        class Foo:  # noqa: B903
            def __init__(self, x: torch.Tensor) -> None:
                self.x = x

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, y: torch.Tensor) -> torch.Tensor:
                foo = Foo(y)
                return y * foo.x

        mod = torch.jit.script(Mod()).eval()
        y = torch.randn((1, ))
        expected = mod(y)

        static_mod = StaticModule(torch.jit.freeze(mod))
        actual = static_mod(y)

        self.assertEqual(expected, actual)

if __name__ == "__main__":
    run_tests()
