# Owner(s): ["oncall: jit"]

import io
import os
import sys
from enum import Enum
from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch.testing import FileCheck


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestUnion(JitTestCase):
    """
    This class tests the functionality of `Union`.

    Note: It's important to be able to refine the type of a `Union` to
    one of its internal types. Currently, there are differences in the
    way Python expects `isinstance` checks and the way TorchScript
    expects `isinstance` checks. This means that we can't use
    `checkScript` in our test cases because either the eager mode or the
    script mode wouldn't run! So, some test cases have separate but
    equivalent functions to emulate `checkScript`.
    """

    def test_check_union_annotation(self):
        def test_func(a: Union[int, float], b: Optional[int]):
            return 0

        scripted_func = torch.jit.script(test_func)
        graph_rep = str(scripted_func.graph)
        code_rep = str(scripted_func.code)
        # TS graph IR for Union should be annotated as Union()
        FileCheck().check("Union(").check("int?").run(graph_rep)
        # Serialized code for Union should be annotated as Union[]
        FileCheck().check("Union[").check("Optional[int]").run(code_rep)
        self.checkScript(test_func, (5, 6))
        # this shouldn't error out
        torch._C.parse_ir(str(scripted_func.graph))

    def test_union_with_scalar_values(self):
        def fn(x: Union[int, float]) -> str:
            return "foo"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (1.0,))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(
            RuntimeError,
            "Expected a member of"
            r" Union\[float, int\] but "
            "instead found type str",
        ):
            scripted("1")

    def test_union_with_collections(self):
        def fn(x: Union[Dict[str, int], List[int]]) -> str:
            return "foo"

        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
        self.checkScript(fn, ([1, 2, 3],))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(
            RuntimeError,
            "Expected a member of"
            r" Union\[List\[int\], Dict\[str, "
            r"int\]\] but instead found type "
            r"Dict\[str, str\]",
        ):
            scripted({"foo": "bar", "baz": "qux"})

        with self.assertRaisesRegex(
            RuntimeError,
            "Expected a member of"
            r" Union\[List\[int\], Dict\[str, "
            r"int\]\] but instead found type "
            r"List\[str\]",
        ):
            scripted(["foo", "bar", "baz"])

        with self.assertRaisesRegex(
            RuntimeError,
            "Expected a member of"
            r" Union\[List\[int\], Dict\[str, "
            r"int\]\] but instead found type "
            "str",
        ):
            scripted("1")

    def test_union_with_enum(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        def fn(x: Union[str, Color]) -> str:
            return "foo"

        self.checkScript(fn, (Color.RED,))
        self.checkScript(fn, ("red",))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(
            RuntimeError,
            "Expected a member of"
            r" Union\[__torch__.jit.test_union."
            r"Color, str\] but instead found "
            "type int",
        ):
            scripted(1)

    def test_union_in_class_constructor(self):
        @torch.jit.script  # noqa: B903
        class A:  # noqa: B903
            def __init__(self, x: Union[int, str]) -> None:
                self.x = x

        def fn(x: Union[str, int]) -> A:
            return A(x)

        self.assertEqual(fn("foo").x, "foo")
        self.assertEqual(fn(1).x, 1)

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(
            RuntimeError,
            "Expected a member of"
            r" Union\[int, str\] but instead "
            r"found type List\[str\]",
        ):
            scripted(["foo", "bar", "baz"])

    def test_union_return_type(self):
        def fn(x: int) -> Union[int, str]:
            return "foo"

        self.checkScript(fn, (1,))

    def test_union_as_annotation(self):
        def fn() -> Union[int, str]:
            x: Union[int, str] = "foo"
            return x

        self.checkScript(fn, ())

    def test_union_as_annotation_in_typed_container(self):
        def fn() -> None:
            l: List[Union[int, str]] = []
            u1: Union[int, str] = "foo"
            u2: Union[int, str] = 1
            l.append(u1)
            l.append(u2)

        self.checkScript(fn, ())

    def test_union_as_annotation_py2(self):
        def fn():
            # type: () -> Union[int, str]
            x: Union[int, str] = "foo"
            return x

        self.checkScript(fn, ())

    def test_union_as_internal_tuple_type(self):
        def fn():
            t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
            return t

        self.checkScript(fn, ())

    def test_union_variable_can_be_reassigned(self):
        @torch.jit.script
        def aux1(i: int):
            return int(i**2)

        @torch.jit.script
        def aux2(s: str):
            return s + s

        def fn() -> Union[int, str]:
            x: Union[int, str] = "foo"
            i: int = 1
            x = i
            y: int = aux1(x)
            z: str = aux2(str(y))
            x = z
            return x

        self.checkScript(fn, ())

    def test_union_does_not_replace_existing_annotated_type(self):
        def fn():
            x: List[int] = [1, 2, 3]
            x.append("foo")
            return x

        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
            scripted = torch.jit.script(fn)
            scripted()

    def test_union_does_not_replace_existing_annotated_type_union(self):
        def fn():
            x: List[Union[int, str]] = [1, "foo", 3]
            x.append(2.0)
            return x

        with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
            scripted = torch.jit.script(fn)
            scripted()

    def test_union_does_not_replace_existing_annotated_type_empty_container(self):
        def fn():
            x: List[int] = []
            x.append("foo")
            return x

        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
            scripted = torch.jit.script(fn)
            scripted()

    def test_unions_of_unions_are_flattened(self):
        @torch.jit.script
        def fn(x: Union[Union[int, str], float]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(float, int, str)").run(s)

    def test_unions_of_a_single_argument_vanish(self):
        @torch.jit.script
        def fn(x: Union[int]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : int").run(s)

    def test_union_redundant_arguments_are_skipped(self):
        @torch.jit.script
        def fn(x: Union[int, str, int]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(int, str)").run(s)

    def test_union_redundant_arguments_are_skipped_optional(self):
        @torch.jit.script
        def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(float, int, NoneType)").run(s)

    def test_union_redundant_arguments_are_skipped_subtyping(self):
        @torch.jit.script
        def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union((int?, int), str)").run(s)

    def test_union_redundant_arguments_are_skipped_container(self):
        @torch.jit.script
        def fn(x: Union[List[str], List[float], List[str]]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(float[], str[])").run(s)

    def test_union_argument_order_is_ignored(self):
        @torch.jit.script
        def fn1(x: Union[int, str]) -> str:
            return "foo"

        @torch.jit.script
        def fn2(x: Union[str, int]) -> str:
            return "foo"

        for s in (fn1.graph, fn2.graph):
            FileCheck().check("x : Union(int, str)").run(s)

    def test_union_argument_order_is_ignored_container(self):
        @torch.jit.script
        def fn1(x: Union[List[str], List[int]]) -> str:
            return "foo"

        @torch.jit.script
        def fn2(x: Union[List[int], List[str]]) -> str:
            return "foo"

        for s in (fn1.graph, fn2.graph):
            FileCheck().check("x : Union(int[], str[])").run(s)

    def test_union_T_None_is_equivalent_to_optional_T(self):
        @torch.jit.script
        def inner(x: Union[int, None]) -> int:
            if x is not None:
                return x
            else:
                return 5

        @torch.jit.script
        def fn1() -> int:
            a: Optional[int] = 5
            b: Optional[int] = None
            a_ = inner(a)
            b_ = inner(b)
            return a_ + b_

        self.assertEqual(fn1(), 10)

        @torch.jit.script
        def inner2(x: Optional[int]) -> int:
            if x is not None:
                return x
            else:
                return 5

        @torch.jit.script
        def fn2() -> int:
            a: Union[int, None] = 5
            b: Union[int, None] = None
            a_ = inner(a)
            b_ = inner(b)
            return a_ + b_

        self.assertEqual(fn2(), 10)

    def test_union_optional_of_union_is_flattened(self):
        @torch.jit.script
        def fn(flag: int) -> Union[str, int, None]:
            y: Union[int, str, None] = "foo"
            if flag == 0:
                x: Optional[Union[int, str]] = y
            elif flag == 1:
                x: Optional[Union[int, str]] = 1
            else:
                x: Optional[Union[int, str]] = None
            return x

        # Can't use `checkScript` because it will flag the fact that
        # the original code has `Optional[Union[int, str]]` but the
        # saved/loaded code has `Union[int, NoneType, str]` (even
        # though this is exactly what we want)
        self.assertEqual(fn(0), "foo")
        self.assertEqual(fn(1), 1)
        self.assertEqual(fn(2), None)

        buffer = io.BytesIO()
        torch.jit.save(fn, buffer)
        buffer = io.BytesIO(buffer.getvalue())
        l = torch.jit.load(buffer)

        s = l.code

        FileCheck().check("Union[int, NoneType, str]").check(
            "Union[int, NoneType, str]"
        ).run(s)

    def test_union_subclasses_larger_union(self):
        def fn() -> Union[int, str, torch.Tensor]:
            x: Union[int, str] = "foo"
            return x

        self.checkScript(fn, ())

    # TODO: We would like to eventually support this. The issue is being
    # tracked at https://github.com/pytorch/pytorch/issues/58167
    def test_union_as_dict_key(self):
        def fn():
            x: Dict[Union[int, str], str] = {}
            x["foo"] = "bar"
            x[1] = 2
            return x[1]

        with self.assertRaisesRegex(
            RuntimeError,
            "only int, float, "
            "complex, Tensor, device and string keys "
            "are supported",
        ):
            torch.jit.script(fn)

    def test_union_as_dict_value(self):
        def fn():
            x: Dict[str, Union[int, str]] = {}
            x["foo"] = "bar"
            x["baz"] = 2
            return x["baz"]

        self.checkScript(fn, ())

    def test_union_module_with_union_instance_variable(self):
        class M(torch.nn.Module):
            x: Union[int, str]

            def __init__(self, x: Union[int, str]):
                super().__init__()
                self.x: Union[int, str] = x

            def forward(self, y: Union[int, str]):
                self.x = y
                return self.x

        self.checkModule(
            M(
                2,
            ),
            (1,),
        )
        self.checkModule(M("bar"), ("foo",))

    def test_union_module_with_union_class_variable(self):
        class M(torch.nn.Module):
            x: Union[int, str] = "foo"

            def __init__(self, y: int):
                super().__init__()
                x = y

            def forward(self, z: str):
                x = z
                return x

        self.checkModule(M(1), ("foo",))

    def test_union_type_refinement(self):
        def fn(x: Union[int, str]) -> str:
            if isinstance(x, str):
                z = x + "bar"
                return x
            else:
                return "baz"

        self.checkScript(fn, ("foo",))
        self.checkScript(fn, (1,))

    def test_union_type_refinement_union_rhs(self):
        def fn(x: int) -> str:
            if torch.jit.isinstance(x, Union[int, str]):
                return "bar"
            else:
                return "baz"

        self.checkScript(fn, (1,))

    def test_union_type_refinement_tuple_rhs(self):
        def fn(x: Union[int, float, List[str]]) -> str:
            if isinstance(x, (int, float)):
                if isinstance(x, int):
                    return str(x)
                else:
                    return "foo"
            else:
                if len(x):
                    return x[0]
                else:
                    return "bar"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (1.0,))
        self.checkScript(fn, (["a", "b", "c"],))

    def test_union_type_refinement_tuple_rhs_noncontained_type(self):
        def fn(x: Union[int, List[str]]) -> str:
            if isinstance(x, (int, float)):
                y = x + x
                return str(y)
            else:
                if len(x):
                    return x[0]
                else:
                    return "bar"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (["a", "b", "c"],))

    def test_union_type_refinement_tuple_rhs_union(self):
        @torch.jit.script
        def fn(x: int) -> str:
            if torch.jit.isinstance(x, (Union[int, str], float)):
                y = x + x
                return str(y)
            else:
                return "foo"

        # TODO: There's currently an unrelated bug in
        # `torch.jit.isinstance` that makes it fail for tuple literals.
        # Posted here: https://github.com/pytorch/pytorch/issues/60095
        # Change `assertEqual` to `checkScript` when the bug is fixed
        self.assertEqual(fn(1), "2")

    def test_union_type_refinement_statically_false(self):
        @torch.jit.script
        def fn(x: int) -> str:
            if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
                z = x + "foo"
                return z
            else:
                return "bar"

        s = fn.graph

        # Check that we don't have any branching statements
        FileCheck().check_not("block0()").check_not("block1()").run(s)

    def test_union_type_refinement_statically_true(self):
        @torch.jit.script
        def fn(x: Union[List[int], int]) -> Union[List[int], int]:
            if not torch.jit.isinstance(x, (int, List[int])):
                return x
            else:
                l = [1, 2, 3]
                y: Union[List[int], int] = l
                return y

        s = fn.graph

        # Check that we don't have any branching statements
        FileCheck().check_not("block0()").check_not("block1()").run(s)

    def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
        def fn(x: Union[List[int], int]) -> int:
            if torch.jit.isinstance(x, (int, float, str)):
                # We should know that `x` is an `int` here
                z = x + 1
                return z
            else:
                return 100

        self.checkScript(fn, ([1, 2, 3],))
        self.checkScript(fn, (1,))

    def test_union_type_refinement_partial_static_refinement_union_rhs(self):
        def fn(x: Union[List[int], int]) -> int:
            if torch.jit.isinstance(x, Union[int, float, str]):
                # We should know that `x` is an `int` here
                z = x + 1
                return z
            else:
                return 100

        self.checkScript(fn, ([1, 2, 3],))
        self.checkScript(fn, (1,))

    def test_union_type_refinement_internal_declaration(self):
        def fn(flag: bool) -> str:
            x: Union[int, str, None] = None
            if flag:
                y = "foo"
            else:
                y = 1
            if isinstance(x, str):
                return x
            else:
                return "bar"

        self.checkScript(fn, (True,))
        self.checkScript(fn, (False,))

    def test_union_branching_with_union_return_and_homogenous_types(self):
        def fn(x: int) -> Union[int, str]:
            if x % 2:
                return "foo"
            else:
                return "bar"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (8,))

    def test_union_branching_does_not_autoinfer_undeclared_union(self):
        def fn(x: int) -> str:
            if x % 2:
                y = "foo"
            else:
                y = x
            if isinstance(y, str):
                return y
            else:
                return "bar"

        with self.assertRaisesRegex(
            RuntimeError,
            "y is set to type str"
            " in the true branch and type int "
            "in the false branch",
        ):
            torch.jit.script(fn)

    def test_union_branching_does_not_widen_existing_inferred_type(self):
        def fn(x: int) -> str:
            y = "foo"
            if x % 2:
                y = "bar"
            else:
                y = x
            if isinstance(y, str):
                return y
            else:
                return "baz"

        with self.assertRaisesRegex(
            RuntimeError,
            "previously had type "
            "str but is now being assigned to a"
            " value of type int",
        ):
            torch.jit.script(fn)

    def test_union_schema_matching_on_internal_type(self):
        def fn(x: Union[List[int], Dict[str, int]]) -> int:
            if torch.jit.isinstance(x, List[int]):
                return x[0]
            else:
                return list(x.values())[0]

        self.checkScript(fn, ([1, 2, 3],))
        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))

    def test_union_subtractive_refinement(self):
        def fn(x: Union[List[int], int]) -> int:
            if not isinstance(x, int):
                x.append(1)
                return x[0]
            else:
                return x

        self.checkScript(fn, (1,))
        self.checkScript(fn, ([1, 2, 3],))

    def test_union_subtractive_refinement_with_container(self):
        def fn(x: Union[List[int], int]) -> int:
            if not torch.jit.isinstance(x, List[int]):
                return x
            else:
                x.append(1)
                return x[0]

        self.checkScript(fn, (1,))
        self.checkScript(fn, ([1, 2, 3],))

    def test_union_memory_aliasing(self):
        def fn():
            x: List[torch.Tensor] = []
            z: List[Optional[List[torch.Tensor]]] = []
            z.append(x)
            x_alias = z[0]
            if torch.jit.isinstance(x_alias, List[torch.Tensor]):
                x_alias.append(torch.tensor(3))
            return x

        self.checkScript(fn, ())

    def test_union_serialization_preserves_type_annotations(self):
        # This function will fail after being torch.jit.save'd and
        # torch.jit.load'd if the type annotations aren't preserved
        # for Union during serialization. We need the `Union[str, int]`
        # annotation to make sure that `y` is typed as a Union instead
        # of as a str in one branch and an int in the other
        def fn(x: int) -> str:
            if x % 2:
                y: Union[str, int] = "bar"
            else:
                y: Union[str, int] = x
            if isinstance(y, str):
                return y
            else:
                return "baz"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (8,))

    def _assert_passes(self, template: str, ann: str, lhs: str):
        code = template.format(ann=ann, lhs=lhs)
        self.checkScript(code, (), name="fn")

    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
        code = template.format(ann=ann, lhs=lhs)
        with self.assertRaisesRegex(RuntimeError, msg):
            cu = torch.jit.CompilationUnit(code, _frames_up=1)
            string_frontend = getattr(cu, "fn")  # noqa: B009

    def test_union_with_list_assignment(self):
        template = dedent(
            """
            def fn():
                x: {ann} = {lhs}
                if torch.jit.isinstance(x, List[torch.Tensor]):
                    x.append(torch.tensor(3))
                return x
        """
        )

        lhs = {
            "list_literal_empty": "[]",
            "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
            "list_literal_of_str": '["foo", "bar", "baz"]',
            "list_literal_of_mixed": "[torch.arange(5), 1]",
            "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
            "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
            "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
        }

        """
        Union[List[str], List[torch.Tensor]]
        """
        self._assert_raises(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["list_literal_empty"],
            "there are multiple possible List type "
            "candidates in the Union annotation",
        )

        self._assert_passes(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["list_literal_of_tensor"],
        )

        self._assert_passes(
            template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
        )

        self._assert_raises(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["list_literal_of_mixed"],
            "none of those types match the types of the" " given list elements",
        )

        self._assert_passes(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["list_comprehension_of_tensor"],
        )

        self._assert_passes(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["list_comprehension_of_str"],
        )

        # TODO: Support mixed list comprehensions
        self._assert_raises(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["list_comprehension_of_mixed"],
            "Arguments for call are not valid",
        )

        """
        Union[int, torch.Tensor]
        """
        self._assert_raises(
            template,
            "Union[int, torch.Tensor]",
            lhs["list_literal_empty"],
            "Expected an Union type annotation with an " "inner List type",
        )

        self._assert_raises(
            template,
            "Union[int, torch.Tensor]",
            lhs["list_literal_of_tensor"],
            "Expected an Union type annotation with an " "inner List type",
        )

        self._assert_raises(
            template,
            "Union[int, torch.Tensor]",
            lhs["list_comprehension_of_tensor"],
            "Expected an Union type annotation with an " "inner List type",
        )

        """
        Union[List[torch.Tensor], int]
        """
        self._assert_passes(
            template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
        )

        self._assert_passes(
            template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
        )

        self._assert_raises(
            template,
            "Union[List[torch.Tensor], int]",
            lhs["list_literal_of_str"],
            r"List type annotation `List\[Tensor\]` did "
            "not match the types of the given list "
            "elements",
        )

        self._assert_raises(
            template,
            "Union[List[torch.Tensor], int]",
            lhs["list_literal_of_mixed"],
            r"List type annotation `List\[Tensor\]` did "
            "not match the types of the given list "
            "elements",
        )

        self._assert_passes(
            template,
            "Union[List[torch.Tensor], int]",
            lhs["list_comprehension_of_tensor"],
        )

        self._assert_raises(
            template,
            "Union[List[torch.Tensor], int]",
            lhs["list_comprehension_of_str"],
            r"List type annotation `List\[Tensor\]` did "
            "not match the types of the given list "
            "elements",
        )

        # TODO(@ansley): Support mixed list comprehensions
        self._assert_raises(
            template,
            "Union[List[torch.Tensor], int]",
            lhs["list_comprehension_of_mixed"],
            "Arguments for call are not valid",
        )

    def test_union_with_dict_assignment(self):
        template = dedent(
            """
            def fn():
                x: {ann} = {lhs}
                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
                    x["foo"] = torch.tensor(3)
                return x
        """
        )

        lhs = {
            "dict_literal_empty": "{}",
            "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
            "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
            "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
            "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
                    zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
            "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
                    zip(["foo", "bar"], [1, 2]}',
            "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
                    zip(["foo", "bar"], [torch.arange(3), 2])}',
            "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
            "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
            "dict_keyword_with_empty_iterable": "dict([])",
            "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
            "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
            "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
        }

        """
        Union[Dict[str, torch.Tensor], Dict[str, int]]
        """
        self._assert_raises(
            template,
            "Union[List[str], List[torch.Tensor]]",
            lhs["dict_literal_empty"],
            "Expected an Union type annotation with an " "inner Dict type",
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_literal_of_str_tensor"],
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_literal_of_str_int"],
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_literal_of_mixed"],
            "none of those dict types can hold the "
            "types of the given keys and values",
        )

        # TODO: String frontend does not support tuple unpacking
        # https://github.com/pytorch/pytorch/issues/64096
        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #              lhs["dict_comprehension_of_str_tensor"])

        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #              lhs["dict_comprehension_of_str_int"])

        # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #              lhs["dict_comprehension_of_mixed"],
        #              "foobar")

        # self._assert_passes(template,
        #                    "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #                    lhs["dict_keyword_with_internal_aggregate_function"])

        # TODO(@ansley): Follow-up project needed for full type
        # inference with dict keyword (supported for dict comprehension
        # and dict literal already; should not be a blocker for anyone)
        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_keyword"],
            "full type inference is not yet supported",
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_keyword_with_iterable"],
            "full type inference is not yet supported",
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_keyword_with_empty_iterable"],
            "full type inference is not yet supported",
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_keyword_with_mapping"],
            "full type inference is not yet supported",
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
            lhs["dict_keyword_with_mapping_and_kwargs"],
            "full type inference is not yet supported",
        )

        """
        Union[int, torch.Tensor]
        """
        self._assert_raises(
            template,
            "Union[int, torch.Tensor]",
            lhs["dict_literal_empty"],
            "Expected an Union type annotation with " "an inner Dict type",
        )

        self._assert_raises(
            template,
            "Union[int, torch.Tensor]",
            lhs["dict_literal_of_str_tensor"],
            "Expected an Union type annotation with " "an inner Dict type",
        )

        # See above--string frontend does not support tuple unpacking
        # self._assert_raises(template, "Union[int, torch.Tensor]",
        #              lhs["dict_comprehension_of_tensor"],
        #              "foobar")

        """
        Union[Dict[str, torch.Tensor], int]
        """
        self._assert_passes(
            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_literal_of_str_tensor"],
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_literal_of_str_int"],
            "Type annotation was inferred to be "
            r"`Dict\[str, Tensor\]`, but the type of "
            "values given by the dict literal is",
        )

        self._assert_raises(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_literal_of_mixed"],
            "Type annotation was inferred to be "
            r"`Dict\[str, Tensor\]`, but the type of "
            "values given by the dict literal is",
        )

        self._assert_passes(
            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_keyword_with_iterable"],
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_keyword_with_empty_iterable"],
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_keyword_with_mapping"],
        )

        self._assert_passes(
            template,
            "Union[Dict[str, torch.Tensor], int]",
            lhs["dict_keyword_with_mapping_and_kwargs"],
        )

        # See above--string frontend does not support tuple unpacking
        # self._assert_passes(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_keyword_with_internal_aggregate_function"])
        #
        # self._assert_passes(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_comprehension_of_str_tensor"])

        # self._assert_raises(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_comprehension_of_str_int"],
        #                    "foobar")

        # self._assert_raises(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_comprehension_of_mixed"],
        #                    "foobar")
