# Owner(s): ["oncall: jit"]

import operator
import unittest
from textwrap import dedent
from typing import Any, List

import torch
from torch import nn, Tensor
from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch.testing._internal.common_utils import make_tensor
from torch.testing._internal.jit_utils import execWrapper, JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


# XXX: still in prototype
class TestSymbolicShapeAnalysis(JitTestCase):
    def setUp(self):
        super(JitTestCase, self).setUp()
        self.prev_symbolic_shapes_test_enabled = (
            torch._C._jit_symbolic_shapes_test_mode_enabled()
        )
        torch._C._jit_set_symbolic_shapes_test_mode(True)

    def tearDown(self):
        torch._C._jit_set_symbolic_shapes_test_mode(
            self.prev_symbolic_shapes_test_enabled
        )

    def test_shape_analysis(self):
        @torch.jit.script
        def foo(x, y):
            return x * y

        inputs = list(foo.graph.inputs())

        def prop_shapes_on_graph(inp0, inp1):
            inputs[0].setType(inputs[0].type().with_sizes(inp0))
            inputs[1].setType(inputs[1].type().with_sizes(inp1))
            torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)

        prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5])
        FileCheck().check("1, 7, 6, 5").run(foo.graph)

        # None implicitly creates a new symbolic symbol
        prop_shapes_on_graph([None, None], [None, None, None])
        output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes()
        inp0_shape = inputs[0].type().symbolic_sizes()
        inp1_shape = inputs[1].type().symbolic_sizes()

        # output shape dim 0 should be taken from the second inp dim0
        # other two dims we cannot infer and are given a new symbolic shape
        self.assertEqual(output_shape[0], inp1_shape[0])
        self.assertFalse(output_shape[1] in inp0_shape + inp1_shape)
        self.assertFalse(output_shape[2] in inp0_shape + inp1_shape)

        # XXX: symbolic shapes are represented with an increasing counter of unique
        # values, use `_new_symbolic_shape_symbol` api instead of specifying negative
        # dimensions directly so there is no chance of collision between manual number
        # and current counter value.
        sym1 = torch._C._new_symbolic_shape_symbol()
        sym2 = torch._C._new_symbolic_shape_symbol()
        sym3 = torch._C._new_symbolic_shape_symbol()
        prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3])
        output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes()
        self.assertEqual(output_shape[0], sym1)
        self.assertEqual(output_shape[1], sym2)
        self.assertEqual(output_shape[2], sym3)

    def test_shared_shape_graph(self):
        @torch.jit.script
        def foo(x, y):
            return x * y, x / y

        mul_node = foo.graph.findNode("aten::mul")
        div_node = foo.graph.findNode("aten::div")

        mul_graph = torch._C._jit_shape_compute_graph_for_node(mul_node)
        div_graph = torch._C._jit_shape_compute_graph_for_node(div_node)
        self.assertIsNotNone(mul_graph)
        self.assertIs(mul_graph, div_graph)

    def test_write(self):
        @torch.jit.script
        def foo(a, b):
            return a * b

        # broadcast appends cant be removed, so we bail on propagation
        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
        FileCheck().check("Tensor = aten::mul").run(foo.graph)

        @torch.jit.script
        def foo(y):
            x = [1, 2, 3, 4]
            x[0] = 5
            return y.view(x)

        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
        FileCheck().check("Tensor = aten::view").run(foo.graph)

    def test_if_propagation(self):
        @torch.jit.script
        def foo(i: int, z):
            x = torch.ones([2, 3, 4, 5])
            y = z.view([z.size(i), 3, 2, z.size(i)])
            if i == 4:
                return x
            else:
                return y

        torch._C._jit_pass_constant_propagation(foo.graph)
        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
        view = foo.graph.findNode("aten::view")

        def neg_to_one(li):
            return [elem if elem >= 0 else -1 for elem in li]

        self.assertEqual(
            neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]
        )
        if_out = next(foo.graph.findNode("prim::If").outputs())
        self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1])

    def test_unary_shape_functions(self):
        unary_ops = [
            torch.nn.functional.hardtanh,
        ]
        for fn in unary_ops:
            t = torch.jit.trace(fn, (torch.rand([4, 4])))
            ten_input = next(t.graph.inputs())
            ten_input.setType(ten_input.type().with_sizes([2, 2]))
            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2])

    def test_unary_shape_fns_inplace(self):
        def mul_inplace(x: torch.Tensor):
            y = x.mul_(2)
            return y

        unary_ops = [mul_inplace]
        for fn in unary_ops:
            # t = torch.jit.trace(fn, torch.rand([4, 4]))  # For some reason tracing is erroring out.
            t = torch.jit.script(fn)
            ten_input = next(t.graph.inputs())
            ten_input.setType(ten_input.type().with_sizes([2, 2]))
            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2])

    def test_binary_shape_functions(self):
        binary_ops = [
            operator.__mul__,
            operator.__truediv__,
            operator.__gt__,
            operator.__add__,
        ]

        for fn in binary_ops:
            size_1 = [1, 4, 8]
            size_2 = [4, 1, 8]
            t = torch.jit.trace(fn, (torch.rand([4]), torch.rand([4])))
            inputs = list(t.graph.inputs())
            inputs[0].setType(inputs[0].type().with_sizes(size_1))
            inputs[1].setType(inputs[1].type().with_sizes(size_2))
            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])

    def test_binary_shape_fns_inplace(self):
        def div_inplace_tensor(x: torch.Tensor, y: torch.Tensor):
            z = x.div_(y)
            return z

        def add_inplace_tensor(x: torch.Tensor, y: torch.Tensor):
            z = x.add_(y)
            return z

        binary_ops = [
            div_inplace_tensor,
            add_inplace_tensor,
        ]

        for fn in binary_ops:
            size_1 = [4, 4, 8]  # x (can't broadcast because it's an inplace op)
            t = torch.jit.script(fn)
            inputs = list(t.graph.inputs())
            inputs[0].setType(inputs[0].type().with_sizes(size_1))
            # Intentionally not populate the type of inputs[1]
            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])

    def test_size_and_sizes(self):
        @torch.jit.script
        def foo(x, y):
            return x.view(y.size(0), 8, y.size(-1))

        @torch.jit.script
        def foo2(x, y):
            return x.view(y.size())

        for graph in [foo.graph, foo2.graph]:
            inputs = list(graph.inputs())
            sym1 = torch._C._new_symbolic_shape_symbol()

            inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
            torch._C._jit_pass_propagate_shapes_on_graph(graph)
            self.assertEqual(
                next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]
            )

    def test_adaptive_avg_pool2d(self):
        inps = [
            [(1, 64, 8, 9), (5, 7)],
            [(1, 64, 10, 9), (7)],
            [(1, 64, 10, 9), (5, None)],
            [(1, 8, 4, 3), (None, None)],
            [(1, 8, 4, 3), (None, 5)],
        ]

        for inp in inps:
            t = torch.randn(*inp[0])
            out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size()

            def foo(x):
                return torch.nn.functional.adaptive_avg_pool2d(x, inp[1])

            fn = torch.jit.trace(foo, (t,))
            torch._C._jit_erase_non_input_shape_information(fn.graph)
            torch._C._jit_pass_peephole(fn.graph)
            torch._C._jit_pass_constant_propagation(fn.graph)
            self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)

    def test_conv_deconv(self):
        for (
            inp_shape,
            weight_shape,
            bias,
            stride,
            padding,
            output_padding,
            dilation,
            groups,
            mod,
        ) in [
            ([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
            (
                [32, 16, 10],
                [16, 3, 3],
                None,
                2,
                2,
                1,
                1,
                2,
                torch.nn.functional.conv_transpose1d,
            ),
            (
                [1, 32, 5, 10],
                [30, 16, 3, 3],
                None,
                [2, 2],
                [0, 0],
                0,
                1,
                2,
                torch.nn.functional.conv2d,
            ),
            (
                [1, 30, 5, 10],
                [30, 16, 3, 3],
                None,
                [2, 2],
                [0, 0],
                0,
                1,
                2,
                torch.nn.functional.conv_transpose2d,
            ),
            (
                [3, 14, 10, 66, 55],
                [2, 7, 7, 4, 4],
                None,
                1,
                1,
                2,
                1,
                2,
                torch.nn.functional.conv3d,
            ),
            (
                [3, 2, 10, 66, 55],
                [2, 7, 7, 4, 4],
                None,
                1,
                1,
                0,
                1,
                2,
                torch.nn.functional.conv_transpose3d,
            ),
        ]:
            inp = torch.rand(inp_shape)
            weight = torch.rand(weight_shape)
            if mod in [
                torch.nn.functional.conv1d,
                torch.nn.functional.conv2d,
                torch.nn.functional.conv3d,
            ]:
                res = mod(inp, weight, bias, stride, padding, dilation, groups).size()
            else:
                res = mod(
                    inp, weight, bias, stride, padding, output_padding, dilation, groups
                ).size()

            def foo(inp, weight):
                if mod in [
                    torch.nn.functional.conv1d,
                    torch.nn.functional.conv2d,
                    torch.nn.functional.conv3d,
                ]:
                    return mod(inp, weight, bias, stride, padding, dilation, groups)
                else:
                    return mod(
                        inp,
                        weight,
                        bias,
                        stride,
                        padding,
                        output_padding,
                        dilation,
                        groups,
                    )

            fn = torch.jit.trace(foo, (inp, weight))
            torch._C._jit_erase_non_input_shape_information(fn.graph)
            torch._C._jit_pass_peephole(fn.graph)
            torch._C._jit_pass_constant_propagation(fn.graph)
            self.checkShapeAnalysis(res, fn.graph, assert_propagation=True)

    def test_arange_shape(self):
        # no opinfo for tensor constructors
        inps = [
            (10,),
            (10, 10),
            (0, 10),
            (0, 1000),
            (1, -1, -1),
            (1, 0, -1),
            (1, 2, 1),
            (0.6, 0.89, 0.1),
            (1, 10, 0.3),
            (1, 10, 4),
            (0.6, 0.7, 0.8),
            (1, 10, 0.3),
            # (True,),  TODO: https://github.com/pytorch/pytorch/issues/63405
            # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405
            (0, 5),
            (0, 5, 2),
            (0, 5 + 1e-6),
            (0, 5 - 1e-6),
            (10, -1 + 1e-6, -1),
            (10, -1, -1),
            (10, -1 - 1e-6, -1),
        ]

        for inp in inps:
            funcs_template = dedent(
                """
            def func():
                return torch.arange({args})
            """
            )

            inp_s = str(inp)[1:-1]  # remove tuple parens
            funcs_str = funcs_template.format(args=inp_s)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            cu = torch.jit.CompilationUnit(funcs_str)
            self.checkShapeAnalysis(
                list(cu.func().size()),
                cu.func.graph,
                assert_propagation=True,
                constant_prop=False,
            )

    def test_shape_embedding_bag(self):
        # TODO: merge into opinfos, having difficulties there
        with torch.no_grad():

            def make_arg(shape, low=None, high=None):
                return make_tensor(
                    shape,
                    device="cpu",
                    dtype=torch.int64,
                    low=low,
                    high=high,
                    requires_grad=False,
                )

            nn_inps = (
                (
                    make_arg((40,), 0, 9),
                    torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0),
                ),
                (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
                (make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)),
                (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
                (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
                (
                    make_arg((2,), 0, 1),
                    torch.nn.Embedding.from_pretrained(
                        torch.arange(6.0).view(2, 3),
                        max_norm=2.0,
                        norm_type=0.5,
                        scale_grad_by_freq=False,
                        sparse=True,
                    ),
                ),
            )

            for inp, module in nn_inps:
                kwargs = {
                    "weight": module.weight.detach(),
                    "padding_idx": module.padding_idx,
                    "max_norm": module.max_norm,
                    "norm_type": module.norm_type,
                    "scale_grad_by_freq": module.scale_grad_by_freq,
                    "sparse": module.sparse,
                }

                out_size = torch.nn.functional.embedding(inp, **kwargs).size()

                def foo(x):
                    return torch.nn.functional.embedding(inp, **kwargs)

                fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)

                self.checkShapeAnalysis(
                    out_size, fn.graph, assert_propagation=True, constant_prop=False
                )

    def test_shape_concat(self):
        # TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR
        sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False)

        class CatMod(nn.Module):
            __constants__ = ["dim"]

            def __init__(self, dim=0):
                super().__init__()
                self.dim = dim

            def forward(self, x, y):
                return torch.cat([x, y], dim=self.dim)

        for inp in sample_inputs:
            mod = torch.jit.script(CatMod(**inp.kwargs).eval())

            args = inp.input

            # This test is hard-coded only to work with two sample inputs
            # but the OpInfo may have more/less
            if len(args) != 2:
                continue

            out_size = mod(*args).size()
            inps = list(mod.graph.inputs())
            inps[1].setType(inps[1].type().with_sizes(args[0].size()))
            inps[2].setType(inps[2].type().with_sizes(args[1].size()))
            self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True)

    def assert_shape_equal_scripted(self, script_fn, given_ins):
        expected_res = script_fn(*given_ins)
        g = script_fn.graph
        graph_ins = list(g.inputs())
        self.assertEqual(len(given_ins), len(graph_ins))
        for inp, graph_in in zip(given_ins, graph_ins):
            graph_in.setType(graph_in.type().with_sizes(inp.size()))

        out_sizes = [out.size() for out in expected_res]
        self.checkShapeAnalysis(out_sizes, g, assert_propagation=True)

    def test_convolution_backward(self):
        # No opinfos for ops that are not part of the Python API
        # Also, as the return shapes are the input, weight, and bias shape, there is no point
        # in a really complicated test

        input = torch.randn(
            (16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True
        )
        weight = torch.randn(
            (8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True
        )
        out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")

        @torch.jit.script
        def conv_bwd(input, weight, grad):
            bias_sizes = [
                8,
            ]
            args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
            return torch.ops.aten.convolution_backward(
                grad, input, weight, bias_sizes, *args
            )

        self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))

        @torch.jit.script
        def conv_bwd_2(input, weight, grad):
            bias_sizes = None
            args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
            return torch.ops.aten.convolution_backward(
                grad, input, weight, bias_sizes, *args
            )

        self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))

    def test_returning_input_symbolic_shapes(self):
        mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
        inps = list(mm.graph.inputs())
        inps[1].setType(inps[1].type().with_sizes([None, None, None, None]))
        shape_compute_graph = (
            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
        )
        g = shape_compute_graph.partial_eval_shape_graph()
        # to make into a jit function cant have multiple outputs
        g.makeMultiOutputIntoTuple()
        func = torch._C._create_function_from_graph("partial_eval_graph", g)
        out = func([20, 16, 5, 10])
        # first four outputs should be unknown symbolic shapes from input
        self.assertEqual(out[0:4], [20, 16, 5, 10])
        # last two are two new symbolic dims - height and width
        self.assertEqual(out[4:], list(mm(torch.rand([20, 16, 5, 10])).size()[2:]))

    def test_partial_eval_graph_conv(self):
        mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
        shape_compute_graph = (
            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
        )
        output_sizes = (
            mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
        )
        # calculating 0, 2 and 3 index
        for i in [0, 2, 3]:
            self.assertTrue(output_sizes[i] < 0)
        self.assertTrue(output_sizes[1] >= 0)
        g = shape_compute_graph.partial_eval_shape_graph()
        # to make into a jit function cant have multiple outputs
        g.makeMultiOutputIntoTuple()
        func = torch._C._create_function_from_graph("partial_eval_graph", g)
        inp = torch.randn(20, 16, 5, 10)
        output = func([20, 16, 5, 10])
        output_eager = list(mm(inp).size())
        for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
            self.assertEqual(o, oe)

    def checkSymShapeCompute(
        self, shape_compute_graph, nodes, node_output_sizes, shape_inputs
    ):
        g = shape_compute_graph.partial_eval_shape_graph()
        self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
        output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
        # map from sym shape -> index
        sym_shape_to_index = {}
        for index, output in enumerate(g.outputs()):
            sym_shape_to_index[output_sym_map[output]] = index

        g.makeMultiOutputIntoTuple()
        func = torch._C._create_function_from_graph("partial_eval_graph", g)
        sym_outputs = func(*shape_inputs)

        for node, output_shape in zip(nodes, node_output_sizes):
            output_type_sizes = node.output().type().symbolic_sizes()
            for i, sym_shape in enumerate(output_type_sizes):
                if sym_shape >= 0:
                    self.assertEqual(sym_shape, output_shape[i])
                else:
                    sym_shape_index = sym_shape_to_index[sym_shape]
                    self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])

    def test_partial_eval_stitching(self):
        conv1 = torch.nn.Conv2d(
            3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        max_pool = torch.nn.MaxPool2d(
            kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
        )
        conv2 = nn.Conv2d(
            64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
        )

        mod = torch.jit.freeze(
            torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())
        )

        conv1_output = conv1(torch.rand(1, 3, 224, 224))
        max_pool_output = max_pool(conv1_output)
        conv2_output = conv2(max_pool_output)

        shape_compute_graph = (
            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
        )
        nodes = [mod.graph.findNode("aten::max_pool2d")] + list(
            mod.graph.findAllNodes("aten::conv2d")
        )
        output_shapes = [
            max_pool_output.size(),
            conv1_output.size(),
            conv2_output.size(),
        ]
        self.checkSymShapeCompute(
            shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)
        )

    def test_refinement_through_graph_stitching(self):
        class TwoConvs(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(
                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
                )
                self.conv2 = torch.nn.Conv2d(
                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
                )

            def forward(self, x):
                a = self.conv1(x)
                b = self.conv2(x)
                return a + b

        mod = torch.jit.freeze(torch.jit.script(TwoConvs()).eval())
        inp_tensor = list(mod.graph.inputs())[1]
        inp_tensor.setType(inp_tensor.type().with_sizes([None, None, None, None]))
        torch._C._jit_pass_propagate_shapes_on_graph(mod.graph)
        outs = list(next(mod.graph.outputs()).node().inputs())
        out1 = outs[0].type().symbolic_sizes()
        out2 = outs[1].type().symbolic_sizes()
        self.assertTrue(out1[2] != out2[2])
        self.assertTrue(out1[3] != out2[3])
        # by joining partial eval graphs of both convs we are able to recognize the output shapes
        # are equivalent
        torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
        out1 = outs[0].type().symbolic_sizes()
        out2 = outs[1].type().symbolic_sizes()
        self.assertEqual(out1, out2)

    def test_stitching_multi_output(self):
        max_pool = torch.nn.MaxPool2d(
            kernel_size=3,
            stride=2,
            padding=1,
            dilation=1,
            ceil_mode=False,
            return_indices=True,
        )
        tensor = torch.rand(1, 3, 224, 224)
        mod = torch.jit.trace(max_pool, (tensor,))
        mod = torch.jit.freeze(mod.eval())
        inp = list(mod.graph.inputs())[1]
        inp.setType(inp.type().with_sizes([None, None, None, None]))
        output_tensor = list(mod(tensor)[0].size())
        self.run_pass("lower_all_tuples", mod.graph)
        shape_compute_graph = (
            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
        )
        max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
        outs = list(max_pool_node.outputs())
        self.assertEqual(
            outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()
        )
        g = shape_compute_graph.partial_eval_shape_graph()
        # to make into a jit function cant have multiple outputs
        g.makeMultiOutputIntoTuple()
        func = torch._C._create_function_from_graph("partial_eval_graph", g)
        mapping = shape_compute_graph.graph_output_to_symbolic_shape_dim()
        output_shape = func(tensor.size())
        # the first 4 dims are input sym dimensions, then the ,
        self.assertEqual(list(output_shape[0:4]), list(tensor.size()))
        self.assertEqual(list(output_shape[4:]), output_tensor[2:])

    def test_sym_ir_parsing(self):
        graph_str1 = """graph(%x.1 : Float(SS(-2), SS(-3))):
                        %3 : int = prim::Constant[value=1]()
                        %4 : Tensor = aten::add(%x.1, %x.1, %3)
                        return (%4)"""
        g = torch._C.parse_ir(graph_str1)
        inp = next(g.inputs())
        out = inp.type().symbolic_sizes()
        self.assertEqual(out, [-2, -3])

    def test_stitching_concat(self):
        @torch.jit.script
        def foo1(a, b, x, y):
            return (a / b) + torch.cat([x, y])

        @torch.jit.script
        def foo2(a, b, x, y):
            return (a / b) + torch.cat([x, y], dim=-2)

        for foo in [foo1, foo2]:
            g = foo.graph
            for inp in foo.graph.inputs():
                inp.setType(inp.type().with_sizes([None, None]))

            shape_compute_graph = (
                torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
                    foo.graph
                )
            )
            nodes = (
                [g.findNode("aten::div")]
                + [g.findNode("aten::add")]
                + [g.findNode("aten::cat")]
            )

            inps = [1, 10], [20, 10], [15, 1], [5, 1]
            output_shapes = [[20, 10], [20, 10], [20, 1]]

            self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)

    @unittest.skipIf(
        not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python"
    )
    def test_shape_function_includes(self):
        inp_shape = [1, 16, 5, 10]
        weight_shape = [33, 16, 3, 3]
        bias = None
        stride = [2, 2]
        padding = [0, 0]
        dilation = [1, 1]
        groups = 1
        res = torch.jit._shapes.conv2d(
            inp_shape, weight_shape, bias, stride, padding, dilation, groups
        )
        self.assertEqual(res, [1, 33, 2, 4])

        m1_shape = [10, 20]
        m2_shape = [20, 10]
        res = torch.jit._shapes.matmul(m1_shape, m2_shape)
        self.assertEqual(res, [10, 10])

    def test_register_function_error_checking(self):
        # this will error before registering on global map, so
        # no issue in overwriting schema mappings
        @torch.jit.script
        def foo(x, y):
            return x + y

        node = foo.graph.findNode("aten::add")

        @torch.jit.script
        def wrong_input_types(x, y):
            x: List[int] = []
            return x

        with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
            torch._C._jit_register_shape_compute_graph_for_node(
                node, wrong_input_types.graph
            )

        @torch.jit.script
        def wrong_output_types(x: List[int], y: List[int]):
            x: List[Tensor] = []
            return x

        with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
            torch._C._jit_register_shape_compute_graph_for_node(
                node, wrong_output_types.graph
            )

        @torch.jit.script
        def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
            x: List[int] = []
            return x

        with self.assertRaises(RuntimeError) as error:
            torch._C._jit_register_shape_compute_graph_for_node(
                node, too_many_inputs.graph
            )

        self.assertTrue("fewer arguments than schema" in str(error.exception))

    def test_cross_entropy_loss(self):
        @torch.jit.script
        def foo(x, y):
            return torch.ops.aten.cross_entropy_loss(x, y, reduction=0)

        inputs = list(foo.graph.inputs())
        inputs[0].setType(inputs[0].type().with_sizes([8, 2]))
        inputs[1].setType(
            inputs[1]
            .type()
            .with_sizes(
                [
                    8,
                ]
            )
        )
        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
        self.assertEqual(
            next(foo.graph.outputs()).type().sizes(),
            [
                8,
            ],
        )

    def test_squeeze_dims(self):
        @torch.jit.script
        def foo(x):
            return torch.ops.aten.squeeze(x, dim=0)

        input = next(foo.graph.inputs())
        input.setType(input.type().with_sizes([1, 5, 8]))
        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
        self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8])
