# Owner(s): ["module: scatter & gather ops"]

from itertools import product
from functools import partial

import numpy as np
import torch
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    dtypes,
)
from torch.testing._internal.common_utils import (
    TestCase,
    run_tests,
    gradcheck,
    parametrize,
    skipIfRocm,
)


reductions = ["max", "mean", "min", "sum", "prod"]


def get_default_value(initial_value, reduction):
    if initial_value is not None:
        return initial_value
    if reduction == "max":
        return -float("Inf")
    elif reduction == "mean":
        return float("nan")
    elif reduction == "min":
        return float("Inf")
    elif reduction == "sum":
        return 0.0
    elif reduction == "prod":
        return 1.0


class TestSegmentReductions(TestCase):
    def _test_common(
        self,
        reduction,
        device,
        dtype,
        unsafe,
        axis,
        initial_value,
        data_arr,
        lengths_arr,
        expected_arr,
        expected_grad_arr,
        check_backward,
        lengths_dtype=torch.int,
    ):
        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
        # generate offsets from lengths
        zeros_shape = list(lengths.shape)
        zeros_shape[-1] = 1
        offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)

        data = torch.tensor(
            data_arr,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
        expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
        for mode in ['lengths', 'offsets']:
            segment_reduce_kwargs = dict(
                axis=axis,
                unsafe=unsafe,
                initial=initial_value)
            if (mode == 'lengths'):
                segment_reduce_kwargs['lengths'] = lengths
            else:
                segment_reduce_kwargs['offsets'] = offsets
            actual_result = torch._segment_reduce(
                data=data,
                reduce=reduction,
                **segment_reduce_kwargs
            )
            self.assertEqual(
                expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
            )

            if not check_backward:
                return

            # Test backward
            actual_result.sum().backward()
            self.assertEqual(
                expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
            )
            data = data.clone().detach().requires_grad_(True)

            # gradcheck does not work well with bfloat16 or fp16 cpu types
            # also there is small numerical difference with fp32
            if dtype not in [torch.half, torch.bfloat16, torch.float]:
                # gradcheck does not like "nan" input, setting to random 10
                d_non_nan = np.nan_to_num(data_arr, nan=10)
                new_data = torch.tensor(
                    # [10 if v == float("nan") else v for v in data],
                    d_non_nan,
                    device=device,
                    dtype=dtype,
                    requires_grad=True,
                )
                self.assertTrue(
                    gradcheck(
                        lambda x: torch._segment_reduce(
                            data=x,
                            reduce=reduction,
                            **segment_reduce_kwargs
                        ),
                        (new_data,),
                    )
                )

    @dtypes(
        *product(
            (torch.half, torch.bfloat16, torch.float, torch.double),
            (torch.int, torch.int64),
        )
    )
    def test_simple_1d(self, device, dtypes):
        val_dtype, length_type = dtypes
        lengths = [1, 2, 3, 0]
        data = [1, float("nan"), 3, 4, 5, 5]

        for reduction in reductions:
            for initial in [0, None]:
                check_backward = True if initial is not None else False
                initial_value = initial
                default_value = get_default_value(initial_value, reduction)
                if reduction == "max":
                    expected_result = [1, float("nan"), 5, default_value]
                    expected_grad = [1, 1, 0, 0, 0.5, 0.5]
                elif reduction == "mean":
                    expected_result = [1, float("nan"), 4.666, default_value]
                    expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
                elif reduction == "min":
                    if initial is not None:
                        initial_value = 1000  # some high number
                        default_value = get_default_value(initial_value, reduction)
                    expected_result = [1, float("nan"), 4, default_value]
                    expected_grad = [1.0, 1.0, 0, 1, 0, 0]
                elif reduction == "sum":
                    expected_result = [1, float("nan"), 14, default_value]
                    expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
                elif reduction == "prod":
                    if initial is not None:
                        initial_value = 2  # 0 initial_value will zero out everything for prod
                        default_value = get_default_value(initial_value, reduction)
                        expected_result = [2, float("nan"), 200, default_value]
                        expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
                    else:
                        expected_result = [1, float("nan"), 100, default_value]
                        expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
                for axis in [0, -1]:
                    for unsafe in [True, False]:
                        self._test_common(
                            reduction,
                            device,
                            val_dtype,
                            unsafe,
                            axis,
                            initial_value,
                            data,
                            lengths,
                            expected_result,
                            expected_grad,
                            check_backward,
                            length_type,
                        )

    @dtypes(
        *product(
            (torch.half, torch.bfloat16, torch.float, torch.double),
            (torch.int, torch.int64),
        )
    )
    def test_simple_zero_length(self, device, dtypes):
        val_dtype, length_type = dtypes
        lengths = [0, 0]
        data = torch.ones(0)

        for reduction in reductions:
            for initial in [0, None]:
                check_backward = True if initial is not None else False
                initial_value = initial
                default_value = get_default_value(initial_value, reduction)
                if reduction == "max":
                    expected_result = [default_value, default_value]
                    expected_grad = []
                elif reduction == "mean":
                    expected_result = [default_value, default_value]
                    expected_grad = []
                elif reduction == "min":
                    if initial is not None:
                        initial_value = 1000  # some high number
                        default_value = get_default_value(initial_value, reduction)
                    expected_result = [default_value, default_value]
                    expected_grad = []
                elif reduction == "sum":
                    expected_result = [default_value, default_value]
                    expected_grad = []
                elif reduction == "prod":
                    if initial is not None:
                        initial_value = 2  # 0 initial_value will zero out everything for prod
                        default_value = get_default_value(initial_value, reduction)
                        expected_result = [default_value, default_value]
                        expected_grad = []
                    else:
                        expected_result = [default_value, default_value]
                        expected_grad = []
                for axis in [0]:
                    for unsafe in [True, False]:
                        self._test_common(
                            reduction,
                            device,
                            val_dtype,
                            unsafe,
                            axis,
                            initial_value,
                            data,
                            lengths,
                            expected_result,
                            expected_grad,
                            check_backward,
                            length_type,
                        )

    @skipIfRocm
    @dtypes(
        *product(
            (torch.half, torch.bfloat16, torch.float, torch.double),
            (torch.int, torch.int64),
        )
    )
    def test_multi_d_simple(self, device, dtypes):
        val_dtype, length_type = dtypes
        axis = 0
        lengths = [1, 2, 3, 0]
        data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]

        for reduction in reductions:
            for initial in [0, None]:
                check_backward = True if initial is not None else False
                initial_value = initial
                default_value = get_default_value(initial_value, reduction)
                if reduction == "max":
                    expected_result = [
                        [1, 1],
                        [float("nan"), float("nan")],
                        [4, 3],
                        [default_value, default_value],
                    ]
                    expected_grad = [
                        [1, 1],
                        [1, 0],
                        [0, 1],
                        [1, 0],
                        [0, 0],
                        [0, 1],
                    ]
                elif reduction == "mean":
                    expected_result = [
                        [1, 1],
                        [float("nan"), float("nan")],
                        [3, 2],
                        [default_value, default_value],
                    ]
                    expected_grad = [
                        [1.0, 1.0],
                        [0.5, 0.5],
                        [0.5, 0.5],
                        [0.333, 0.333],
                        [0.333, 0.333],
                        [0.333, 0.333],
                    ]
                elif reduction == "min":
                    if initial is not None:
                        initial_value = 1000  # some high number
                        default_value = get_default_value(initial_value, reduction)
                    expected_result = [
                        [1, 1],
                        [float("nan"), float("nan")],
                        [2, 1],
                        [default_value, default_value],
                    ]
                    expected_grad = [
                        [1.0, 1.0],
                        [1, 0],
                        [0, 1],
                        [0, 1],
                        [0, 0],
                        [1, 0],
                    ]
                elif reduction == "sum":
                    expected_result = [
                        [1, 1],
                        [float("nan"), float("nan")],
                        [9, 6],
                        [default_value, default_value],
                    ]
                    expected_grad = [
                        [1.0, 1.0],
                        [1.0, 1.0],
                        [1.0, 1.0],
                        [1.0, 1.0],
                        [1.0, 1.0],
                        [1.0, 1.0],
                    ]
                elif reduction == "prod":
                    if initial is not None:
                        initial_value = 2  # 0 initial_value will zero out everything for prod
                        default_value = get_default_value(initial_value, reduction)
                        expected_result = [
                            [2, 2],
                            [float("nan"), float("nan")],
                            [48, 12],
                            [default_value, default_value],
                        ]
                        expected_grad = [
                            [2.0, 2.0],
                            [6.0, float("nan")],
                            [float("nan"), 2.0],
                            [12.0, 12.0],
                            [16.0, 6.0],
                            [24.0, 4.0],
                        ]
                    else:
                        expected_result = [
                            [1, 1],
                            [float("nan"), float("nan")],
                            [24, 6],
                            [default_value, default_value],
                        ]
                        expected_grad = [
                            [1.0, 1.0],
                            [3.0, float("nan")],
                            [float("nan"), 1.0],
                            [6.0, 6.0],
                            [8.0, 3.0],
                            [12.0, 2.0],
                        ]
                for unsafe in [True, False]:
                    self._test_common(
                        reduction,
                        device,
                        val_dtype,
                        unsafe,
                        axis,
                        initial_value,
                        data,
                        lengths,
                        expected_result,
                        expected_grad,
                        check_backward,
                    )

    @dtypes(
        *product(
            (torch.half, torch.bfloat16, torch.float, torch.double),
            (torch.int, torch.int64),
        )
    )
    @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
    def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
        val_dtype, length_dtype = dtypes
        # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
        tests = [
            {
                'src': [1, 2, 3, 4, 5, 6],
                'index': [0, 0, 1, 1, 1, 3],
                'indptr': [0, 2, 5, 5, 6],
                'sum': [3, 12, 0, 6],
                'prod': [2, 60, 1, 6],
                'mean': [1.5, 4, float('nan'), 6],
                'min': [1, 3, float('inf'), 6],
                'max': [2, 5, -float('inf'), 6],
            },
            {
                'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
                'index': [0, 0, 1, 1, 1, 3],
                'indptr': [0, 2, 5, 5, 6],
                'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
                'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
                'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
                'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
                'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
            },
            {
                'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
                'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
                'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
                'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
                'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
                'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
                'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
                'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
            },
            {
                'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
                'index': [[0, 0, 1], [0, 2, 2]],
                'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
                'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
                'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
                'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
                         [[7, 9], [float('nan'), float('nan')], [11, 12]]],
                'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
                        [[7, 9], [float('inf'), float('inf')], [10, 11]]],
                'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
                        [[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
            },
            {
                'src': [[1, 3], [2, 4]],
                'index': [[0, 0], [0, 0]],
                'indptr': [[0, 2], [0, 2]],
                'sum': [[4], [6]],
                'prod': [[3], [8]],
                'mean': [[2], [3]],
                'min': [[1], [2]],
                'max': [[3], [4]],
            },
            {
                'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
                'index': [[0, 0], [0, 0]],
                'indptr': [[0, 2], [0, 2]],
                'sum': [[[4, 4]], [[6, 6]]],
                'prod': [[[3, 3]], [[8, 8]]],
                'mean': [[[2, 2]], [[3, 3]]],
                'min': [[[1, 1]], [[2, 2]]],
                'max': [[[3, 3]], [[4, 4]]],
            },
        ]
        for test in tests:
            data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True)
            indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device)
            dim = indptr.ndim - 1
            # calculate lengths from indptr
            lengths = torch.diff(indptr, dim=dim)
            expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)

            actual_result = torch._segment_reduce(
                data=data,
                reduce=reduce,
                lengths=lengths,
                axis=dim,
                unsafe=True,
            )
            self.assertEqual(actual_result, expected)

            # test offsets
            actual_result = torch._segment_reduce(
                data=data,
                reduce=reduce,
                offsets=indptr,
                axis=dim,
                unsafe=True,
            )
            self.assertEqual(actual_result, expected)

            if val_dtype == torch.float64:
                def fn(x, mode='lengths'):
                    initial = 1
                    # supply initial values to prevent gradcheck from failing for 0 length segments
                    # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
                    if reduce == 'min':
                        initial = 1000
                    elif reduce == 'max':
                        initial = -1000
                    segment_reduce_args = {x, reduce}
                    segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
                    if mode == 'lengths':
                        segment_reduce_kwargs[mode] = lengths
                    elif mode == 'offsets':
                        segment_reduce_kwargs[mode] = indptr
                    return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
                self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
                self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))


    @dtypes(
        *product(
            (torch.half, torch.bfloat16, torch.float, torch.double),
            (torch.int, torch.int64),
        )
    )
    def test_multi_d(self, device, dtypes):
        val_dtype, length_type = dtypes
        axis = 0
        lengths = [0, 2, 3, 0]
        data = np.arange(50).reshape(5, 2, 5).tolist()
        expected_grad = []

        # TODO: calculate grad and check correctness
        check_backward = False

        for reduction in reductions:
            initial_value = 0
            if reduction == "max":
                expected_result = [
                    np.full((2, 5), initial_value).tolist(),
                    np.max(data[:2], axis=0).tolist(),
                    np.max(data[2:], axis=0).tolist(),
                    np.full((2, 5), initial_value).tolist(),
                ]
            elif reduction == "mean":
                expected_result = [
                    np.full((2, 5), initial_value).tolist(),
                    np.mean(data[:2], axis=0).tolist(),
                    np.mean(data[2:], axis=0).tolist(),
                    np.full((2, 5), initial_value).tolist(),
                ]
            elif reduction == "min":
                initial_value = 1000  # some high number
                expected_result = [
                    np.full((2, 5), initial_value).tolist(),
                    np.min(data[:2], axis=0).tolist(),
                    np.min(data[2:], axis=0).tolist(),
                    np.full((2, 5), initial_value).tolist(),
                ]
            elif reduction == "sum":
                expected_result = [
                    np.full((2, 5), initial_value).tolist(),
                    np.sum(data[:2], axis=0).tolist(),
                    np.sum(data[2:], axis=0).tolist(),
                    np.full((2, 5), initial_value).tolist(),
                ]
            elif reduction == "prod":
                initial_value = 1
                expected_result = [
                    np.full((2, 5), initial_value).tolist(),
                    np.prod(data[:2], axis=0).tolist(),
                    np.prod(data[2:], axis=0).tolist(),
                    np.full((2, 5), initial_value).tolist(),
                ]
            for unsafe in [True, False]:
                self._test_common(
                    reduction,
                    device,
                    val_dtype,
                    unsafe,
                    axis,
                    initial_value,
                    data,
                    lengths,
                    expected_result,
                    expected_grad,
                    check_backward,
                )

    @dtypes(torch.int, torch.int64)
    def test_unsafe_flag(self, device, dtype):
        length_type = dtype
        lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
        data = torch.arange(6, dtype=torch.float, device=device)

        # test for error on 1-D lenghts
        with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
            torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)

        # test for error on multi-D lengths
        nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
        nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
        with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
            torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)




instantiate_device_type_tests(TestSegmentReductions, globals())

if __name__ == "__main__":
    run_tests()
