#!/usr/bin/env python3
# Owner(s): ["oncall: mobile"]

import os
import io
import functools
import tempfile
import urllib
import unittest

import torch
import torch.backends.xnnpack
import torch.utils.model_dump
import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
from torch.testing._internal.common_quantized import supported_qengines


class SimpleModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = torch.nn.Linear(16, 64)
        self.relu1 = torch.nn.ReLU()
        self.layer2 = torch.nn.Linear(64, 8)
        self.relu2 = torch.nn.ReLU()

    def forward(self, features):
        act = features
        act = self.layer1(act)
        act = self.relu1(act)
        act = self.layer2(act)
        act = self.relu2(act)
        return act


class QuantModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
        self.core = SimpleModel()

    def forward(self, x):
        x = self.quant(x)
        x = self.core(x)
        x = self.dequant(x)
        return x


class ModelWithLists(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.rt = [torch.zeros(1)]
        self.ot = [torch.zeros(1), None]

    def forward(self, arg):
        arg = arg + self.rt[0]
        o = self.ot[0]
        if o is not None:
            arg = arg + o
        return arg


def webdriver_test(testfunc):
    @functools.wraps(testfunc)
    def wrapper(self, *args, **kwds):
        self.needs_resources()

        if os.environ.get("RUN_WEBDRIVER") != "1":
            self.skipTest("Webdriver not requested")
        from selenium import webdriver

        for driver in [
                "Firefox",
                "Chrome",
        ]:
            with self.subTest(driver=driver):
                wd = getattr(webdriver, driver)()
                testfunc(self, wd, *args, **kwds)
                wd.close()

    return wrapper


class TestModelDump(TestCase):
    def needs_resources(self):
        pass

    def test_inline_skeleton(self):
        self.needs_resources()
        skel = torch.utils.model_dump.get_inline_skeleton()
        assert "unpkg.org" not in skel
        assert "src=" not in skel

    def do_dump_model(self, model, extra_files=None):
        # Just check that we're able to run successfully.
        buf = io.BytesIO()
        torch.jit.save(model, buf, _extra_files=extra_files)
        info = torch.utils.model_dump.get_model_info(buf)
        assert info is not None

    def open_html_model(self, wd, model, extra_files=None):
        buf = io.BytesIO()
        torch.jit.save(model, buf, _extra_files=extra_files)
        page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
        wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))

    def open_section_and_get_body(self, wd, name):
        container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
        caret = container.find_element_by_class_name("caret")
        if container.get_attribute("data-shown") != "true":
            caret.click()
        content = container.find_element_by_tag_name("div")
        return content

    def test_scripted_model(self):
        model = torch.jit.script(SimpleModel())
        self.do_dump_model(model)

    def test_traced_model(self):
        model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
        self.do_dump_model(model)

    def test_main(self):
        self.needs_resources()
        if IS_WINDOWS:
            # I was getting tempfile errors in CI.  Just skip it.
            self.skipTest("Disabled on Windows.")

        with tempfile.NamedTemporaryFile() as tf:
            torch.jit.save(torch.jit.script(SimpleModel()), tf)
            # Actually write contents to disk so we can read it below
            tf.flush()

            stdout = io.StringIO()
            torch.utils.model_dump.main(
                [
                    None,
                    "--style=json",
                    tf.name,
                ],
                stdout=stdout)
            self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')

            stdout = io.StringIO()
            torch.utils.model_dump.main(
                [
                    None,
                    "--style=html",
                    tf.name,
                ],
                stdout=stdout)
            self.assertRegex(
                stdout.getvalue().replace("\n", " "),
                r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')

    def get_quant_model(self):
        fmodel = QuantModel().eval()
        fmodel = torch.ao.quantization.fuse_modules(fmodel, [
            ["core.layer1", "core.relu1"],
            ["core.layer2", "core.relu2"],
        ])
        fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
        prepped = torch.ao.quantization.prepare(fmodel)
        prepped(torch.randn(2, 16))
        qmodel = torch.ao.quantization.convert(prepped)
        return qmodel

    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
    def test_quantized_model(self):
        qmodel = self.get_quant_model()
        self.do_dump_model(torch.jit.script(qmodel))

    @skipIfNoXNNPACK
    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
    def test_optimized_quantized_model(self):
        qmodel = self.get_quant_model()
        smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
        omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
        self.do_dump_model(omodel)

    def test_model_with_lists(self):
        model = torch.jit.script(ModelWithLists())
        self.do_dump_model(model)

    def test_invalid_json(self):
        model = torch.jit.script(SimpleModel())
        self.do_dump_model(model, extra_files={"foo.json": "{"})

    @webdriver_test
    def test_memory_computation(self, wd):
        def check_memory(model, expected):
            self.open_html_model(wd, model)
            memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
            device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
            self.assertEqual("cpu", device)
            memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
            self.assertEqual(expected, int(memory_usage_str))

        simple_model_memory = (
            # First layer, including bias.
            64 * (16 + 1) +
            # Second layer, including bias.
            8 * (64 + 1)
            # 32-bit float
        ) * 4

        check_memory(torch.jit.script(SimpleModel()), simple_model_memory)

        # The same SimpleModel instance appears twice in this model.
        # The tensors will be shared, so ensure no double-counting.
        a_simple_model = SimpleModel()
        check_memory(
            torch.jit.script(
                torch.nn.Sequential(a_simple_model, a_simple_model)),
            simple_model_memory)

        # The freezing process will move the weight and bias
        # from data to constants.  Ensure they are still counted.
        check_memory(
            torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
            simple_model_memory)

        # Make sure we can handle a model with both constants and data tensors.
        class ComposedModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w1 = torch.zeros(1, 2)
                self.w2 = torch.ones(2, 2)

            def forward(self, arg):
                return arg * self.w2 + self.w1

        check_memory(
            torch.jit.freeze(
                torch.jit.script(ComposedModule()).eval(),
                preserved_attrs=["w1"]),
            4 * (2 + 4))


if __name__ == '__main__':
    run_tests()
