# Owner(s): ["module: dynamo"]
import contextlib
import functools
import logging
import os
import unittest.mock

import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.distributed as dist
from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311
from torch._dynamo.trace_rules import _as_posix_path
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_utils import (
    find_free_port,
    munge_exc,
    skipIfTorchDynamo,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.logging_utils import (
    LoggingTestCase,
    make_logging_test,
    make_settings_test,
)


requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_distributed = functools.partial(
    unittest.skipIf, not dist.is_available(), "requires distributed"
)


def example_fn(a):
    output = a.mul(torch.ones(1000, 1000))
    output = output.add(torch.ones(1000, 1000))
    return output


def dynamo_error_fn(a):
    output = a.mul(torch.ones(1000, 1000))
    output = output.add(torch.ones(10, 10))
    return output


def inductor_error_fn(a):
    output = torch.round(a)
    return output


def inductor_schedule_fn(a):
    output = a.add(torch.ones(1000, 1000, device="cuda"))
    return output


ARGS = (torch.ones(1000, 1000, requires_grad=True),)


def multi_record_test(num_records, **kwargs):
    @make_logging_test(**kwargs)
    def fn(self, records):
        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
        fn_opt(*ARGS)
        self.assertEqual(len(records), num_records)

    return fn


def within_range_record_test(num_records_lower, num_records_higher, **kwargs):
    @make_logging_test(**kwargs)
    def fn(self, records):
        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
        fn_opt(*ARGS)
        self.assertGreaterEqual(len(records), num_records_lower)
        self.assertLessEqual(len(records), num_records_higher)

    return fn


def single_record_test(**kwargs):
    return multi_record_test(1, **kwargs)


class LoggingTests(LoggingTestCase):
    test_bytecode = multi_record_test(2, bytecode=True)
    test_output_code = multi_record_test(2, output_code=True)
    test_aot_graphs = multi_record_test(3, aot_graphs=True)

    @requires_cuda
    @make_logging_test(schedule=True)
    def test_schedule(self, records):
        fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
        fn_opt(torch.ones(1000, 1000, device="cuda"))
        self.assertGreater(len(records), 0)
        self.assertLess(len(records), 5)

    @requires_cuda
    @make_logging_test(fusion=True)
    def test_fusion(self, records):
        fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn)
        fn_opt(torch.ones(1000, 1000, device="cuda"))
        self.assertGreater(len(records), 0)
        self.assertLess(len(records), 8)

    @requires_cuda
    @make_logging_test(cudagraphs=True)
    def test_cudagraphs(self, records):
        fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn)
        fn_opt(torch.ones(1000, 1000, device="cuda"))
        self.assertGreater(len(records), 0)
        self.assertLess(len(records), 8)

    @make_logging_test(recompiles=True)
    def test_recompiles(self, records):
        def fn(x, y):
            return torch.add(x, y)

        fn_opt = torch._dynamo.optimize("inductor")(fn)
        fn_opt(torch.ones(1000, 1000), torch.ones(1000, 1000))
        fn_opt(torch.ones(1000, 1000), 1)
        self.assertGreater(len(records), 0)

    test_dynamo_debug = within_range_record_test(30, 90, dynamo=logging.DEBUG)
    test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO)

    @skipIfTorchDynamo("too slow")
    @make_logging_test(dynamo=logging.DEBUG)
    def test_dynamo_debug_default_off_artifacts(self, records):
        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
        fn_opt(torch.ones(1000, 1000))
        self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
        self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)

    @make_logging_test()
    def test_dynamo_error(self, records):
        try:
            fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
            fn_opt(*ARGS)
        except Exception:
            pass
        record = self.getRecord(records, "WON'T CONVERT")
        self.assertExpectedInline(
            munge_exc(record.getMessage()),
            """\
WON'T CONVERT dynamo_error_fn test_logging.py line N
due to:
Traceback (most recent call last):
torch._dynamo.exc.TorchRuntimeError: Failed running call_method add(*(FakeTensor(..., size=(1000, 1000), grad_fn=<MulBackward0>), FakeTensor(..., size=(10, 10))), **{}):
Attempting to broadcast a dimension of length 10 at -1! Mismatching argument at index 1 had torch.Size([10, 10]); but expected shape should be broadcastable to [1000, 1000]

from user code:
   File "test_logging.py", line N, in dynamo_error_fn
    output = output.add(torch.ones(10, 10))""",  # noqa: B950
        )

    test_aot = within_range_record_test(2, 6, aot=logging.INFO)
    test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG)
    test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)

    @make_logging_test()
    def test_inductor_error(self, records):
        exitstack = contextlib.ExitStack()
        import torch._inductor.lowering

        def throw(x):
            raise AssertionError

        # inject an error in the lowerings
        dict_entries = {}
        for x in list(torch._inductor.lowering.lowerings.keys()):
            if "round" in x.__name__:
                dict_entries[x] = throw

        exitstack.enter_context(
            unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries)
        )

        try:
            fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn)
            fn_opt(*ARGS)
        except Exception:
            pass
        record = self.getRecord(records, "WON'T CONVERT")
        self.assertExpectedInline(
            munge_exc(record.getMessage()),
            """\
WON'T CONVERT inductor_error_fn test_logging.py line N
due to:
Traceback (most recent call last):
  File "test_logging.py", line N, in throw
    raise AssertionError
torch._inductor.exc.LoweringException: AssertionError:
  target: aten.round.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))
  ))

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
  target: aten.round.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))
  ))""",
        )

        exitstack.close()

    @requires_distributed()
    @requires_cuda
    @make_logging_test(ddp_graphs=True)
    def test_ddp_graphs(self, records):
        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.layers = torch.nn.Sequential(
                    torch.nn.Linear(1024, 1024),
                    torch.nn.Linear(1024, 1024),
                )

            def forward(self, x):
                return self.layers(x)

        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        dist.init_process_group("gloo", rank=0, world_size=1)

        ddp_model = torch._dynamo.optimize("inductor")(
            DDP(ToyModel().to("cuda:0"), device_ids=[0], bucket_cap_mb=4)
        )

        ddp_model(torch.randn(1024, 1024, device="cuda:0"))

        dist.destroy_process_group()
        self.assertEqual(len([r for r in records if "__ddp_graphs" in r.name]), 4)

    # check that logging to a child log of a registered logger
    # does not register it and result in duplicated records
    @make_settings_test("torch._dynamo.output_graph")
    def test_open_registration_with_registered_parent(self, records):
        logger = logging.getLogger("torch._dynamo.output_graph")
        logger.info("hi")
        self.assertEqual(len(records), 1)

    # check logging to a random log that is not a child log of a registered
    # logger registers it and sets handlers properly
    @make_settings_test("torch.utils")
    def test_open_registration(self, records):
        logger = logging.getLogger("torch.utils")
        logger.info("hi")
        self.assertEqual(len(records), 1)

    # check logging to a random log that is not a child log of a registered
    # logger registers it and sets handlers properly
    @make_logging_test(modules={"torch.utils": logging.INFO})
    def test_open_registration_python_api(self, records):
        logger = logging.getLogger("torch.utils")
        logger.info("hi")
        self.assertEqual(len(records), 1)

    @make_logging_test(all=logging.DEBUG, dynamo=logging.INFO)
    def test_all(self, _):
        registry = torch._logging._internal.log_registry

        dynamo_qnames = registry.log_alias_to_log_qnames["dynamo"]
        for logger_qname in torch._logging._internal.log_registry.get_log_qnames():
            logger = logging.getLogger(logger_qname)

            # if logger_qname is a.b.c and dynamo_qnames contains a.b, it still matches dynamo's INFO setting
            if any(logger_qname.find(d) == 0 for d in dynamo_qnames):
                self.assertEqual(
                    logger.getEffectiveLevel(),
                    logging.INFO,
                    msg=f"expected {logger_qname} is INFO, got {logging.getLevelName(logger.getEffectiveLevel())}",
                )
            else:
                self.assertEqual(
                    logger.getEffectiveLevel(),
                    logging.DEBUG,
                    msg=f"expected {logger_qname} is DEBUG, got {logging.getLevelName(logger.getEffectiveLevel())}",
                )

    @make_logging_test(graph_breaks=True)
    def test_graph_breaks(self, records):
        @torch._dynamo.optimize("inductor")
        def fn(x):
            torch._dynamo.graph_break()
            return x + 1

        fn(torch.ones(1))

        self.assertEqual(len(records), 1)

    @make_settings_test("torch._dynamo.utils")
    def test_dump_compile_times(self, records):
        fn_opt = torch._dynamo.optimize("inductor")(example_fn)
        fn_opt(torch.ones(1000, 1000))
        # This function runs during exit via atexit.register.
        # We're not actually going to run atexit._run_exit_funcs() here,
        # because it'll destroy state necessary for other tests.
        torch._dynamo.utils.dump_compile_times()
        self.assertEqual(
            len(
                [r for r in records if "TorchDynamo compilation metrics" in str(r.msg)]
            ),
            1,
        )

    @make_logging_test(dynamo=logging.INFO)
    def test_custom_format_exc(self, records):
        dynamo_log = logging.getLogger(torch._dynamo.__name__)
        try:
            raise RuntimeError("foo")
        except RuntimeError:
            dynamo_log.exception("test dynamo")
            dynamo_log.info("with exc", exc_info=True)
        dynamo_log.info("with stack", stack_info=True)
        self.assertEqual(len(records), 3)
        # unfortunately there's no easy way to test the final formatted log other than
        # to ask the dynamo logger's handler to format it.
        for handler in dynamo_log.handlers:
            if torch._logging._internal._is_torch_handler(handler):
                break
        self.assertIsNotNone(handler)
        self.assertIn("Traceback", handler.format(records[0]))
        self.assertIn("Traceback", handler.format(records[1]))
        self.assertIn("Stack", handler.format(records[2]))

    @make_logging_test(dynamo=logging.INFO)
    def test_custom_format(self, records):
        dynamo_log = logging.getLogger(torch._dynamo.__name__)
        test_log = torch._logging.getArtifactLogger(
            torch._dynamo.__name__, "custom_format_test_artifact"
        )
        dynamo_log.info("test dynamo")
        test_log.info("custom format")
        self.assertEqual(len(records), 2)
        # unfortunately there's no easy way to test the final formatted log other than
        # to ask the dynamo logger's handler to format it.
        for handler in dynamo_log.handlers:
            if torch._logging._internal._is_torch_handler(handler):
                break
        self.assertIsNotNone(handler)
        self.assertIn("I", handler.format(records[0]))
        self.assertEqual("custom format", handler.format(records[1]))

    @make_logging_test(dynamo=logging.INFO)
    def test_multiline_format(self, records):
        dynamo_log = logging.getLogger(torch._dynamo.__name__)
        dynamo_log.info("test\ndynamo")
        dynamo_log.info("%s", "test\ndynamo")
        dynamo_log.info("test\n%s", "test\ndynamo")
        self.assertEqual(len(records), 3)
        # unfortunately there's no easy way to test the final formatted log other than
        # to ask the dynamo logger's handler to format it.
        for handler in dynamo_log.handlers:
            if torch._logging._internal._is_torch_handler(handler):
                break
        self.assertIsNotNone(handler)
        for record in records:
            r = handler.format(record)
            for l in r.splitlines():
                self.assertIn("I", l)

    test_trace_source_simple = within_range_record_test(1, 100, trace_source=True)

    @make_logging_test(trace_source=True)
    def test_trace_source_if_stmt(self, records):
        def fn(x):
            if x.sum() > 0:
                return x * 2
            return x * 3

        fn_opt = torch._dynamo.optimize("eager")(fn)
        fn_opt(torch.ones(3, 3))

        found_x2 = False
        found_x3 = False
        for record in records:
            msg = record.getMessage()
            if "return x * 2" in msg:
                found_x2 = True
            if "return x * 3" in msg:
                found_x3 = True

        self.assertTrue(found_x2)
        self.assertFalse(found_x3)

    @make_logging_test(trace_source=True)
    def test_trace_source_nested(self, records):
        def fn1(x):
            x = fn2(x)
            return x * 2

        def fn2(x):
            x = fn3(x)
            return x * 3

        def fn3(x):
            return x * 4

        fn_opt = torch._dynamo.optimize("eager")(fn1)
        fn_opt(torch.ones(3, 3))

        found_x2 = False
        found_x3 = False
        found_x4 = False
        for record in records:
            msg = record.getMessage()
            if "return x * 2" in msg:
                found_x2 = True
                self.assertNotIn("inline depth", msg)
            elif "return x * 3" in msg:
                found_x3 = True
                self.assertIn("inline depth: 1", msg)
            elif "return x * 4" in msg:
                found_x4 = True
                self.assertIn("inline depth: 2", msg)
        self.assertTrue(found_x2)
        self.assertTrue(found_x3)
        self.assertTrue(found_x4)

    @make_logging_test(trace_source=True)
    def test_trace_source_cond(self, records):
        from functorch.experimental.control_flow import cond

        def true_fn(x):
            return x * 2

        def false_fn(x):
            return x * 3

        def inner(pred, x):
            return cond(pred, true_fn, false_fn, [x])

        def outer(pred, x):
            return inner(pred, x)

        fn_opt = torch._dynamo.optimize("eager")(outer)
        fn_opt(torch.tensor(True), torch.ones(3, 3))

        found_x2 = False
        found_x3 = False
        for record in records:
            msg = record.getMessage()
            if "return x * 2" in msg:
                found_x2 = True
                self.assertIn("inline depth: 3", msg)
            if "return x * 3" in msg:
                found_x3 = True
                self.assertIn("inline depth: 3", msg)

        self.assertTrue(found_x2)
        self.assertTrue(found_x3)

    @make_logging_test(trace_source=True)
    def test_trace_source_funcname(self, records):
        # NOTE: list comprehensions are inlined in 3.12, so test with tuples
        def fn1():
            def fn2():
                if True:
                    return tuple(torch.ones(3, 3) for _ in range(5))
                return None

            return fn2()

        fn_opt = torch._dynamo.optimize("eager")(fn1)
        fn_opt()

        found_funcname = False
        for record in records:
            msg = record.getMessage()
            if "<genexpr>" in msg and "fn1.fn2" in msg:
                found_funcname = True

        self.assertTrue(found_funcname)

    def test_invalid_artifact_flag(self):
        with self.assertRaises(ValueError):
            torch._logging.set_logs(aot_graphs=5)

    @requires_distributed()
    def test_distributed_rank_logging(self):
        env = dict(os.environ)
        env["TORCH_LOGS"] = "dynamo"
        stdout, stderr = self.run_process_no_exception(
            """\
import torch.distributed as dist
import logging
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group("fake", rank=0, world_size=2, store=store)
dynamo_log = logging.getLogger("torch._dynamo")
dynamo_log.info("woof")
print("arf")
""",
            env=env,
        )
        self.assertIn("[rank0]:", stderr.decode("utf-8"))

    @skipIfNotPy311
    @make_logging_test(trace_call=True)
    def test_trace_call(self, records):
        def fn(x, y):
            return (x * 2) @ (y * 3)

        fn_opt = torch._dynamo.optimize("eager")(fn)
        fn_opt(torch.randn(10, 20), torch.randn(20, 30))

        self.assertEqual(len(records), 3)
        # only get last 2 lines
        messages = [
            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
        ]
        self.assertExpectedInline(
            messages[0],
            """\
            return (x * 2) @ (y * 3)
                    ~~^~~""",
        )
        self.assertExpectedInline(
            messages[1],
            """\
            return (x * 2) @ (y * 3)
                              ~~^~~""",
        )
        self.assertExpectedInline(
            messages[2],
            """\
            return (x * 2) @ (y * 3)
                   ~~~~~~~~^~~~~~~~~""",
        )

    @skipIfNotPy311
    @make_logging_test(trace_call=True)
    def test_trace_call_inline_call(self, records):
        def g(x):
            return x * 2

        def f(x):
            return g(g(x))

        fn_opt = torch._dynamo.optimize("eager")(f)
        fn_opt(torch.randn(3, 3))

        self.assertEqual(len(records), 4)
        messages = [
            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
        ]
        self.assertExpectedInline(
            messages[0],
            """\
            return g(g(x))
                     ~^^^""",
        )
        self.assertExpectedInline(
            messages[1],
            """\
            return x * 2
                   ~~^~~""",
        )
        self.assertExpectedInline(
            messages[2],
            """\
            return g(g(x))
                   ~^^^^^^""",
        )
        self.assertExpectedInline(
            messages[3],
            """\
            return x * 2
                   ~~^~~""",
        )

    @skipIfNotPy311
    @make_logging_test(trace_call=True)
    def test_trace_call_graph_break(self, records):
        def fn(x):
            x = x * 2
            torch._dynamo.graph_break()
            return x * 3

        fn_opt = torch._dynamo.optimize("eager")(fn)
        fn_opt(torch.randn(3, 3))

        self.assertEqual(len(records), 3)
        messages = [
            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
        ]
        self.assertExpectedInline(
            messages[0],
            """\
            x = x * 2
                ~~^~~""",
        )
        self.assertExpectedInline(
            messages[-1],
            """\
            return x * 3
                   ~~^~~""",
        )

    @make_logging_test(guards=True, recompiles=True)
    def test_guards_recompiles(self, records):
        def fn(x, ys, zs):
            return inner(x, ys, zs)

        def inner(x, ys, zs):
            for y, z in zip(ys, zs):
                x += y * z
            return x

        ys = [1.0, 2.0]
        zs = [3.0]
        x = torch.tensor([1.0])

        fn_opt = torch._dynamo.optimize("eager")(fn)
        fn_opt(x, ys, zs)
        fn_opt(x, ys[:1], zs)

        record_str = "\n".join(r.getMessage() for r in records)

        self.assertIn(
            """L['zs'][0] == 3.0""",
            record_str,
        )
        self.assertIn(
            "len(L['ys']) == 2",
            record_str,
        )

    @make_logging_test(cudagraph_static_inputs=True)
    def test_cudagraph_static_inputs(self, records):
        @torch.compile(mode="reduce-overhead")
        def fn(x):
            return x + 1

        x = torch.ones(2, 2)
        torch._dynamo.mark_static_address(x)
        fn(x)
        self.assertGreater(len(records), 0)
        self.assertLess(len(records), 4)

    @skipIfTorchDynamo("too slow")
    @make_logging_test(**torch._logging.DEFAULT_LOGGING)
    def test_default_logging(self, records):
        def fn(a):
            if a.sum() < 0:
                a = torch.sin(a)
            else:
                a = torch.cos(a)
            print("hello")
            return a + 1

        fn_opt = torch._dynamo.optimize("eager")(fn)
        fn_opt(torch.ones(10, 10))
        fn_opt(-torch.ones(10, 5))

        self.assertGreater(len([r for r in records if ".__graph_breaks" in r.name]), 0)
        self.assertGreater(len([r for r in records if ".__recompiles" in r.name]), 0)
        self.assertGreater(len([r for r in records if ".symbolic_shapes" in r.name]), 0)
        self.assertGreater(len([r for r in records if ".__guards" in r.name]), 0)
        self.assertGreater(
            len([r for r in records if "return a + 1" in r.getMessage()]), 0
        )

    def test_logs_out(self):
        import tempfile

        with tempfile.NamedTemporaryFile(delete=False) as tmp:
            file_path = _as_posix_path(tmp.name)
            """
            NamedTemporaryFile will include a file open operation.
            On Windowsm the file is opened by NamedTemporaryFile, the
            following run_process_no_exception can't access a opened file.
            And then, raise a PermissionError: [Errno 13] Permission denied: [file_path]
            """
            tmp.close()
            env = dict(os.environ)
            env["TORCH_LOGS"] = "dynamo"
            env["TORCH_LOGS_OUT"] = file_path
            stdout, stderr = self.run_process_no_exception(
                """\
import torch
@torch.compile(backend="eager")
def fn(a):
    return a.sum()

fn(torch.randn(5))
                """,
                env=env,
            )
            with open(
                file_path, encoding="utf-8"
            ) as fd:  # encoding file to UTF-8 for Windows.
                lines = fd.read()
                fd.close()
                os.remove(
                    file_path
                )  # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
                self.assertEqual(  # process wrap difference: /r/n on Windows, /n on posix.
                    empty_line_normalizer(lines),
                    empty_line_normalizer(stderr.decode("utf-8")),
                )


# single record tests
exclusions = {
    "bytecode",
    "cudagraphs",
    "output_code",
    "schedule",
    "fusion",
    "overlap",
    "aot_graphs",
    "aot_graphs_effects",
    "post_grad_graphs",
    "compiled_autograd",
    "compiled_autograd_verbose",
    "recompiles",
    "recompiles_verbose",
    "graph_breaks",
    "graph",
    "graph_code",
    "graph_sizes",
    "ddp_graphs",
    "perf_hints",
    "not_implemented",
    "trace_source",
    "trace_call",
    "trace_bytecode",
    "custom_format_test_artifact",
    "onnx",
    "onnx_diagnostics",
    "guards",
    "verbose_guards",
    "sym_node",
    "export",
    "trace_shape_events",
    "cudagraph_static_inputs",
    "benchmarking",
    "loop_ordering",
}
for name in torch._logging._internal.log_registry.artifact_names:
    if name not in exclusions:
        setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True}))

if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
