# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Callable, List, Optional, Tuple

import torch
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
from executorch.exir.passes.replace_aten_with_edge_pass import (
    aten_to_edge,
    should_lower_to_edge,
)
from torch import fx
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from torch.library import impl, register_fake


__all__ = [
    "get_quant_patterns_and_replacements",
]

# TODO: extending an existing library that is defined in OSS might be a bit
# confusing, we can investigate if it is possible to define a new library

quantized_decomposed_lib.define(
    "embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
)

quantized_decomposed_lib.define(
    "embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
)

quantized_decomposed_lib.define(
    "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
)

quantized_decomposed_lib.define(
    "embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
)


def embedding_weight_checks(weight, weight_scales, weight_zero_points):
    assert weight.dtype in [
        torch.int8,
        torch.uint8,
    ], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}"
    assert (
        weight.dim() == 2
    ), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}"

    assert weight_scales.dtype in [
        torch.float16,
        torch.float32,
    ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}"
    assert (
        weight_scales.dim() == 1 or weight_scales.dim() == 2
    ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}"
    assert weight_scales.size(0) == weight.size(
        0
    ), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}"

    assert (
        weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
    ), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
    assert (
        weight_zero_points is None or weight_zero_points.dim() == 1
    ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
    assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
        0
    ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"


@impl(quantized_decomposed_lib, "embedding_byte", "CompositeExplicitAutograd")
def embedding_byte(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
) -> torch.Tensor:
    embedding_weight_checks(weight, weight_scales, weight_zero_points)
    group_size = weight.size(1) // (
        weight_scales.size(1) if weight_scales.dim() == 2 else 1
    )
    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        weight.dtype,
        group_size,
        weight_scales.dtype,
    )
    return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_byte.out")
def embedding_byte_out_meta(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    out: torch.Tensor,
) -> torch.Tensor:
    return embedding_byte(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        indices,
    )


@impl(quantized_decomposed_lib, "embedding_byte.dtype", "CompositeExplicitAutograd")
def embedding_byte_dtype(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    dtype: Optional[torch.dtype],
) -> torch.Tensor:
    embedding_weight_checks(weight, weight_scales, weight_zero_points)
    group_size = weight.size(1) // (
        weight_scales.size(1) if weight_scales.dim() == 2 else 1
    )
    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        weight.dtype,
        group_size,
        dtype,
    )
    return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_byte.dtype_out")
def embedding_byte_dtype_out_meta(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    dtype: Optional[torch.dtype],
    out: torch.Tensor,
) -> torch.Tensor:
    return embedding_byte_dtype(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        indices,
        dtype,
    )


quantized_decomposed_lib.define(
    "embedding_2bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
)

quantized_decomposed_lib.define(
    "embedding_2bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
)

quantized_decomposed_lib.define(
    "embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
)

quantized_decomposed_lib.define(
    "embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
)


@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd")
def embedding_2bit(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
) -> torch.Tensor:
    embedding_weight_checks(weight, weight_scales, weight_zero_points)
    group_size = (4 * weight.size(1)) // (
        weight_scales.size(1) if weight_scales.dim() == 2 else 1
    )
    weight_0 = weight & 3
    weight_1 = (weight & 12) >> 2
    weight_2 = (weight & 48) >> 4
    weight_3 = (weight & 192) >> 6
    weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1)
    weight = weight_unpacked.view(weight.shape[0], -1)
    weight = weight.view(torch.int8).add(-2)

    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        weight.dtype,
        group_size,
        weight_scales.dtype,
    )
    return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_2bit.out")
def embedding_2bit_out_meta(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    out: torch.Tensor,
) -> torch.Tensor:
    return embedding_2bit(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        indices,
    )


@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd")
def embedding_2bit_dtype(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    dtype: Optional[torch.dtype],
) -> torch.Tensor:
    embedding_weight_checks(weight, weight_scales, weight_zero_points)
    group_size = (4 * weight.size(1)) // (
        weight_scales.size(1) if weight_scales.dim() == 2 else 1
    )
    weight_0 = weight & 3
    weight_1 = (weight & 12) >> 2
    weight_2 = (weight & 48) >> 4
    weight_3 = (weight & 192) >> 6
    weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1)
    weight = weight_unpacked.view(weight.shape[0], -1)
    weight = weight.view(torch.int8).add(-2)

    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        weight.dtype,
        group_size,
        dtype,
    )
    return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
def embedding_2bit_dtype_out_meta(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    dtype: Optional[torch.dtype],
    out: torch.Tensor,
) -> torch.Tensor:
    return embedding_2bit_dtype(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        indices,
        dtype,
    )


quantized_decomposed_lib.define(
    "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
)

quantized_decomposed_lib.define(
    "embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
)

quantized_decomposed_lib.define(
    "embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
)

quantized_decomposed_lib.define(
    "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
)


@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd")
def embedding_4bit(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
) -> torch.Tensor:
    embedding_weight_checks(weight, weight_scales, weight_zero_points)
    group_size = (2 * weight.size(1)) // (
        weight_scales.size(1) if weight_scales.dim() == 2 else 1
    )
    weight_even = weight.div(16, rounding_mode="trunc")
    weight_odd = weight.remainder(16)
    weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
    weight = weight_unpacked.view(weight.shape[0], -1)
    weight = weight.view(torch.int8).add(-8)

    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        weight.dtype,
        group_size,
        weight_scales.dtype,
    )
    return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_4bit.out")
def embedding_4bit_out_meta(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    out: torch.Tensor,
) -> torch.Tensor:
    return embedding_4bit(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        indices,
    )


@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd")
def embedding_4bit_dtype(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    dtype: Optional[torch.dtype],
) -> torch.Tensor:
    embedding_weight_checks(weight, weight_scales, weight_zero_points)
    group_size = (2 * weight.size(1)) // (
        weight_scales.size(1) if weight_scales.dim() == 2 else 1
    )
    weight_even = weight.div(16, rounding_mode="trunc")
    weight_odd = weight.remainder(16)
    weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
    weight = weight_unpacked.view(weight.shape[0], -1)
    weight = weight.view(torch.int8).add(-8)

    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        weight.dtype,
        group_size,
        dtype,
    )
    return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_4bit.dtype_out")
def embedding_4bit_dtype_out_meta(
    weight: torch.Tensor,
    weight_scales: torch.Tensor,
    weight_zero_points: Optional[torch.Tensor],
    weight_quant_min: int,
    weight_quant_max: int,
    indices: torch.Tensor,
    dtype: Optional[torch.dtype],
    out: torch.Tensor,
) -> torch.Tensor:
    return embedding_4bit_dtype(
        weight,
        weight_scales,
        weight_zero_points,
        weight_quant_min,
        weight_quant_max,
        indices,
        dtype,
    )


quantized_decomposed_lib.define(
    "mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor",
)

quantized_decomposed_lib.define(
    "mixed_linear(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, ScalarType? dtype=None) -> Tensor",
)

quantized_decomposed_lib.define(
    "add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc"
)

quantized_decomposed_lib.define(
    "add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor"
)

quantized_decomposed_lib.define(
    "add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc"
)


def _trace_and_lower_to_edge_ops(f: Callable) -> fx.GraphModule:
    gm = fx.symbolic_trace(f)
    for node in gm.graph.nodes:
        if node.op == "call_function" and should_lower_to_edge(node.target):
            node.target = aten_to_edge(node.target)
    gm.recompile()
    return gm


def _sixth_input_is_scalar(match, original_graph, pattern_graph):
    """check the node that's matched to the sixth input of the pattern graph

    is a scalar number
    """
    input_idx = 0
    for node in pattern_graph.nodes:
        if node.op == "placeholder":
            if input_idx == 5:
                num_node = node
            input_idx += 1
    if not isinstance(match.nodes_map[num_node], (int, float)):
        return False
    return True


def _get_binary_op_patterns_and_replacements(
    binary_op: Callable,
    qbinary_op: Callable,
    qbinary_scalar_op: Callable,
    qbinary_relu_op: Callable,
) -> List[Tuple[Callable, Callable]]:
    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_op.name())
    def binary_op_pattern(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        y,
        y_scale,
        y_zero_point,
        y_qmin,
        y_qmax,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )
        y = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            y, y_scale, y_zero_point, y_qmin, y_qmax, torch.uint8
        )

        out = binary_op(x, y)
        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
        )

        return out

    def binary_op_replacement(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        y,
        y_scale,
        y_zero_point,
        y_qmin,
        y_qmax,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        out = qbinary_op(
            x,
            x_scale,
            x_zero_point,
            x_qmin,
            x_qmax,
            y,
            y_scale,
            y_zero_point,
            y_qmin,
            y_qmax,
            out_scale,
            out_zero_point,
            out_qmin,
            out_qmax,
        )

        return out

    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_scalar_op.name())
    def binary_op_scalar_1_pattern(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        num,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )

        out = binary_op(x, num)
        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
        )

        return out

    def binary_op_scalar_1_replacement(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        num,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        out = qbinary_scalar_op(
            x,
            x_scale,
            x_zero_point,
            x_qmin,
            x_qmax,
            num,
            out_scale,
            out_zero_point,
            out_qmin,
            out_qmax,
        )

        return out

    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_scalar_op.name())
    def binary_op_scalar_2_pattern(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        num,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )

        out = binary_op(num, x)
        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
        )

        return out

    def binary_op_scalar_2_replacement(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        num,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        out = qbinary_scalar_op(
            x,
            x_scale,
            x_zero_point,
            x_qmin,
            x_qmax,
            num,
            out_scale,
            out_zero_point,
            out_qmin,
            out_qmax,
        )

        return out

    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_relu_op.name())
    def binary_relu_op_pattern(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        y,
        y_scale,
        y_zero_point,
        y_qmin,
        y_qmax,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )
        y = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            y, y_scale, y_zero_point, y_qmin, y_qmax, torch.uint8
        )

        out = binary_op(x, y)
        out = torch.ops.aten.relu.default(out)
        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
        )

        return out

    def binary_relu_op_replacement(
        x,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        y,
        y_scale,
        y_zero_point,
        y_qmin,
        y_qmax,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        out = qbinary_relu_op(
            x,
            x_scale,
            x_zero_point,
            x_qmin,
            x_qmax,
            y,
            y_scale,
            y_zero_point,
            y_qmin,
            y_qmax,
            out_scale,
            out_zero_point,
            out_qmin,
            out_qmax,
        )

        return out

    return [
        (
            _trace_and_lower_to_edge_ops(binary_relu_op_pattern),
            _trace_and_lower_to_edge_ops(binary_relu_op_replacement),
            [],
        ),
        (
            _trace_and_lower_to_edge_ops(binary_op_pattern),
            _trace_and_lower_to_edge_ops(binary_op_replacement),
            [],
        ),
        (
            _trace_and_lower_to_edge_ops(binary_op_scalar_1_pattern),
            _trace_and_lower_to_edge_ops(binary_op_scalar_1_replacement),
            [_sixth_input_is_scalar],
        ),
        (
            _trace_and_lower_to_edge_ops(binary_op_scalar_2_pattern),
            _trace_and_lower_to_edge_ops(binary_op_scalar_2_replacement),
            [_sixth_input_is_scalar],
        ),
    ]


def _get_binary_ops_patterns_and_replacements() -> (
    List[Tuple[Callable, Callable, List[Callable]]]
):

    # TODO: replace qbinary op with the ops implemented in lean mode
    binary_op_to_qbinary_ops = {
        exir_ops.edge.aten.add.Tensor: (
            exir_ops.edge.quantized_decomposed.add.default,
            exir_ops.edge.quantized_decomposed.add.scalar,
            exir_ops.edge.quantized_decomposed.add_relu.default,
        ),
    }
    pattern_and_replacements = []
    for binary_op, (qbop, qbscalar_op, qbrelu_op) in binary_op_to_qbinary_ops.items():
        pattern_and_replacements.extend(
            _get_binary_op_patterns_and_replacements(
                binary_op, qbop, qbscalar_op, qbrelu_op
            )
        )

    return pattern_and_replacements


def _get_reshape_patterns_and_replacements() -> (
    List[Tuple[Callable, Callable, List[Callable]]]
):
    def pattern(
        x,
        arg0,
        arg1,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )

        x = torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)
        x = torch.ops.quantized_decomposed.quantize_per_tensor.default(
            x, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
        )

        return x

    def replacement(
        x,
        arg0,
        arg1,
        x_scale,
        x_zero_point,
        x_qmin,
        x_qmax,
        out_scale,
        out_zero_point,
        out_qmin,
        out_qmax,
    ):

        x = torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)
        return x

    return [
        (
            _trace_and_lower_to_edge_ops(pattern),
            _trace_and_lower_to_edge_ops(replacement),
            [],
        )
    ]


def _get_slice_patterns_and_replacements() -> (
    List[Tuple[Callable, Callable, List[Callable]]]
):
    def pattern(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )
        x = torch.ops.aten.slice_copy.Tensor(x, dim, start, end)
        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
        )
        return x

    def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
        x = torch.ops.aten.slice_copy.Tensor(x, dim, start, end)
        return x

    return [
        (
            _trace_and_lower_to_edge_ops(pattern),
            _trace_and_lower_to_edge_ops(replacement),
            [],
        )
    ]


def _get_embedding_ops_patterns_and_replacements() -> (
    List[Tuple[Callable, Callable, List[Callable]]]
):
    def get_pattern_and_replacement():
        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
        def pattern(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indicies,
        ):
            weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
                weight,
                weight_scales,
                weight_zero_points,
                0,
                weight_quant_min,
                weight_quant_max,
                torch.uint8,
            )
            out = torch.ops.aten.embedding.default(weight, indicies)
            return out

        def replacement(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indicies,
        ):
            out = torch.ops.quantized_decomposed.embedding_byte.default(
                weight,
                weight_scales,
                weight_zero_points,
                weight_quant_min,
                weight_quant_max,
                indicies,
            )
            return out

        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
        def pattern_groupwise(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indices,
            group_size,
        ):
            weight = (
                torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
                    weight,
                    weight_scales,
                    weight_zero_points,
                    weight_quant_min,
                    weight_quant_max,
                    weight.dtype,
                    group_size,
                    weight_scales.dtype,
                )
            )
            out = torch.ops.aten.embedding.default(weight, indices)
            return out

        def replacement_groupwise(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indices,
            group_size,
        ):
            out = torch.ops.quantized_decomposed.embedding_byte.default(
                weight,
                weight_scales,
                weight_zero_points,
                weight_quant_min,
                weight_quant_max,
                indices,
            )
            return out

        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
        def pattern_with_padding_idx(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indicies,
            padding_idx,
        ):
            weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
                weight,
                weight_scales,
                weight_zero_points,
                0,
                weight_quant_min,
                weight_quant_max,
                torch.uint8,
            )
            out = torch.ops.aten.embedding.default(weight, indicies, padding_idx)
            return out

        def replacement_with_padding_idx(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indicies,
            _,  # padding_idx only matters for training and not when running op for inference
        ):
            out = torch.ops.quantized_decomposed.embedding_byte.default(
                weight,
                weight_scales,
                weight_zero_points,
                weight_quant_min,
                weight_quant_max,
                indicies,
            )
            return out

        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
        def pattern_with_padding_idx_groupwise(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indices,
            group_size,
            padding_idx,
        ):
            weight = (
                torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
                    weight,
                    weight_scales,
                    weight_zero_points,
                    weight_quant_min,
                    weight_quant_max,
                    weight.dtype,
                    group_size,
                    weight_scales.dtype,
                )
            )
            out = torch.ops.aten.embedding.default(weight, indices, padding_idx)
            return out

        def replacement_with_padding_idx_groupwise(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indices,
            group_size,
            _,  # padding_idx only matters for training and not when running op for inference
        ):
            out = torch.ops.quantized_decomposed.embedding_byte.default(
                weight,
                weight_scales,
                weight_zero_points,
                weight_quant_min,
                weight_quant_max,
                indices,
            )
            return out

        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype")
        def pattern_with_dtype_groupwise(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indices,
            group_size,
            dtype,
        ):
            weight = (
                torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
                    weight,
                    weight_scales,
                    weight_zero_points,
                    weight_quant_min,
                    weight_quant_max,
                    weight.dtype,
                    group_size,
                    dtype,
                )
            )
            out = torch.ops.aten.embedding.default(weight, indices)
            return out

        def replacement_with_dtype_groupwise(
            weight,
            weight_scales,
            weight_zero_points,
            weight_quant_min,
            weight_quant_max,
            indices,
            group_size,
            dtype,
        ):
            out = torch.ops.quantized_decomposed.embedding_byte.dtype(
                weight,
                weight_scales,
                weight_zero_points,
                weight_quant_min,
                weight_quant_max,
                indices,
                dtype=dtype,
            )
            return out

        return [
            (
                _trace_and_lower_to_edge_ops(pattern),
                _trace_and_lower_to_edge_ops(replacement),
                [],
            ),
            (
                _trace_and_lower_to_edge_ops(pattern_groupwise),
                _trace_and_lower_to_edge_ops(replacement_groupwise),
                [],
            ),
            (
                _trace_and_lower_to_edge_ops(pattern_with_padding_idx),
                _trace_and_lower_to_edge_ops(replacement_with_padding_idx),
                [],
            ),
            (
                _trace_and_lower_to_edge_ops(pattern_with_padding_idx_groupwise),
                _trace_and_lower_to_edge_ops(replacement_with_padding_idx_groupwise),
                [],
            ),
            (
                _trace_and_lower_to_edge_ops(pattern_with_dtype_groupwise),
                _trace_and_lower_to_edge_ops(replacement_with_dtype_groupwise),
                [],
            ),
        ]

    patterns_and_replacements = []
    patterns_and_replacements.extend(
        get_pattern_and_replacement(),
    )
    return patterns_and_replacements


"""
def _get_fixed_qparams_ops_patterns_and_replacements() -> List[Tuple[Callable, Callable, List[Callable]]]:
    fixed_qparams_op_to_qop = {
        torch.ops.aten.softmax: (torch.ops.quantized_decomposed.softmax, 1.0 / 256.0, 0)
    }
    def get_pattern_and_replacement(fixed_qparams_op, fixed_scale, fixed_zero_point):
        def pattern(x, x_scale, x_zero_point, x_qmin, x_qmax):
            x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8)
            x = fixed_qparams_op(x)
            x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(x, fixed_scale, fixed_zero_point, 0, 255, torch.uint8)
            return x

        def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax):
            x = fixed_qparams_qop(x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8)
            return x

n        return [(pattern, replacement, [])]

    patterns_and_replacements = []
    for op, (qop, fixed_scale, fixed_zero_point) in fixed_qparams_op_to_qop.items():
        patterns_and_replacements.extend(
            get_pattern_and_replacement(op, qop, fixed_scale, fixed_zero_point)
        )
"""


def get_quant_patterns_and_replacements() -> (
    List[Tuple[Callable, Callable, List[Callable]]]
):

    return copy.copy(
        [
            *_get_binary_ops_patterns_and_replacements(),
            # TODO: enable following after the corresponding ops are implemented
            *_get_reshape_patterns_and_replacements(),
            *_get_slice_patterns_and_replacements(),
            # *_get_fixed_qparams_ops_patterns_and_replacements(),
            *_get_embedding_ops_patterns_and_replacements(),
        ]
    )
