# Owner(s): ["oncall: jit"]

import os
import sys
from typing import List

import torch


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


# Tests that Python slice class is supported in TorchScript
class TestSlice(JitTestCase):
    def test_slice_kwarg(self):
        def slice_kwarg(x: List[int]):
            return x[slice(1, stop=2)]

        with self.assertRaisesRegex(
            RuntimeError, "Slice does not accept any keyword arguments"
        ):
            torch.jit.script(slice_kwarg)

    def test_slice_three_nones(self):
        def three_nones(x: List[int]):
            return x[slice(None, None, None)]

        self.checkScript(three_nones, (range(10),))

    def test_slice_two_nones(self):
        def two_nones(x: List[int]):
            return x[slice(None, None)]

        self.checkScript(two_nones, (range(10),))

    def test_slice_one_none(self):
        def one_none(x: List[int]):
            return x[slice(None)]

        self.checkScript(one_none, (range(10),))

    def test_slice_stop_only(self):
        def fn(x: List[int]):
            return x[slice(5)]

        self.checkScript(fn, (range(10),))

    def test_slice_stop_only_with_nones(self):
        def fn(x: List[int]):
            return x[slice(None, 5, None)]

        self.checkScript(fn, (range(10),))

    def test_slice_start_stop(self):
        def fn(x: List[int]):
            return x[slice(1, 5)]

        self.checkScript(fn, (range(10),))

    def test_slice_start_stop_with_none(self):
        def fn(x: List[int]):
            return x[slice(1, 5, None)]

        self.checkScript(fn, (range(10),))

    def test_slice_start_stop_step(self):
        def fn(x: List[int]):
            return x[slice(0, 6, 2)]

        self.checkScript(fn, (range(10),))

    def test_slice_string(self):
        def fn(x: str):
            return x[slice(None, 3, 1)]

        self.checkScript(fn, ("foo_bar",))

    def test_slice_tensor(self):
        def fn(x: torch.Tensor):
            return x[slice(None, 3, 1)]

        self.checkScript(fn, (torch.ones(10),))

    def test_slice_tensor_multidim(self):
        def fn(x: torch.Tensor):
            return x[slice(None, 3, 1), 0]

        self.checkScript(fn, (torch.ones((10, 10)),))

    def test_slice_tensor_multidim_with_dots(self):
        def fn(x: torch.Tensor):
            return x[slice(None, 3, 1), ...]

        self.checkScript(fn, (torch.ones((10, 10)),))

    def test_slice_as_variable(self):
        def fn(x: List[int]):
            a = slice(1)
            return x[a]

        self.checkScript(fn, (range(10),))

    def test_slice_stop_clipped(self):
        def fn(x: List[int]):
            return x[slice(1000)]

        self.checkScript(fn, (range(10),))

    def test_slice_dynamic_index(self):
        def t(x):
            slice1 = x[0:1]
            zero = 0
            one = zero + 1
            slice2 = x[zero:one]
            return slice1 + slice2

        self.checkScript(t, (torch.zeros(3, 2, 3),))

    def test_tuple_slicing(self):
        def tuple_slice(a):
            if bool(a):
                b = (1, 2, 3, 4)
            else:
                b = (4, 3, 2, 1)
            c = b[-4:4]
            e = c[1:-1]
            return e

        self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
        scripted_fn = torch.jit.script(tuple_slice)
        self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
        tuple_graph = scripted_fn.graph
        slices = tuple_graph.findAllNodes("prim::TupleConstruct")
        num_outputs = {len(x.output().type().elements()) for x in slices}
        # there should be only one tupleSlice with length of 2
        self.assertTrue(num_outputs == {2})
        self.run_pass("lower_all_tuples", tuple_graph)
        self.assertTrue("Tuple" not in str(tuple_graph))

    def test_module_list_slicing(self):
        class Bar(torch.nn.Module):
            def __init__(self, identifier: str):
                super().__init__()
                self.identifier = identifier

            def forward(self):
                return 0

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
                self.test = torch.nn.ModuleList(module_list)

            def forward(self):
                return self.test[::-2], self.test[1:4:]

        scripted_foo = torch.jit.script(Foo())
        result1, result2 = scripted_foo()

        self.assertEqual(len(result1), 3)
        self.assertEqual(result1[0].identifier, "E")
        self.assertEqual(result1[1].identifier, "C")
        self.assertEqual(result1[2].identifier, "A")

        self.assertEqual(len(result2), 3)
        self.assertEqual(result2[0].identifier, "B")
        self.assertEqual(result2[1].identifier, "C")
        self.assertEqual(result2[2].identifier, "D")
