# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class CastInt64ToInt32Pass(ExportPass):
    def __init__(self, exported_program: torch.export.ExportedProgram):
        super(CastInt64ToInt32Pass, self).__init__()
        self.exported_program = exported_program

    def _to_int32(self, graph_module: torch.fx.GraphModule):
        for node in graph_module.graph.nodes:
            fake_tensor = node.meta["val"]
            if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
                if node.meta["val"].dtype == torch.int64:
                    node.meta["val"] = node.meta["val"].to(torch.int32)
                    buffer_name = (
                        self.exported_program.graph_signature.inputs_to_buffers[
                            node.name
                        ]
                    )
                    new_tensor = self.exported_program.state_dict[buffer_name].to(
                        torch.int32
                    )
                    self.exported_program.state_dict[buffer_name] = new_tensor

    def call(self, graph_module: torch.fx.GraphModule):
        self._to_int32(graph_module)
        graph_module.recompile()
        graph_module = super().call(graph_module).graph_module
        return PassResult(graph_module, True)
