from pathlib import Path

import torch
from torch.fx import symbolic_trace
from torch.package import PackageExporter
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE


packaging_directory = f"{Path(__file__).parent}/package_bc"
torch.package.package_exporter._gate_torchscript_serialization = False


def generate_bc_packages():
    """Function to create packages for testing backwards compatiblity"""
    if not IS_FBCODE or IS_SANDCASTLE:
        from package_a.test_nn_module import TestNnModule

        test_nn_module = TestNnModule()
        test_torchscript_module = torch.jit.script(TestNnModule())
        test_fx_module: torch.fx.GraphModule = symbolic_trace(TestNnModule())
        with PackageExporter(f"{packaging_directory}/test_nn_module.pt") as pe1:
            pe1.intern("**")
            pe1.save_pickle("nn_module", "nn_module.pkl", test_nn_module)
        with PackageExporter(
            f"{packaging_directory}/test_torchscript_module.pt"
        ) as pe2:
            pe2.intern("**")
            pe2.save_pickle(
                "torchscript_module", "torchscript_module.pkl", test_torchscript_module
            )
        with PackageExporter(f"{packaging_directory}/test_fx_module.pt") as pe3:
            pe3.intern("**")
            pe3.save_pickle("fx_module", "fx_module.pkl", test_fx_module)


if __name__ == "__main__":
    generate_bc_packages()
