# Owner(s): ["oncall: jit"]

import os
import sys
from typing import List, Tuple

import torch


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestHash(JitTestCase):
    def test_hash_tuple(self):
        def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
            return hash(t1) == hash(t2)

        self.checkScript(fn, ((1, 2), (1, 2)))
        self.checkScript(fn, ((1, 2), (3, 4)))
        self.checkScript(fn, ((1, 2), (2, 1)))

    def test_hash_tuple_nested_unhashable_type(self):
        # Tuples may contain unhashable types like `list`, check that we error
        # properly in that case.
        @torch.jit.script
        def fn_unhashable(t1: Tuple[int, List[int]]):
            return hash(t1)

        with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"):
            fn_unhashable((1, [1]))

    def test_hash_tensor(self):
        """Tensors should hash by identity"""

        def fn(t1, t2):
            return hash(t1) == hash(t2)

        tensor1 = torch.tensor(1)
        tensor1_clone = torch.tensor(1)
        tensor2 = torch.tensor(2)

        self.checkScript(fn, (tensor1, tensor1))
        self.checkScript(fn, (tensor1, tensor1_clone))
        self.checkScript(fn, (tensor1, tensor2))

    def test_hash_none(self):
        def fn():
            n1 = None
            n2 = None
            return hash(n1) == hash(n2)

        self.checkScript(fn, ())

    def test_hash_bool(self):
        def fn(b1: bool, b2: bool):
            return hash(b1) == hash(b2)

        self.checkScript(fn, (True, False))
        self.checkScript(fn, (True, True))
        self.checkScript(fn, (False, True))
        self.checkScript(fn, (False, False))

    def test_hash_float(self):
        def fn(f1: float, f2: float):
            return hash(f1) == hash(f2)

        self.checkScript(fn, (1.2345, 1.2345))
        self.checkScript(fn, (1.2345, 6.789))
        self.checkScript(fn, (1.2345, float("inf")))
        self.checkScript(fn, (float("inf"), float("inf")))
        self.checkScript(fn, (1.2345, float("nan")))
        if sys.version_info < (3, 10):
            # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
            # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
            self.checkScript(fn, (float("nan"), float("nan")))
        self.checkScript(fn, (float("nan"), float("inf")))

    def test_hash_int(self):
        def fn(i1: int, i2: int):
            return hash(i1) == hash(i2)

        self.checkScript(fn, (123, 456))
        self.checkScript(fn, (123, 123))
        self.checkScript(fn, (123, -123))
        self.checkScript(fn, (-123, -123))
        self.checkScript(fn, (123, 0))

    def test_hash_string(self):
        def fn(s1: str, s2: str):
            return hash(s1) == hash(s2)

        self.checkScript(fn, ("foo", "foo"))
        self.checkScript(fn, ("foo", "bar"))
        self.checkScript(fn, ("foo", ""))

    def test_hash_device(self):
        def fn(d1: torch.device, d2: torch.device):
            return hash(d1) == hash(d2)

        gpu0 = torch.device("cuda:0")
        gpu1 = torch.device("cuda:1")
        cpu = torch.device("cpu")
        self.checkScript(fn, (gpu0, gpu0))
        self.checkScript(fn, (gpu0, gpu1))
        self.checkScript(fn, (gpu0, cpu))
        self.checkScript(fn, (cpu, cpu))
