# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_var_decomposition(op) -> tuple:
    if op == exir_ops.edge.aten.var.correction:
        return (
            exir_ops.edge.aten.mean.dim,
            exir_ops.edge.aten.sub.Tensor,
            exir_ops.edge.aten.mul.Tensor,
            exir_ops.edge.aten.sum.dim_IntList,
            exir_ops.edge.aten.full.default,
        )
    if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim):
        return (
            torch.ops.aten.mean.dim,
            torch.ops.aten.sub.Tensor,
            torch.ops.aten.mul.Tensor,
            torch.ops.aten.sum.dim_IntList,
            torch.ops.aten.full,
        )
    raise RuntimeError(f"Can't get var decomposition for op {op}")


class DecomposeVarPass(ExportPass):
    """
    This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)

    Example:
        y = var_correction(x, dim, keepdim, correction)
    Becomes:
        mean = mean(x, dim)
        diff = sub(x, mean)
        squared_diff = mul(diff, diff)
        sum = sum(squared_diff, dim)
        y = div(sum, max(0, N-correction))
    """

    def call_operator(self, op, args, kwargs, meta):
        if op not in (
            exir_ops.edge.aten.var.correction,
            torch.ops.aten.var.correction,
            torch.ops.aten.var.dim,
        ):
            return super().call_operator(op, args, kwargs, meta)
        shape = meta["val"].size()
        dtype = meta["val"].dtype
        dim = args[1] if len(args) > 1 else list(range(len(shape)))
        if op == torch.ops.aten.var.dim:
            correction = args[-2]
            keepdim = args[-1]
        else:
            correction = kwargs["correction"]
            keepdim = kwargs.get("keepdim", False)
        if not keepdim:
            return super().call_operator(op, args, kwargs, meta)

        x = args[0]
        input_shape = x.data.size()
        N = 1
        for d in dim:
            N *= input_shape[d]

        mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
        mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
        diff = super().call_operator(diff_op, (x, mean), {}, meta)
        squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
        sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
        full = super().call_operator(
            full_op,
            ([1 for _ in shape], 1 / max(0, N - correction)),
            {"dtype": dtype},
            meta,
        )
        return super().call_operator(mul_op, (sum, full), {}, meta)
