# Owner(s): ["module: sparse"]
#
# Test to ensure sparsity information propagates properly into traced graph.
#

import sys
import unittest

import torch
from torch._dynamo.config import is_fbcode
from torch._subclasses.fake_tensor import FakeTensor
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    subtest,
    TestCase,
)


# Various data types (preserved over operations).
DTYPES = [
    torch.int64,
    torch.float16,
    torch.bfloat16,
    torch.float32,
    torch.float64,
]

# Various index types.
ITYPES = [torch.int32, torch.int64]


# Constructs a subtest for every sparse layout currently supported in torch.sparse.
def all_sparse_layouts(test_name="layout"):
    return parametrize(
        test_name,
        [
            subtest(torch.sparse_coo, name="SparseCOO"),
            subtest(torch.sparse_csr, name="SparseCSR"),
            subtest(torch.sparse_csc, name="SparseCSC"),
            subtest(torch.sparse_bsr, name="SparseBSR"),
            subtest(torch.sparse_bsc, name="SparseBSC"),
        ],
    )


#
# Various network examples.
#


class IdNet(torch.nn.Module):
    def forward(self, x):
        return x


class SumNet(torch.nn.Module):
    def forward(self, x):
        return x.sum()


class EltwiseNet(torch.nn.Module):
    def forward(self, x):
        return torch.nn.functional.relu(2 * torch.abs(-x))


class ToDenseNet(torch.nn.Module):
    def forward(self, x):
        return x.to_dense()


class AddNet(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)


class SparseActivationCOO(torch.nn.Module):
    def forward(self, x):
        return [xi.to_sparse() for xi in x]


class SparseActivationCSR(torch.nn.Module):
    def forward(self, x):
        return [xi.to_sparse_csr() for xi in x]


#
# The test driver.
#


@unittest.skipIf(is_fbcode(), "See torch._dynamo.config")
@unittest.skipIf(
    sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
class TestSparseProp(TestCase):
    def setUp(self):
        TestCase.setUp(self)

    def assertEqualMeta(self, x, y):
        self.assertIsInstance(x, FakeTensor)
        self.assertIsInstance(y, torch.Tensor)

        # Convert expected value to meta for comparison.
        y = y.to("meta")
        self.assertEqual(x, y, exact_layout=True, exact_is_coalesced=True)

        # When x or y is a meta tensor (say, `x.device == "meta"`), then
        # assertEqual(x, y) compares only x and y attributes but skips
        # comparing their values. In the case of sparse tensors, this means
        # that comparing indices and values attributes are skipped as well,
        # which is why we are doing that explicitly below.
        if x.layout is torch.strided:
            pass
        elif x.layout is torch.sparse_coo:
            self.assertEqual(x._indices(), y._indices(), exact_layout=True)
            self.assertEqual(x._values(), y._values(), exact_layout=True)
        else:
            if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
                x_meta1, y_meta1 = (x.crow_indices(), y.crow_indices())
                x_meta2, y_meta2 = (x.col_indices(), y.col_indices())
            elif x.layout in {torch.sparse_csc, torch.sparse_bsc}:
                x_meta1, y_meta1 = (x.ccol_indices(), y.ccol_indices())
                x_meta2, y_meta2 = (x.row_indices(), y.row_indices())
            else:
                assert 0  # unreachable
            self.assertEqual(x_meta1, y_meta1, exact_layout=True)
            self.assertEqual(x_meta2, y_meta2, exact_layout=True)
            self.assertEqual(x.values(), y.values(), exact_layout=True)

    @parametrize("dtype", DTYPES)
    @parametrize("itype", ITYPES)
    @all_sparse_layouts("layout")
    def test_idnet(self, dtype, itype, layout):
        net = IdNet()
        for sparse_input in self.generate_simple_inputs(
            layout,
            device="cpu",
            dtype=dtype,
            index_dtype=itype,
        ):
            # Build the traced graph.
            prog = torch.export.export(net, (sparse_input,))
            # Test arg/output.
            for i, node in enumerate(prog.graph.nodes):
                meta = node.meta.get("val", None)
                if i == 0:
                    self.assertEqualMeta(meta, sparse_input)
                else:
                    self.assertEqual(meta, None)

    @parametrize("dtype", DTYPES)
    @parametrize("itype", ITYPES)
    @all_sparse_layouts("layout")
    def test_sumnet(self, dtype, itype, layout):
        net = SumNet()
        for sparse_input in self.generate_simple_inputs(
            layout,
            device="cpu",
            dtype=dtype,
            index_dtype=itype,
        ):
            result = net(sparse_input)
            # Build the traced graph.
            prog = torch.export.export(net, (sparse_input,))
            # Test arg/sum/output.
            for i, node in enumerate(prog.graph.nodes):
                meta = node.meta.get("val", None)
                if i == 0:
                    self.assertEqualMeta(meta, sparse_input)
                elif i == 1:
                    self.assertEqualMeta(meta, result)
                else:
                    self.assertEqual(meta, None)

    @parametrize("dtype", DTYPES)
    @parametrize("itype", ITYPES)
    @all_sparse_layouts("layout")
    def test_eltwisenet(self, dtype, itype, layout):
        net = EltwiseNet()
        for sparse_input in self.generate_simple_inputs(
            layout,
            device="cpu",
            dtype=dtype,
            index_dtype=itype,
        ):
            result = net(sparse_input)
            # Build the traced graph.
            prog = torch.export.export(net, (sparse_input,))
            # Test arg/neg/abs/mul/relu/output.
            for i, node in enumerate(prog.graph.nodes):
                meta = node.meta.get("val", None)
                if i <= 4:
                    self.assertEqualMeta(meta, result)
                else:
                    self.assertEqual(meta, None)

    @parametrize("dtype", DTYPES)
    @parametrize("itype", ITYPES)
    @all_sparse_layouts("layout")
    def test_todensenet(self, dtype, itype, layout):
        net = ToDenseNet()
        for sparse_input in self.generate_simple_inputs(
            layout,
            device="cpu",
            dtype=dtype,
            index_dtype=itype,
        ):
            result = net(sparse_input)
            # Build the traced graph.
            prog = torch.export.export(net, (sparse_input,))
            # Test arg/todense/output.
            for i, node in enumerate(prog.graph.nodes):
                meta = node.meta.get("val", None)
                if i == 0:
                    self.assertEqualMeta(meta, sparse_input)
                elif i == 1:
                    self.assertEqualMeta(meta, result)
                else:
                    self.assertEqual(meta, None)

    def test_add(self):
        net = AddNet()
        Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
        A = torch.tensor(
            [
                [0.0, 1.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 2.0],
                [0.0, 0.0, 1.0, 1.0],
                [3.0, 0.0, 3.0, 0.0],
            ],
            dtype=torch.float32,
        )
        S = A.to_sparse_csr()
        result = net(S, Y)
        # Build the traced graph.
        prog = torch.export.export(net, (S, Y))
        # Test args/add/output.
        for i, node in enumerate(prog.graph.nodes):
            meta = node.meta.get("val", None)
            if i == 0:
                self.assertEqualMeta(meta, S)
            elif i == 1:
                self.assertEqualMeta(meta, Y)
            elif i == 2:
                self.assertEqualMeta(meta, result)
            else:
                self.assertEqual(meta, None)

    def test_activation_coo(self):
        net = SparseActivationCOO()
        x = [torch.randn(3, 3) for _ in range(3)]
        result = net(x)
        # Build the traced graph.
        prog = torch.export.export(net, args=(x,))
        # Test args/to_sparse/output.
        for i, node in enumerate(prog.graph.nodes):
            meta = node.meta.get("val", None)
            if i <= 2:
                self.assertEqualMeta(meta, x[i])
            elif i <= 5:
                self.assertEqualMeta(meta, result[i - 3])
            else:
                self.assertEqual(meta, None)

    def test_activation_csr(self):
        net = SparseActivationCSR()
        x = [torch.randn(3, 3) for _ in range(3)]
        result = net(x)
        # Build the traced graph.
        prog = torch.export.export(net, args=(x,))
        # Test args/to_sparse/output.
        for i, node in enumerate(prog.graph.nodes):
            meta = node.meta.get("val", None)
            if i <= 2:
                self.assertEqualMeta(meta, x[i])
            elif i <= 5:
                self.assertEqualMeta(meta, result[i - 3])
            else:
                self.assertEqual(meta, None)


instantiate_parametrized_tests(TestSparseProp)

if __name__ == "__main__":
    run_tests()
