# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-unsafe


# This file contains all the functions that simplify args of an op

import sys
from typing import Optional

from executorch.backends.cadence.aot.pass_utils import (
    CadencePassAttribute,
    register_cadence_pass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class SimplifySliceOpPass(ExportPass):
    """
    Simplify the start and end indices of slice and slice_scatter ops.
    """

    def adjust_slice_range(
        self,
        length: int,
        start: Optional[int] = None,
        end: Optional[int] = None,
        step: int = 1,
    ) -> tuple[int, int]:
        # Get the start index and end index
        start_val = start if start is not None else 0
        end_val = end if end is not None else sys.maxsize  # 2^63 – 1

        # If start_val and end_val are negative, add length to them
        if start_val < 0:
            start_val += length
        if end_val < 0:
            end_val += length

        # If the start val is still outside the tensor_size along the sliced
        # dimension, adjust it accordingly.
        if start_val < 0:
            start_val = 0
        elif start_val >= length:
            start_val = length

        # If the end val is still outside the tensor_size along the sliced
        # dimension, adjust it accordingly.
        if end_val < start_val:
            end_val = start_val
        elif end_val >= length:
            end_val = length

        # Return the adjusted start and end indices
        return (start_val, end_val)

    def call_operator(self, op, args, kwargs, meta):
        # We are only interested in slice_copy or slice_scatter ops
        if op not in {
            exir_ops.edge.aten.slice_copy.Tensor,
            exir_ops.edge.aten.slice_scatter.default,
        }:
            return super().call_operator(op, args, kwargs, meta)

        # Check if it is a slice_scatter op or not. The slice_scatter op has
        # an extra src argument at index 1.
        slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
        # Parse the arguments
        # Extract the tensor to be sliced, and the slicing dimension
        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
        dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
        # Make dim non-negative
        dim = dim if dim >= 0 else dim + in_tensor.dim()
        length = in_tensor.size(dim)

        # Get the adjusted start and end indices
        start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None
        end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None
        step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1
        (start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step)

        # If the start_val is geq end_val, then we can return an empty tensor
        # for slice op, or input for slice_scatter op.
        if start_val >= end_val and slice_scatter:
            return args[0]
        if start_val >= end_val:
            empty_shape = [x for x in in_tensor.shape if x != 0]
            empty_shape[dim] = 0
            return super().call_operator(
                exir_ops.edge.aten.full.default,
                (tuple(empty_shape), 0),
                {"dtype": in_tensor.dtype},
                meta,
            )

        # Create new args
        new_args = (
            (args[0],)
            + ((args[1],) if slice_scatter else ())
            + (dim, start_val, end_val, step)
        )
        return super().call_operator(op, new_args, kwargs, meta)


# This class encapsulates all the functions that simplify the op's args
class CadenceSimplifyOpsInGraph:
    passes = [
        SimplifySliceOpPass,
    ]
