# Owner(s): ["module: functorch"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gc
from unittest import skip, skipIf

from attn_ft import BertSelfAttention as BertSelfAttentionA, Linear
from attn_positional import BertSelfAttention as BertSelfAttentionB

import torch
from functorch._C import dim as _C
from functorch.dim import (
    Dim,
    DimensionBindError,
    DimList,
    dimlists,
    dims,
    stack,
    Tensor,
)
from torch.testing._internal.common_utils import (
    run_tests,
    skipIfTorchDynamo,
    TEST_CUDA,
    TestCase,
)


try:
    from torchvision.models import resnet18
except ImportError:
    resnet18 = None

_test_c, _parse_test, _set_pointwise_optimize = (
    _C._test_c,
    _C._parse_test,
    _C._set_pointwise_optimize,
)

from contextlib import contextmanager
from time import perf_counter


measure_perf = False
if measure_perf:
    from torchdim.magic_trace import magic_trace
else:

    @contextmanager
    def magic_trace(*args, **kwargs):
        yield


@contextmanager
def measure(what):
    b = perf_counter()
    yield
    e = perf_counter()
    print(f"{what}: {e - b:.20f} seconds")


def triu(A):
    i, j = dims()
    a = A[i, j]
    zero = torch.tensor(0, dtype=torch.float)  # XXX - torch.where is janky...
    return torch.where(i <= j, a, zero).order(i, j)


def gpu_time(lmb, name, r=100):
    b = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)
    # with magic_trace(name + ".fxt"):
    for _ in range(r):
        lmb()
    b.record()
    for _ in range(r):
        lmb()
    e.record()
    e.synchronize()
    elapsed = b.elapsed_time(e)
    # with torch.profiler.profile(schedule=torch.profiler.schedule(
    #     wait=0,
    #     warmup=1,
    #     active=2), on_trace_ready=tensorboard_trace_handler(name), with_stack=True) as profiler:
    #     for _ in range(3):
    #         lmb()
    #         profiler.step()
    print(name, elapsed / r)
    return elapsed / r


@skipIfTorchDynamo("Bad interaction")
class TestMin(TestCase):
    def setUp(self):
        super().setUp()
        gc.disable()
        gc.collect()
        self.interesting = set()
        for o in gc.get_objects():
            if isinstance(o, (torch.Tensor, Dim, Tensor, DimList)):
                self.interesting.add(id(o))
        if "cuda" in self._testMethodName:
            self.mem_allocated = torch.cuda.memory_allocated()

    def tearDown(self):
        interesting = []
        for o in gc.get_objects():
            if (
                isinstance(o, (torch.Tensor, Dim, Tensor, DimList))
                and id(o) not in self.interesting
            ):
                interesting.append(o)

        extra_memory = 0
        if "cuda" in self._testMethodName:
            extra_memory += torch.cuda.memory_allocated() - self.mem_allocated

        #  nolevels = _n_levels_in_use() == 0
        if extra_memory != 0 or len(interesting) != 0:
            import refcycle

            refcycle.garbage().export_image("garbage.pdf")
        gc.collect()
        # assert nolevels, f"cleanup failed? {_n_levels_in_use()}"
        assert extra_memory == 0, f"extra cuda memory left allocated: {extra_memory}"
        assert len(interesting) == 0, (
            f"extra torch.Tensor, Dim, or Tensor left allocated: {len(interesting)} objects of types:"
            f" { [type(t) for t in interesting] }"
        )

    def test_manual_stuff(self):
        A_ = torch.rand(3, 4)
        B_ = torch.rand(4, 5)
        i, j, k = dims()
        A = A_[i, k]
        B = B_[k, j]
        C = (A.expand(j) * B.expand(i)).sum(k)
        self.assertTrue(torch.allclose(C.order(i, j), torch.mm(A_, B_)))
        self.assertTrue(torch.allclose(torch.triu(A_, 0), triu(A_)))

        D_ = torch.randint(0, 3, (6,))
        d = dims()
        D = D_[d]

        A.index([i], [D]).order(k, d)

    def attn(
        self,
        batch_size=1,
        sequence_length=4,
        hidden_size=6,
        num_attention_heads=3,
        linear=Linear,
        device=None,
        time=False,
    ):
        def maybe_to(x):
            return x if device is None else x.to(device)

        attention_probs_dropout_prob = 0.0
        A = maybe_to(
            BertSelfAttentionA(
                hidden_size,
                num_attention_heads,
                attention_probs_dropout_prob,
                linear=linear,
            )
        )
        B = maybe_to(
            BertSelfAttentionB(
                hidden_size, num_attention_heads, attention_probs_dropout_prob
            )
        )

        A.load_state_dict(B.state_dict())
        hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size))
        b_out = B(hidden_state)
        a_out = A(hidden_state)
        self.assertTrue(
            torch.allclose(a_out, b_out)
        )  # why does a simple matmul not do the right thing?

        if time:
            gpu_time(lambda: B(hidden_state), "positional", r=3)
            gpu_time(lambda: A(hidden_state), "first_class", r=3)

        for approach in ("relative_key", "relative_key_query"):
            A = maybe_to(
                BertSelfAttentionA(
                    hidden_size,
                    num_attention_heads,
                    attention_probs_dropout_prob,
                    approach,
                    sequence_length,
                    linear=linear,
                )
            )
            B = maybe_to(
                BertSelfAttentionB(
                    hidden_size,
                    num_attention_heads,
                    attention_probs_dropout_prob,
                    approach,
                    sequence_length,
                )
            )
            A.load_state_dict(B.state_dict())

            hidden_state = maybe_to(
                torch.rand(batch_size, sequence_length, hidden_size)
            )
            b_out = B(hidden_state)
            a_out = A(hidden_state)
            self.assertTrue(torch.allclose(a_out, b_out))

            if time:
                gpu_time(lambda: B(hidden_state), "positional", r=3)
                gpu_time(lambda: A(hidden_state), "first_class", r=3)

        A = maybe_to(
            BertSelfAttentionA(
                hidden_size,
                num_attention_heads,
                attention_probs_dropout_prob,
                None,
                None,
                linear=linear,
            )
        )
        B = maybe_to(
            BertSelfAttentionB(
                hidden_size,
                num_attention_heads,
                attention_probs_dropout_prob,
                None,
                None,
            )
        )
        A.load_state_dict(B.state_dict())

        hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size))
        past_key_value = (
            maybe_to(
                torch.rand(
                    batch_size,
                    num_attention_heads,
                    sequence_length,
                    hidden_size // num_attention_heads,
                )
            ),
            maybe_to(
                torch.rand(
                    batch_size,
                    num_attention_heads,
                    sequence_length,
                    hidden_size // num_attention_heads,
                )
            ),
        )

        b_out = B(hidden_state, past_key_value=past_key_value)
        a_out = A(hidden_state, past_key_value=past_key_value)
        self.assertTrue(torch.allclose(a_out, b_out))

        if time:
            gpu_time(lambda: B(hidden_state), "positional", r=3)
            gpu_time(lambda: A(hidden_state), "first_class", r=3)

    def test_attn(self):
        self.attn()

    def test_inplace(self):
        # some embeddings table
        embeddings = torch.zeros(10, 3)

        # some sparse updates to the embeddings
        indices = torch.arange(2) + 1
        values = torch.rand(2, 3)

        i, n, f = dims()

        embeddings[indices[i], f] += values[i, f]

    def test_adapt(self):
        def f():
            ci, co = dims()

        # python 3.11 adapts bytecode after a number of iterations
        # check that we still match names correctly
        for i in range(10):
            f()

    @skipIf(not TEST_CUDA, "no CUDA")
    def test_attn_cuda(self):
        # size from the BERT paper, 90% pretraining of sequence length 128
        self.attn(
            batch_size=256,
            hidden_size=768,
            sequence_length=128,
            num_attention_heads=12,
            device="cuda",
            time=measure_perf,
            linear=torch.nn.Linear,
        )

    def test_stack(self):
        i, j, d = dims()
        A = torch.rand(4, 5)
        r = stack([A[i, j]], d, j)
        # a, b = r.unbind(d)
        # self.assertTrue(torch.allclose(a.order(i, j), i.expand(j).order(i, j)))
        # self.assertTrue(torch.allclose(b.order(i, j), j.expand(i).order(i, j)))

    def test_max(self):
        ap = torch.rand(2, 3, 2)
        i, j, k = dims()
        a = ap[i, j, k]
        r, i0 = a.max(dim=k)
        self.assertTrue(torch.allclose(r.order(i, j), ap.max(2)[0]))

    def test_mm(self):
        i, j, k, q = dims()
        a = torch.rand(3, 4)
        b = torch.rand(4, 5)
        a_ = a[i, k]
        b_ = b[k, j]
        q.size = 1
        r = (a_.expand(j, q) * b_.expand(i, q)).sum(k).order(q, i, j)
        # r = (a_*b_).sum(k).order(q, i, j)
        # print(r)
        # print(a @ b)

    def test_with_dims_split(self):
        a = torch.arange(3 * 12).view(3, 12)
        i, j, k = dims()
        k.size = 4
        r = a[i, [j, k]]
        x = r.order(i, [j, k])
        self.assertTrue(torch.allclose(a, x))

    def test_hello(self):
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)
        i, j, k = dims()

        # r = A[i]*4
        r = (A[i, k] * B[k, j]).sum(k).order(i, j)
        assert torch.allclose(r, A @ B)

        assert A.sum() == A[i].sum((0, i))
        assert A.sum() == A[i].sum((-1, i))

        assert torch.allclose(A.sum(), A[i].sum(0, keepdim=True).sum((0, i)))
        assert torch.allclose(A[i].std(i, True), A.std(0, True))

        assert torch.allclose(A[i, k].max(i)[0].order(k), A.max(0)[0])
        assert torch.allclose(A.sort(1)[0], A[i, k].sort(k)[0].order(i, k))
        # XXX - chunk changes the size of a dimension, has to take a new dimension...
        # assert torch.allclose(A.chunk(2,1)[0], A[i, k].chunk(2, k)[0].order(i, k))
        assert torch.allclose(A[i].renorm(1, i, 7).order(i), A.renorm(1, 0, 7))
        kk = dims()
        # assert torch.allclose( torch.stack([A, A], 1), stack([A[i,k], A[i, k]], kk, k).order(i, kk, k))

        k2 = dims()
        # r = cat((A[i, k], A[i,k]), k, k2)
        # assert torch.allclose(torch.cat([A, A], 1), r.order(i, k2))
        # assert k2.size == 2*k.size

        assert torch.allclose(A.expand(5, -1, -1), A[i, k].expand(j).order(j, i, k))
        z = dims()
        C = torch.arange(2)
        assert torch.allclose(A[:, 0:2], A[i, k].index(k, C[z]).order(i, z))

        o, l = dims()
        o.size = 2
        r = A[i, k].index(k, (o, l))
        assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2))
        rr = r.index((o, l), k)
        assert torch.allclose(A, rr.order(i, k))

        r = i + k - 1
        r2 = torch.arange(3)[:, None] + torch.arange(4)[None, :] - 1
        assert torch.allclose(r.order(i, k), r2)

        # test with ...
        assert torch.allclose(A.T, A[..., k].order(k))

        # test with dimlist
        a_, b_ = dimlists()
        assert torch.allclose(A[i, a_].order(*a_, i), A.T)
        # test with one bound dimlist
        assert torch.allclose(A[:, a_].order(*a_), A.T)
        # test with a dimlist that will end up empty
        assert torch.allclose(A[i, b_, k].order(i, k, *b_), A)
        # test with too few things
        (A[i] + i)
        assert torch.allclose((A[i] + i).order(i), A + torch.arange(3)[:, None])
        # test with too many elements
        try:
            A[1, ..., 1, 1]
            raise NotImplementedError
        except IndexError:
            pass
        c, d = dims()
        c.size = 2
        assert torch.allclose(A[i, [c, d]].order(i, c, d), A.view(3, 2, 2))

        assert torch.allclose(
            A[c + 1, c + 0].order(c), A[torch.arange(2) + 1, torch.arange(2)]
        )
        try:
            A[..., 3, ...]
            raise NotImplementedError
        except DimensionBindError:
            pass

        C = torch.rand(4, 7)
        c_, x, y, z = dims()

        a, b, c = C.split((3, 3, 1), dim=1)
        s = dims()
        ref = C.split((3, 3, 1), dim=1)
        t = C[s, c_].split((x, y, z), dim=c_)
        for a, b, d in zip(ref, t, (x, y, z)):
            assert torch.allclose(a, b.order(s, d))

        D = torch.rand(3, 4, 5)
        assert torch.allclose(
            D.transpose(0, 1).flatten(1, 2), D[i, k, j].order((i, j)).order(k)
        )

        r = [id(x) for x in torch.rand_like(A[i, k]).dims]
        assert id(i) in r and id(k) in r
        r = [id(x) for x in torch.nn.functional.dropout(A[i, k]).dims]
        assert id(i) in r and id(k) in r

    def test_simple(self):
        i, j, k = dims()
        x = torch.rand(3, 4)
        z = x[i, j]
        (z + z + z + z)
        (z.order(i, j))

    def test_mm_fuse(self):
        i, j, k = dims()
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)

        C = (A[i, k] * B[k, j]).sum(k).order(i, j)
        assert torch.allclose(C, A @ B)

    def test_time_mm_fuse(self):
        i, j, k = dims()
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)

        for _ in range(10):
            r0 = A @ B

        for _ in range(10):
            a = A[i, k]
            b = B[k, j]
            r1 = (a * b).sum(k)

        with measure("pp"):
            for _ in range(10000):
                A @ B
        # magic_trace_stop_indicator()

        with measure("fc"):
            for _ in range(10000):
                (A[i, k] * B[k, j]).sum(k).order(i, j)

        with magic_trace("f.fxt"):
            for _ in range(10000):
                (A[i, k] * B[k, j]).sum(k).order(i, j)

        with magic_trace("p.fxt"):
            for _ in range(10000):
                A @ B

        # magic_trace_stop_indicator()

        assert torch.allclose(r1.order(i, j), r0)

    def test_compare_dims(self):
        i, j = dims()
        i.size = 3
        j.size = 4
        (i < j)  # noqa: B015

    def test_c(self):
        _test_c()

    def test_seg(self):
        A = torch.rand(3, 4)
        i, k = dims()
        i.size = 4
        k.size = 3
        r = i + k - 1

    def test_expand(self):
        A = torch.rand(3, 4)
        i = dims()
        assert list(A[i].expand(2, 4).order(i).size()) == [3, 2, 4]

    def test_parse(self):
        self.assertEqual(("x", None, None, None), _parse_test(1, 0, "x"))
        self.assertEqual(("x", None, "y", None), _parse_test(1, 0, "x", c="y"))
        self.assertEqual(("x", None, "y", "z"), _parse_test(1, 0, "x", d="z", c="y"))

        self.assertEqual(("x", "4", None, None), _parse_test(2, 0, "x", b="4"))
        self.assertEqual(("x", "y", "z", "q"), _parse_test(2, 0, "x", "y", "z", "q"))
        with self.assertRaises(TypeError):
            _parse_test(2, 0, "x", "y", "z", "q", "5")
        with self.assertRaises(TypeError):
            _parse_test(2, 0, "x", "y", b="y")

        with self.assertRaises(TypeError):
            _parse_test(2, 0, "x", c="y")
        with self.assertRaises(TypeError):
            _parse_test(2, 0, "x")

    def test_network(self):
        if resnet18 is None:
            self.skipTest("no torchvision")
        rn = resnet18(
            norm_layer=lambda x: torch.nn.BatchNorm2d(x, track_running_stats=False)
        )
        rn.train()
        img = torch.rand(1, 1, 2, 3, 224, 224)
        imgf = img.view(2, 3, 224, 224)

        i, j = dims()
        r = rn(img[i, j])
        r = r.order(i, j).view(2, 1000)
        r2 = rn(imgf)
        assert torch.allclose(r2, r, atol=1e-06)

    def test_dim_args(self):
        a = dimlists()
        assert isinstance(a, DimList)
        a = dims()
        b = dimlists()
        assert isinstance(a, Dim)
        assert isinstance(b, DimList)
        assert str(a) == "a"
        a, b = dims(sizes=[3, 4])
        assert a.size == 3
        assert b.size == 4
        a = dims(sizes=[3])
        b = dimlists(sizes=[4])
        assert len(b) == 4
        a = dims()
        b = dimlists(sizes=[[4, 5]])
        assert b[0].size == 4
        assert b[1].size == 5

    def test_diag(self):
        i = dims()
        A = torch.rand(4, 4)
        (A[i, i])

    def test_softmax_split(self):
        a = torch.rand(16)
        g, i = dims(sizes=[2, None])
        a2 = a[[i, g],]

        m_b, _ = a2.max(i)
        f_b = torch.exp(a2 - m_b)
        l_b = f_b.sum(i)

        m, _ = m_b.max(g)
        c = torch.exp(m_b - m)
        f = (c * f_b).order((i, g))
        l = (c * l_b).sum(g)
        assert torch.allclose(f / l, torch.nn.functional.softmax(a, dim=0))

    def test_index(self):
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)
        i, j, k = dims()

        o, l = dims()
        o.size = 2
        r = A[i, k].index(k, [o, l])
        assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2))
        rr = r.index([o, l], k)
        assert torch.allclose(A, rr.order(i, k))
        z = dims()
        C = torch.arange(2)
        x = A[i, k].index(k, C[z]).order(i, z)
        assert torch.allclose(A[:, 0:2], x)

        C = torch.rand(3, 4, 5)
        ik = dims()
        assert torch.allclose(
            C.index((0, 2), ik).order(ik), C.permute(0, 2, 1).reshape(15, 4)
        )

    # failures that came up from monkey patching some operators...
    def test_monkey(self):
        A = torch.rand(3, 4)
        A[0, 0] = 5
        x = torch.randn(3, 4, 4, 4, 3)
        x_clone1 = x.clone()
        ia = torch.tensor([0, 2, 1])
        ib = torch.tensor([0, 2, 1])
        first_shape = x[:, ia, None, ib, 0].shape
        x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1)
        x = torch.autograd.Variable(torch.tensor([]))
        z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
        a = [z[2], z[0] + 3]
        x.new(a)
        # self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4])

    def test_index_placement(self):
        A = torch.rand(1, 2, 3, 4)

        i, j = dims(sizes=[2, 4])

        a = A[:, i + 0, :, j + 0]
        r = a.order(i, j)

        assert torch.allclose(A.permute(1, 3, 0, 2), r)

    def test_order(self):
        i, j = dims()
        A = torch.rand(3, 4, 5)
        assert torch.allclose(A[i].order(1, i), A.permute(2, 0, 1))

    def test_mask(self):
        a = torch.rand(5)
        i, j = dims(sizes=[a.size(0), a.size(0)])
        ((i >= j) * a[i]).sum(j).order(i)

    def test_eq(self):
        i, j = dims(sizes=[3, 3])
        assert (i == j).sum((i, j)) == 3

    def test_dims_with_size(self):
        x = dims(3)
        assert len(x) == 3 and isinstance(x[0], Dim)

        class Foo:
            pass

        y = Foo()
        z, y.x, q = dims(3)
        assert str(z) == "z"
        assert str(y.x) == "d1"
        assert str(q) == "d2"

    def test_dir(self):
        i, j = dims(sizes=[3, 3])
        dir(i <= j)

    def test_doc(self):
        assert Tensor.clamp.__doc__ == torch.Tensor.clamp.__doc__

    def test_embed(self):
        embeddings = torch.rand(8, 32)
        ids = torch.tensor([1, 0, 3, 4])

        # slow but Pythonic
        values_ = torch.empty(4, 32)
        for batch in range(ids.size(0)):
            for feature in range(embeddings.size(1)):
                values_[batch, feature] = embeddings[ids[batch], feature]

        # with torchdim, single indexing kernel
        batch, feature = dims(2)
        values = embeddings[ids[batch], feature].order(batch, feature)

        assert torch.allclose(values, values_)

    def test_functorch(self):
        A = torch.rand(3, 4, 5)
        B = torch.rand(3, 4, 5)
        C = torch.rand(5, 2)

        i, j = dims()

        AA = torch.mm(A[i], C)  # 3, 4, 2
        BB = torch.mm(B[j], C)  # 3, 4, 2
        assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2]

    def test_permute_orig(self):
        d = dims(1)
        t_fc = torch.rand(1, 2, 3, 4)[d]
        assert t_fc.permute(dims=(1, 0, 2)).shape == t_fc.permute(1, 0, 2).shape

    def test_order_keyword(self):
        d = dims(1)
        t = torch.rand(3)[d]
        self.assertRaises(TypeError, lambda: t.order(wrong=3))

    def test_big_split(self):
        total = 0
        l = []
        while total < 6400:
            l.append(torch.randint(2, 10, (1,)).item())
            total += l[-1]
        x = torch.randn(total, 1)
        x.split(l, 0)


skip_functorch_only = ["test_time_mm_fuse", "test_attn_cuda"]


class TestMinFunctorchOnly(TestMin):
    def setUp(self):
        super().setUp()
        _set_pointwise_optimize(False)

    def tearDown(self):
        _set_pointwise_optimize(True)
        super().tearDown()


for n in skip_functorch_only:
    setattr(TestMinFunctorchOnly, n, skip("skip_functorch_only")(lambda self: None))

if __name__ == "__main__":
    run_tests()
