# Owner(s): ["module: dynamo"]
import sys
import unittest
from typing import Dict, List

import torch
import torch._dynamo.config
import torch._dynamo.test_case
from torch import nn
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import NoTest


try:
    from torchrec.datasets.random import RandomRecDataset
    from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

    HAS_TORCHREC = True
except ImportError:
    HAS_TORCHREC = False


@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
class BucketizeMod(torch.nn.Module):
    def __init__(self, feature_boundaries: Dict[str, List[float]]):
        super().__init__()
        self.bucket_w = torch.nn.ParameterDict()
        self.boundaries_dict = {}
        for key, boundaries in feature_boundaries.items():
            self.bucket_w[key] = torch.nn.Parameter(
                torch.empty([len(boundaries) + 1]).fill_(1.0),
                requires_grad=True,
            )
            buf = torch.tensor(boundaries, requires_grad=False)
            self.register_buffer(
                f"{key}_boundaries",
                buf,
                persistent=False,
            )
            self.boundaries_dict[key] = buf

    def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
        weights_list = []
        for key, boundaries in self.boundaries_dict.items():
            jt = features[key]
            bucketized = torch.bucketize(jt.weights(), boundaries)
            # doesn't super matter I guess
            # hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries))
            hashed = bucketized
            weights = torch.gather(self.bucket_w[key], dim=0, index=hashed)
            weights_list.append(weights)
        return KeyedJaggedTensor(
            keys=features.keys(),
            values=features.values(),
            weights=torch.cat(weights_list),
            lengths=features.lengths(),
            offsets=features.offsets(),
            stride=features.stride(),
            length_per_key=features.length_per_key(),
        )


if not HAS_TORCHREC:
    print("torchrec not available, skipping tests", file=sys.stderr)
    TestCase = NoTest  # noqa: F811


@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec")
class TorchRecTests(TestCase):
    def test_pooled(self):
        tables = [
            (nn.EmbeddingBag(2000, 8), ["a0", "b0"]),
            (nn.EmbeddingBag(2000, 8), ["a1", "b1"]),
            (nn.EmbeddingBag(2000, 8), ["b2"]),
        ]

        embedding_groups = {
            "a": ["a0", "a1"],
            "b": ["b0", "b1", "b2"],
        }

        counter = CompileCounter()

        @torch.compile(backend=counter, fullgraph=True, dynamic=True)
        def f(id_list_features: KeyedJaggedTensor):
            id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
            pooled_embeddings = {}
            # TODO: run feature processor
            for emb_module, feature_names in tables:
                features_dict = id_list_jt_dict
                for feature_name in feature_names:
                    f = features_dict[feature_name]
                    pooled_embeddings[feature_name] = emb_module(
                        f.values(), f.offsets()
                    )

            pooled_embeddings_by_group = {}
            for group_name, group_embedding_names in embedding_groups.items():
                group_embeddings = [
                    pooled_embeddings[name] for name in group_embedding_names
                ]
                pooled_embeddings_by_group[group_name] = torch.cat(
                    group_embeddings, dim=1
                )

            return pooled_embeddings_by_group

        dataset = RandomRecDataset(
            keys=["a0", "a1", "b0", "b1", "b2"],
            batch_size=4,
            hash_size=2000,
            ids_per_feature=3,
            num_dense=0,
        )
        di = iter(dataset)

        # unsync should work

        d1 = next(di).sparse_features.unsync()
        d2 = next(di).sparse_features.unsync()
        d3 = next(di).sparse_features.unsync()

        r1 = f(d1)
        r2 = f(d2)
        r3 = f(d3)

        self.assertEqual(counter.frame_count, 1)
        counter.frame_count = 0

        # sync should work too

        d1 = next(di).sparse_features.sync()
        d2 = next(di).sparse_features.sync()
        d3 = next(di).sparse_features.sync()

        r1 = f(d1)
        r2 = f(d2)
        r3 = f(d3)

        self.assertEqual(counter.frame_count, 1)

        # export only works with unsync

        gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module
        gm.print_readable()

        self.assertEqual(gm(d1), r1)
        self.assertEqual(gm(d2), r2)
        self.assertEqual(gm(d3), r3)

    def test_bucketize(self):
        mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]})
        features = KeyedJaggedTensor.from_lengths_sync(
            keys=["f1"],
            values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
            lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
            weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]),
        ).unsync()

        def f(x):
            # This is a trick to populate the computed cache and instruct
            # ShapeEnv that they're all sizey
            x.to_dict()
            return mod(x)

        torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable()

    @unittest.expectedFailure
    def test_simple(self):
        jag_tensor1 = KeyedJaggedTensor(
            values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
            keys=["index_0", "index_1"],
            lengths=torch.tensor([0, 0, 1, 1, 1, 3]),
        ).sync()

        # ordinarily, this would trigger one specialization
        self.assertEqual(jag_tensor1.length_per_key(), [1, 5])

        counter = CompileCounter()

        @torch._dynamo.optimize(counter, nopython=True)
        def f(jag_tensor):
            # The indexing here requires more symbolic reasoning
            # and doesn't work right now
            return jag_tensor["index_0"].values().sum()

        f(jag_tensor1)

        self.assertEqual(counter.frame_count, 1)

        jag_tensor2 = KeyedJaggedTensor(
            values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
            keys=["index_0", "index_1"],
            lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
        ).sync()

        f(jag_tensor2)

        self.assertEqual(counter.frame_count, 1)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
