# Owner(s): ["oncall: quantization"]

from collections import OrderedDict
import contextlib
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.reference as nnqr
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
import torch.multiprocessing as mp
from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY

# graph mode quantization based on fx
from torch.ao.quantization.quantize_fx import (
    prepare_fx,
    convert_fx,
    convert_to_reference_fx,
    _convert_to_reference_decomposed_fx,
    prepare_qat_fx,
    fuse_fx,
)


from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler

from torch.ao.quantization.fx.match_utils import (
    _is_match,
    MatchAllNode,
)

from torch.ao.quantization import (
    QuantType,
)
from torch.ao.quantization.quant_type import _get_quant_type_to_str

from torch.ao.quantization import (
    QuantStub,
    DeQuantStub,
    QuantWrapper,
    default_qconfig,
    default_dynamic_qconfig,
    default_per_channel_qconfig,
    default_qat_qconfig,
    default_reuse_input_qconfig,
    default_symmetric_qnnpack_qconfig,
    default_symmetric_qnnpack_qat_qconfig,
    per_channel_dynamic_qconfig,
    float16_dynamic_qconfig,
    float16_static_qconfig,
    float_qparams_weight_only_qconfig,
    float_qparams_weight_only_qconfig_4bit,
    get_default_qconfig,
    get_default_qat_qconfig,
    get_default_qconfig_mapping,
    get_default_qat_qconfig_mapping,
    fuse_modules,
    fuse_modules_qat,
    prepare,
    prepare_qat,
    convert,
    quantize_dynamic,
    default_placeholder_observer,
    default_weight_observer,
    PerChannelMinMaxObserver,
    FixedQParamsFakeQuantize,
    FixedQParamsObserver,
    FusedMovingAvgObsFakeQuantize,
    FakeQuantize,
    MovingAverageMinMaxObserver,
    HistogramObserver,
    ReuseInputObserver,
    QConfig,
    default_embedding_qat_qconfig,
)

from torch.ao.quantization.backend_config import (
    get_fbgemm_backend_config,
    get_qnnpack_backend_config,
    BackendConfig,
    BackendPatternConfig,
    DTypeConfig,
    DTypeWithConstraints,
    ObservationType
)
from torch.ao.quantization.backend_config.native import (
    get_test_only_legacy_native_backend_config,
)

from torch.ao.quantization.qconfig_mapping import (
    _get_symmetric_qnnpack_qconfig_mapping,
    _get_symmetric_qnnpack_qat_qconfig_mapping,
    _GLOBAL_DICT_KEY,
    _MODULE_NAME_DICT_KEY,
    _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY,
    _MODULE_NAME_REGEX_DICT_KEY,
    _OBJECT_TYPE_DICT_KEY,
    QConfigMapping,
)

from torch.ao.quantization.fx.qconfig_mapping_utils import (
    _get_object_type_qconfig,
    _get_module_name_qconfig,
    _get_module_name_regex_qconfig,
    _maybe_adjust_qconfig_for_module_name_object_type_order,
)

from torch.ao.quantization.fx.pattern_utils import (
    _DEFAULT_FUSION_PATTERNS,
    _DEFAULT_QUANTIZATION_PATTERNS,
    _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP,
    _DEFAULT_OUTPUT_OBSERVER_MAP,
    _register_fusion_pattern,
    _register_quant_pattern,
    get_default_output_activation_post_process_map
)

from torch.ao.quantization.fx.custom_config import (
    STANDALONE_MODULE_NAME_DICT_KEY,
    STANDALONE_MODULE_CLASS_DICT_KEY,
    FLOAT_TO_OBSERVED_DICT_KEY,
    OBSERVED_TO_QUANTIZED_DICT_KEY,
    NON_TRACEABLE_MODULE_NAME_DICT_KEY,
    NON_TRACEABLE_MODULE_CLASS_DICT_KEY,
    INPUT_QUANTIZED_INDEXES_DICT_KEY,
    OUTPUT_QUANTIZED_INDEXES_DICT_KEY,
    PRESERVED_ATTRIBUTES_DICT_KEY,
    FuseCustomConfig,
    ConvertCustomConfig,
    PrepareCustomConfig,
    StandaloneModuleConfigEntry,
)
import torch.ao.quantization.fx.lstm_utils

from torch.ao.quantization.fx.utils import (
    _reroute_tuple_getitem_pattern,
    NodeInfo,
)

from torch.ao.quantization.fake_quantize import (
    default_fixed_qparams_range_0to1_fake_quant,
    default_fixed_qparams_range_neg1to1_fake_quant,
)

from torch.ao.quantization.observer import (
    default_fixed_qparams_range_0to1_observer,
    default_fixed_qparams_range_neg1to1_observer,
    MinMaxObserver,
    _is_activation_post_process,
)

# test utils
from hypothesis import given, settings
from hypothesis import strategies as st
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_quantization import (
    LinearReluLinearModel,
    LinearReluModel,
    LinearBnLeakyReluModel,
    LinearTanhModel,
    ConvBnAddReluModel,
    QuantizationTestCase,
    skipIfNoFBGEMM,
    skipIfNoQNNPACK,
    skip_if_no_torchvision,
    train_one_epoch,
    run_ddp,
    test_only_eval_fn,
    test_only_train_fn,
    ModelForConvTransposeBNFusion,
    get_supported_device_types,
    skipIfNoONEDNN,
)

from torch.testing._internal.common_quantization import (
    LinearModelWithSubmodule,
    ResNetBase,
    RNNDynamicModel,
    RNNCellDynamicModel,
)

from torch.testing._internal.common_quantized import (
    supported_qengines,
    override_qengines,
    override_quantized_engine,
)

from torch.testing._internal.common_utils import (
    TemporaryFileName,
    IS_ARM64,
    skipIfTorchDynamo,
)

from torch.testing._internal.common_quantization import NodeSpec as ns

from torch.testing import FileCheck

import copy
import itertools
import operator
import unittest
import io
from typing import Callable, Optional, List, Tuple

class BinaryOp(torch.nn.Module):
    def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
        """ ibinary_op means inplace binary op
        """
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
        self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
        self.is_scalar = is_scalar
        self.op = ibinary_op if ibinary_op and is_inplace else binary_op

    def forward(self, x, y):
        x = self.conv1(x)
        y = 3 if self.is_scalar else self.conv2(y)
        # x = x + y
        x = self.op(x, y)
        # x = y + x
        x = self.op(y, x)
        return x

class BinaryOpNonQuantizedInput(torch.nn.Module):
    def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
        """ ibinary_op means inplace binary op
        """
        super().__init__()
        self.is_scalar = is_scalar
        self.op = ibinary_op if ibinary_op and is_inplace else binary_op

    def forward(self, x, y):
        y = 3 if self.is_scalar else y
        x = self.op(x, y)
        return x

class BinaryOpRelu(torch.nn.Module):
    def __init__(self, binary_op, ibinary_op, is_inplace, relu_callable,
                 is_scalar):
        """ ibinary_op means inplace binary op
        """
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
        self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
        self.op = ibinary_op if ibinary_op and is_inplace else binary_op
        self.relu_callable = relu_callable
        self.is_scalar = is_scalar
        if relu_callable is torch.nn.ReLU:
            self.relu = torch.nn.ReLU()
        else:
            self.relu = relu_callable

    def forward(self, x, y):
        x = self.conv1(x)
        y = 3 if self.is_scalar else self.conv2(y)
        x = self.op(x, y)
        x = self.relu(x)
        x = self.op(y, x)
        x = self.relu(x)
        return x

@torch.fx.wrap
def _user_func_with_complex_return_type(x):
    return list(torch.split(x, 1, 1))

class TestFuseFx(QuantizationTestCase):
    def test_fuse_conv_bn_relu(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1d = nn.Conv1d(1, 1, 1)
                self.conv2d = nn.Conv2d(1, 1, 1)
                self.conv3d = nn.Conv3d(1, 1, 1)
                self.bn1d = nn.BatchNorm1d(1)
                self.bn2d = nn.BatchNorm2d(1)
                self.bn3d = nn.BatchNorm3d(1)
                self.conv1d2 = nn.Conv1d(1, 1, 1)
                self.conv2d2 = nn.Conv2d(1, 1, 1)
                self.conv3d2 = nn.Conv3d(1, 1, 1)
                self.bn1d2 = nn.BatchNorm1d(1)
                self.bn2d2 = nn.BatchNorm2d(1)
                self.bn3d2 = nn.BatchNorm3d(1)
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.conv1d(x)
                x = self.bn1d(x)
                x = self.conv2d(x)
                x = self.bn2d(x)
                x = self.conv3d(x)
                x = self.bn3d(x)
                x = self.conv1d2(x)
                x = self.bn1d2(x)
                x = self.relu(x)
                x = self.conv2d2(x)
                x = self.bn2d2(x)
                x = self.relu(x)
                x = self.conv3d2(x)
                x = self.bn3d2(x)
                x = self.relu(x)
                return x

        # test train mode
        m = M().train()
        # currently we don't check if the module are configured with qconfig before fusion
        # TODO: if we decide to do that in the future, this test needs to
        # be updated
        # train mode fuse_fx is called in prepare_qat_fx
        m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),))
        expected_nodes = [
            ns.call_module(nni.ConvBn1d),
            ns.call_module(nni.ConvBn2d),
            ns.call_module(nni.ConvBn3d),
            ns.call_module(nni.ConvBnReLU1d),
            ns.call_module(nni.ConvBnReLU2d),
            ns.call_module(nni.ConvBnReLU3d),
        ]
        expected_occurrence = {
            ns.call_module(nn.ReLU): 0
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_list=expected_nodes,
            expected_node_occurrence=expected_occurrence)

        # test eval mode
        m = M().eval()
        # fuse_fx is a top level api and only supports eval mode
        m = fuse_fx(m)
        expected_nodes = [
            ns.call_module(nn.Conv1d),
            ns.call_module(nn.Conv2d),
            ns.call_module(nn.Conv3d),
            ns.call_module(nni.ConvReLU1d),
            ns.call_module(nni.ConvReLU2d),
            ns.call_module(nni.ConvReLU3d),
        ]
        # ConvBnRelu1d is not fused
        expected_occurrence = {
            ns.call_module(nn.ReLU): 0
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_list=expected_nodes,
            expected_node_occurrence=expected_occurrence)

    def test_fuse_linear_bn_eval(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(1, 1)
                self.bn1d = nn.BatchNorm1d(1)

            def forward(self, x):
                x = self.linear(x)
                x = self.bn1d(x)
                return x

        # test eval mode
        m = M().eval()
        # fuse_fx is a top level api and only supports eval mode
        m = fuse_fx(m)
        expected_nodes = [
            ns.call_module(nn.Linear),
        ]
        expected_occurrence = {
            ns.call_module(nn.BatchNorm1d): 0,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_list=expected_nodes,
            expected_node_occurrence=expected_occurrence)

    @skipIfNoONEDNN
    def test_fuse_linear_bn_leaky_relu_onednn(self):
        # linear - bn - leaky_relu is fused for onednn backend only
        from torch.ao.quantization.backend_config import get_onednn_backend_config
        expected_nodes = [
            ns.call_module(nni.LinearLeakyReLU),
        ]
        expected_occurrence = {
            ns.call_module(nn.BatchNorm1d): 0,
            ns.call_module(nn.LeakyReLU): 0,
        }

        for with_bn in [True, False]:
            # test eval mode
            m = LinearBnLeakyReluModel(with_bn).eval()
            # fuse_fx is a top level api and only supports eval mode
            m = fuse_fx(m,
                        backend_config=get_onednn_backend_config())
            self.checkGraphModuleNodes(
                m,
                expected_node_list=expected_nodes,
                expected_node_occurrence=expected_occurrence)

    def test_linear_bn_leaky_relu_not_fused_by_default(self):
        # Make sure linear - bn - leaky_relu is not fused by default
        for with_bn in [True, False]:
            # test eval mode
            m = LinearBnLeakyReluModel(with_bn).eval()
            # fuse_fx is a top level api and only supports eval mode
            m = fuse_fx(m)
            expected_nodes = [
                ns.call_module(nn.Linear),
                ns.call_module(nn.LeakyReLU),
            ]
            expected_occurrence = {
                ns.call_module(nni.LinearLeakyReLU): 0,
            }
            self.checkGraphModuleNodes(
                m,
                expected_node_list=expected_nodes,
                expected_node_occurrence=expected_occurrence)

    @skipIfNoONEDNN
    def test_fuse_linear_tanh_for_onednn_backend(self):
        # linear - tanh is fused for onednn backend only
        from torch.ao.quantization.backend_config import get_onednn_backend_config
        expected_nodes = [
            ns.call_module(nni.LinearTanh),
        ]
        expected_occurrence = {
            ns.call_module(nn.Linear): 0,
            ns.call_module(nn.Tanh): 0,
        }

        # test eval mode
        m = LinearTanhModel().eval()
        # fuse_fx is a top level api and only supports eval mode
        m = fuse_fx(m,
                    backend_config=get_onednn_backend_config())
        self.checkGraphModuleNodes(
            m,
            expected_node_list=expected_nodes,
            expected_node_occurrence=expected_occurrence)

    def test_linear_tanh_not_fused_by_default(self):
        # Make sure linear - tanh is not fused by default
        # test eval mode
        m = LinearTanhModel().eval()
        # fuse_fx is a top level api and only supports eval mode
        m = fuse_fx(m)
        expected_nodes = [
            ns.call_module(nn.Linear),
            ns.call_module(nn.Tanh),
        ]
        expected_occurrence = {
            ns.call_module(nni.LinearTanh): 0,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_list=expected_nodes,
            expected_node_occurrence=expected_occurrence)

    def test_fuse_conv_bn_add_relu_onednn(self):
        # conv - bn - add - relu is fused for onednn backend only
        from torch.ao.quantization.backend_config import get_onednn_backend_config
        options = itertools.product(
            [True, False],  # with_bn
            [True, False],  # with_relu
            [True, False],  # conv in the left
            [True, False],  # with_two_conv
            [True, False],  # use_torch_add
        )
        for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
            expected_nodes = [
                ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d),
            ]
            expected_occurrence = {
                ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d): 1,
                ns.call_module(nn.BatchNorm2d): 0,
            }

            # test eval mode
            m = ConvBnAddReluModel(
                with_bn=with_bn,
                with_relu=with_relu,
                left_conv=left_conv,
                two_conv=two_conv,
                use_torch_add=use_torch_add).eval()

            m = fuse_fx(m,
                        backend_config=get_onednn_backend_config())
            self.checkGraphModuleNodes(
                m,
                expected_node_list=expected_nodes,
                expected_node_occurrence=expected_occurrence)

    def test_fuse_conv_bn_add_relu_by_default(self):
        options = itertools.product(
            [True, False],  # with_bn
            [True, False],  # with_relu
            [True, False],  # conv in the left
            [True, False],  # with_two_conv
            [True, False],  # use_torch_add
        )
        for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
            # test eval mode
            expected_nodes = [
                ns.call_module(nn.Conv2d),
            ]
            expected_occurrence = {
                ns.call_module(nni.ConvAdd2d): 0,
            }
            m = ConvBnAddReluModel(
                with_bn=with_bn,
                with_relu=with_relu,
                left_conv=left_conv,
                two_conv=two_conv,
                use_torch_add=use_torch_add).eval()
            m = fuse_fx(m)
            self.checkGraphModuleNodes(
                m,
                expected_node_list=expected_nodes,
                expected_node_occurrence=expected_occurrence)

    @skipIfNoONEDNN
    def test_fuse_conv_bn_add_relu_lowering(self):
        """ Test fusion and lowering of Conv2d - (bn -) ReLU
            by FX. For onednn backedn only.
        """
        from torch.ao.quantization.backend_config import get_onednn_backend_config
        qconfig_mapping = get_default_qconfig_mapping('onednn')
        with override_quantized_engine('onednn'):
            options = itertools.product(
                [True, False],  # with_bn
                [True, False],  # with_relu
                [True, False],  # conv in the left
                [True, False],  # two_conv
                [True, False],  # use_torch_add
            )
            for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
                node_occurrence = {
                    ns.call_function(torch.quantize_per_tensor): 1 if two_conv else 2,
                    ns.call_method("dequantize"): 1,
                    ns.call_module(nniq.ConvAddReLU2d if with_relu else nniq.ConvAdd2d): 1,
                    ns.call_module(nn.Conv2d): 0,
                    ns.call_module(nn.ReLU): 0,
                }
                node_occurrence_ref = {
                    ns.call_function(torch.quantize_per_tensor): 3,
                    ns.call_method("dequantize"): 3,
                }

                # test eval mode
                m = ConvBnAddReluModel(
                    with_bn=with_bn,
                    with_relu=with_relu,
                    left_conv=left_conv,
                    two_conv=two_conv,
                    use_torch_add=use_torch_add).eval()
                example_x = m.get_example_inputs()
                m = prepare_fx(m, qconfig_mapping,
                               example_inputs=example_x,
                               backend_config=get_onednn_backend_config())
                m_copy = copy.deepcopy(m)
                m = convert_fx(m, backend_config=get_onednn_backend_config())
                m_ref = convert_to_reference_fx(m_copy)
                self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
                self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
                m(*example_x)

    def test_fuse_convtranspose_bn_eval(self):

        m = ModelForConvTransposeBNFusion().eval()
        m = fuse_fx(m)

        expected_nodes = [
            ns.call_module(nn.ConvTranspose1d),
            ns.call_module(nn.ConvTranspose2d),
            ns.call_module(nn.ConvTranspose3d),
        ]
        expected_occurrence = {
            ns.call_module(nn.BatchNorm1d): 0,
            ns.call_module(nn.BatchNorm2d): 0,
            ns.call_module(nn.BatchNorm3d): 0,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_list=expected_nodes,
            expected_node_occurrence=expected_occurrence)


    def test_fuse_module_relu(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1d = nn.Conv1d(1, 1, 1)
                self.conv2d = nn.Conv2d(1, 1, 1)
                self.conv3d = nn.Conv3d(1, 1, 1)
                self.bn1d = nn.BatchNorm1d(1)
                self.bn2d = nn.BatchNorm2d(1)
                self.bn3d = nn.BatchNorm3d(1)
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.conv1d(x)
                x = self.relu(x)
                x = self.conv2d(x)
                x = self.relu(x)
                x = self.conv3d(x)
                x = self.relu(x)
                x = self.bn1d(x)
                x = self.relu(x)
                x = self.bn2d(x)
                x = self.relu(x)
                x = self.bn3d(x)
                x = self.relu(x)
                return x

        m = M().eval()
        m = fuse_fx(m)
        expected_nodes = [
            ns.call_module(nni.ConvReLU1d),
            ns.call_module(nni.ConvReLU2d),
            ns.call_module(nni.ConvReLU3d),
            ns.call_module(nni.BNReLU2d),
            ns.call_module(nni.BNReLU3d),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=expected_nodes)

    @skipIfNoFBGEMM
    def test_qconfig_fused_module(self):
        """ TODO: add test for all fused modules
        """
        qconfig_dict = {
            "": None,
            "object_type": [(nn.Linear, default_qconfig),
                            (nn.ReLU, default_qconfig),
                            (F.relu, default_qconfig)]
        }

        linearRelu_node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        tests = [(LinearReluModel, linearRelu_node_list),
                 (LinearReluLinearModel, linearReluLinear_node_list)]

        for M, node_list in tests:
            m = M().eval()
            example_inputs = (torch.rand(5, 5),)
            prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)

            prepared(*example_inputs)
            quantized = convert_fx(prepared)

            self.checkGraphModuleNodes(quantized, expected_node_list=node_list)

    def test_problematic_fuse_example(self):
        class LinearRelu(nn.Sequential):
            def __init__(self) -> None:
                super().__init__(
                    nn.Linear(5, 5),
                    nn.ReLU(),
                )

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin_relu = LinearRelu()
                self.linear = nn.Linear(5, 5)

            def forward(self, x):
                x = self.lin_relu(x)
                x = self.linear(x)
                return x

        model = M().eval()
        # these qconfigs somehow fail equality where default_qconfig does not
        qconfig_dict = {
            "": None,
            "object_type": [
                (torch.nn.Linear, get_default_qconfig('fbgemm')),
                (torch.nn.ReLU, get_default_qconfig('fbgemm')),
            ],
        }
        m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))

        self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.ao.nn.intrinsic.modules.fused.LinearReLU))

    @unittest.skip("Temporarily skipping the test case, will enable after the simple"
                   "pattern format is supported")
    def test_fuse_addtional_fuser_method(self):
        class MyConvReLU(torch.nn.Module):
            pass

        def my_conv_relu_fuser(conv, relu):
            return MyConvReLU()

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(self.conv(x))

        m = M().eval()
        m = fuse_fx(m, fuse_custom_config={
            "additional_fuser_method_mapping": {
                (torch.nn.Conv2d, torch.nn.ReLU): my_conv_relu_fuser
            }
        })
        self.checkGraphModuleNodes(m, expected_node=ns.call_module(MyConvReLU))

    def test_fuse_custom_pattern(self):
        class M(torch.nn.Module):
            def __init__(self, use_torch_add=True):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.bn = torch.nn.BatchNorm2d(3)
                self.relu = torch.nn.ReLU()
                self.maxpool = torch.nn.MaxPool2d(3)
                if use_torch_add:
                    self.add = torch.add
                else:
                    self.add = operator.add

            def forward(self, x):
                y = x
                y = self.maxpool(x)
                x = self.conv(x)
                x = self.bn(x)
                x = self.add(y, x)
                x = self.relu(x)
                return x

        for use_torch_add in [True, False]:
            m = M(use_torch_add).eval()

            def fuse_conv_bn_relu(is_qat, relu, add_pattern):
                _, _, bn_pattern = add_pattern
                bn, conv = bn_pattern
                return conv

            conv_bn_res_relu_config1 = BackendPatternConfig() \
                ._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
                .set_fuser_method(fuse_conv_bn_relu)
            conv_bn_res_relu_config2 = BackendPatternConfig() \
                ._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
                .set_fuser_method(fuse_conv_bn_relu)
            backend_config = BackendConfig() \
                .set_backend_pattern_config(conv_bn_res_relu_config1) \
                .set_backend_pattern_config(conv_bn_res_relu_config2)
            m = fuse_fx(m, backend_config=backend_config)
            self.assertEqual(type(m.conv), torch.nn.Conv2d)
            # check bn and relu are gone since we replaced the whole pattern to conv
            self.assertFalse(hasattr(m, "bn"))
            self.assertFalse(hasattr(m, "relu"))

    def test_fusion_pattern_with_multiple_inputs(self):
        """ This test tests two keys in backend_config: root_node_getter and
        extra_inputs_getter,
        root_node_getter is used to identify a "root" module in the node pattern,
        the node that we'll keep after fusion.
        extra_inputs_getter will return a list of node that needs to be added to the
        fused node as extra inputs.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.bn = torch.nn.BatchNorm2d(3)
                self.relu = torch.nn.ReLU()
                self.maxpool = torch.nn.MaxPool2d(3)

            def forward(self, x):
                y = x
                y = self.maxpool(x)
                x = self.conv(x)
                x = self.bn(x)
                x = torch.add(x, y)
                x = self.relu(x)
                return x

        m = M().eval()

        def fuse_conv_bn_relu(is_qat, relu, add_pattern):
            _, bn_pattern, _ = add_pattern
            bn, conv = bn_pattern
            return conv

        def conv_bn_res_relu_root_node_getter(pattern):
            relu, add_pattern = pattern
            _, bn_pattern, _ = add_pattern
            bn, conv = bn_pattern
            return conv

        def conv_bn_res_relu_extra_inputs_getter(pattern):
            """ get inputs pattern for extra inputs, inputs for root node
            are assumed to be copied over from root node to the fused node
            """
            relu, add_pattern = pattern
            _, bn_pattern, extra_input = add_pattern
            bn, conv = bn_pattern
            return [extra_input]

        conv_bn_res_relu_config = BackendPatternConfig() \
            ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
            .set_fuser_method(fuse_conv_bn_relu) \
            ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
            ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
        backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config)
        m = fuse_fx(m, backend_config=backend_config)
        self.assertEqual(type(m.conv), torch.nn.Conv2d)
        # check bn and relu are gone since we replaced the whole pattern to conv
        self.assertFalse(hasattr(m, "bn"))
        self.assertFalse(hasattr(m, "relu"))

        # check conv module has two inputs
        named_modules = dict(m.named_modules())
        for node in m.graph.nodes:
            if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
                self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments"

    def test_fusion_pattern_with_matchallnode(self):
        """This test tests that the node matched by MatchAllNode will be regared as an input
        instead of a module to be fused. For instance, we have two patterns:
            (nn.ReLU, (torch.add, MatchAllNode, nn.Conv2d))
            (nn.ReLU, nn.Conv2d)
        And we wanna fuse the following model
            Conv2d -> ReLU +
            Conv2d ------ Add -> ReLU
        ReLU in the first row is matched as MatchAllNode in the residual pattern. But it won't be
        fused as part of that pattnern. It needs to be properly fused with the upstream Conv2d.
        """

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.relu1 = torch.nn.ReLU()
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu2 = torch.nn.ReLU()

            def forward(self, x):
                y = self.conv1(x)
                y = self.relu1(y)

                x = self.conv2(x)
                x = torch.add(x, y)
                x = self.relu2(x)
                return x

        m = M().eval()

        def fuse_conv_relu(is_qat, conv, relu):
            return conv

        def fuse_conv_res_relu(is_qat, relu, add_pattern):
            _, conv, _ = add_pattern
            return conv

        def conv_res_relu_root_node_getter(pattern):
            relu, (_, conv, _) = pattern
            return conv

        def conv_res_relu_extra_inputs_getter(pattern):
            relu, (_, _, extra_input) = pattern
            return [extra_input]

        conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \
            .set_fuser_method(fuse_conv_relu)
        conv_res_relu_config = BackendPatternConfig() \
            ._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
            .set_fuser_method(fuse_conv_res_relu) \
            ._set_root_node_getter(conv_res_relu_root_node_getter) \
            ._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter)
        backend_config = BackendConfig() \
            .set_backend_pattern_config(conv_relu_config) \
            .set_backend_pattern_config(conv_res_relu_config)
        m = fuse_fx(m, backend_config=backend_config)
        self.assertEqual(type(m.conv1), torch.nn.Conv2d)
        self.assertEqual(type(m.conv2), torch.nn.Conv2d)
        # check relu are gone since we replaced both patterns to conv
        self.assertFalse(hasattr(m, "relu1"))
        self.assertFalse(hasattr(m, "relu2"))


@skipIfNoFBGEMM
class TestQuantizeFx(QuantizationTestCase):
    def test_pattern_match(self):
        """ test MatchAllNode with
            conv - bn - add - relu pattern
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 1)
                self.bn = nn.BatchNorm2d(1)
                self.relu = nn.ReLU()

            def forward(self, x, y):
                x = self.conv(x)
                x = self.bn(x)
                x = x + y
                x = self.relu(x)
                return x

        pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
        m = torch.fx.symbolic_trace(M())
        modules = dict(m.named_modules())
        for n in m.graph.nodes:
            if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
                self.assertTrue(_is_match(modules, n, pattern))

    def test_pattern_match_constant(self):
        class M(torch.nn.Module):
            def forward(self, x):
                x, _ = torch.ops.aten.max_pool2d_with_indices.default(x)
                return x

        pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0)
        m = torch.fx.symbolic_trace(M())
        # eliminate the code that get the second output of maxpool, so that the pattern
        # can be matched
        m.graph.eliminate_dead_code()
        modules = dict(m.named_modules())
        for n in m.graph.nodes:
            if n.op == "call_function" and n.target == operator.getitem:
                self.assertTrue(_is_match(modules, n, pattern))

    def test_fused_module_qat_swap(self):
        class Tmp(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.tmp = torch.nn.Linear(5, 5)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.tmp(x)
                return self.relu(x)


        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(Tmp(), torch.nn.Linear(5, 5))
                self.mods2 = torch.nn.Linear(5, 5)

            def forward(self, x):
                a = self.mods1(x)
                x = torch.add(x, 5)
                x = self.mods2(x)
                x = torch.add(x, 5)
                return a, x


        model = M().train()
        qconfig_dict = {
            "": None,
            "object_type": [
                (torch.nn.Linear, default_qat_qconfig),
                (torch.nn.ReLU, default_qat_qconfig),
            ],
        }
        prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
        self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU))

    def _get_conv_linear_test_cases(self, is_reference):
        """ Returns a list of test cases, with format:
        is_dynamic, ModuleClass, module_constructor_inputs,
        inputs, quantized_node, weight_prepack_op
        """
        class FunctionalConv1d(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = 1
                self.padding = 0
                self.dilation = 1
                self.groups = 1

            def forward(self, x):
                return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)


        class Conv1d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv1d(*args)

            def forward(self, x):
                return self.conv(x)

        conv1d_input = torch.rand(1, 3, 224)
        conv1d_weight = torch.rand(3, 3, 3)
        conv1d_module_args = (3, 3, 3)

        class FunctionalConv2d(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x):
                return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

        class Conv2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv2d(*args)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        conv2d_weight = torch.rand(3, 3, 3, 3)
        conv2d_module_args = (3, 3, 3)

        class FunctionalConv3d(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = (1, 1, 1)
                self.padding = (0, 0, 0)
                self.dilation = (1, 1, 1)
                self.groups = 1

            def forward(self, x):
                return F.conv3d(
                    x,
                    self.weight,
                    None,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.groups,
                )

        class Conv3d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv3d(*args)

            def forward(self, x):
                return self.conv(x)

        conv3d_input = torch.rand(1, 3, 32, 224, 224)
        conv3d_weight = torch.rand(3, 3, 3, 3, 3)
        conv3d_module_args = (3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)

            def forward(self, x):
                return F.linear(x, self.weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        # is_dynamic, ModuleClass, module_constructor_inputs,
        # inputs, quantized_node, weight_prepack_node
        tests = [
            (
                False,
                FunctionalConv1d,
                (conv1d_weight,),
                (conv1d_input,),
                ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) ,
                ns.call_function(torch.ops.quantized.conv1d_prepack),
            ),
            (
                False,
                FunctionalConv2d,
                (conv2d_weight,),
                (conv2d_input,),
                ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d),
                ns.call_function(torch.ops.quantized.conv2d_prepack),
            ),
            (
                False,
                FunctionalConv3d,
                (conv3d_weight,),
                (conv3d_input,),
                ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d),
                ns.call_function(torch.ops.quantized.conv3d_prepack),
            ),
            (
                False,
                Conv1d,
                conv1d_module_args,
                (conv1d_input,),
                ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
                None
            ),
            (
                False,
                Conv2d,
                conv2d_module_args,
                (conv2d_input,),
                ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
                None
            ),
            (
                False,
                Conv3d,
                conv3d_module_args,
                (conv3d_input,),
                ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
                None
            ),
            (
                True,
                Linear,
                (linear_weight,),
                (linear_input,),
                None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic),
                ns.call_function(torch.ops.quantized.linear_prepack),
            ),
            (
                False,
                Linear,
                (linear_weight,),
                (linear_input,),
                ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear),
                ns.call_function(torch.ops.quantized.linear_prepack),
            ),
            (
                True,
                LinearModule,
                (),
                (linear_module_input,),
                ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear),
                None,
            ),
            (
                False,
                LinearModule,
                (),
                (linear_module_input,),
                ns.call_module(nnqr.Linear if is_reference else nnq.Linear),
                None,
            ),
        ]
        return tests

    @skipIfNoFBGEMM
    def test_conv_linear_not_reference(self):
        """ Test quantizing conv and linear
        """
        tests = self._get_conv_linear_test_cases(is_reference=False)
        for (is_dynamic, ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            node_occurrence = {}
            if weight_prepack_node:
                node_occurrence[weight_prepack_node] = 0
            self.checkGraphModeFxOp(
                ModuleClass(*module_constructor_inputs),
                inputs, quant_type,
                expected_node=quantized_node,
                expected_node_occurrence=node_occurrence,
                is_reference=False)

    @skipIfNoFBGEMM
    def test_conv_linear_reference(self):
        """ Test quantizing functional conv and linear with reference option
        """
        tests = self._get_conv_linear_test_cases(is_reference=True)

        def _get_keys(prefix, is_dynamic):
            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
            if not is_dynamic:
                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
            return all_keys

        for (is_dynamic, ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            node_occurrence = {}
            if weight_prepack_node:
                node_occurrence[weight_prepack_node] = 0
            result_dict = self.checkGraphModeFxOp(
                ModuleClass(*module_constructor_inputs),
                inputs, quant_type,
                expected_node=quantized_node,
                expected_node_occurrence=node_occurrence,
                is_reference=True)
            qr = result_dict["quantized_reference"]

            def checkWeightQParams(model):
                for module_name in ("linear", "conv"):
                    if hasattr(model, module_name):
                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
                        self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())

            def checkSerDeser(model, is_dynamic):
                for module_name in ("linear", "conv"):
                    if hasattr(model, module_name):
                        # make sure seralization works
                        state_dict = copy.deepcopy(model.state_dict())
                        all_keys = _get_keys(module_name, is_dynamic)
                        for key in all_keys:
                            self.assertTrue(key in state_dict)
                        # check load_state_dict restores states
                        module = getattr(model, module_name)
                        prev_scale = module.weight_scale
                        module.weight_scale = None
                        model.load_state_dict(state_dict)
                        module = getattr(model, module_name)
                        self.assertTrue(torch.equal(prev_scale, module.weight_scale))


            checkWeightQParams(qr)
            qr = copy.deepcopy(qr)
            # make sure the qparams are preserved after copy
            checkWeightQParams(qr)

            checkSerDeser(qr, is_dynamic)

    def _get_conv_transpose_test_cases(self, use_relu, is_reference):
        """ Returns a list of test cases, with format:
        is_dynamic, ModuleClass, module_constructor_inputs,
        inputs, quantized_node, weight_prepack_op
        """
        class FunctionalConvTranspose1d(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = 1
                self.padding = 0
                self.output_padding = 0
                self.dilation = 1
                self.groups = 1

            def forward(self, x):
                y = F.conv_transpose1d(
                    x,
                    self.weight,
                    None,
                    self.stride,
                    self.padding,
                    self.output_padding,
                    self.groups,
                    self.dilation
                )
                if use_relu:
                    y = F.relu(y)
                return y

        class ConvTranspose1d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.deconv = torch.nn.ConvTranspose1d(*args)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                y = self.deconv(x)
                if use_relu:
                    y = self.relu(y)
                return y

        conv_transpose1d_input = torch.rand(1, 3, 224)
        conv_transpose1d_weight = torch.rand(3, 3, 3)
        conv_transpose1d_module_args = (3, 3, 3)

        class FunctionalConvTranspose2d(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.output_padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x):
                y = F.conv_transpose2d(
                    x,
                    self.weight,
                    None,
                    self.stride,
                    self.padding,
                    self.output_padding,
                    self.groups,
                    self.dilation
                )
                if use_relu:
                    y = F.relu(y)
                return y

        class ConvTranspose2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.deconv = torch.nn.ConvTranspose2d(*args)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                y = self.deconv(x)
                if use_relu:
                    y = self.relu(y)
                return y

        conv_transpose2d_input = torch.rand(1, 3, 224, 224)
        conv_transpose2d_weight = torch.rand(3, 3, 3, 3)
        conv_transpose2d_module_args = (3, 3, 3)

        class FunctionalConvTranspose3d(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = (1, 1, 1)
                self.padding = (0, 0, 0)
                self.output_padding = (0, 0, 0)
                self.dilation = (1, 1, 1)
                self.groups = 1

            def forward(self, x):
                y = F.conv_transpose3d(
                    x,
                    self.weight,
                    None,
                    self.stride,
                    self.padding,
                    self.output_padding,
                    self.groups,
                    self.dilation
                )
                if use_relu:
                    y = F.relu(y)
                return y

        class ConvTranspose3d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.deconv = torch.nn.ConvTranspose3d(*args)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                y = self.deconv(x)
                if use_relu:
                    y = self.relu(y)
                return y

        conv_transpose3d_input = torch.rand(1, 3, 32, 224, 224)
        conv_transpose3d_weight = torch.rand(3, 3, 3, 3, 3)
        conv_transpose3d_module_args = (3, 3, 3)

        # is_dynamic, ModuleClass, module_constructor_inputs,
        # inputs, quantized_node, weight_prepack_node
        tests = [
            (
                False,
                FunctionalConvTranspose1d,
                (conv_transpose1d_weight,),
                (conv_transpose1d_input,),
                ns.call_function(
                    torch.nn.functional.conv_transpose1d if is_reference else torch.ops.quantized.conv_transpose1d
                ),
                ns.call_function(torch.ops.quantized.conv_transpose1d_prepack),
            ),
            (
                False,
                FunctionalConvTranspose2d,
                (conv_transpose2d_weight,),
                (conv_transpose2d_input,),
                ns.call_function(
                    torch.nn.functional.conv_transpose2d if is_reference else torch.ops.quantized.conv_transpose2d
                ),
                ns.call_function(torch.ops.quantized.conv_transpose2d_prepack),
            ),
            (
                False,
                FunctionalConvTranspose3d,
                (conv_transpose3d_weight,),
                (conv_transpose3d_input,),
                ns.call_function(
                    torch.nn.functional.conv_transpose3d if is_reference else torch.ops.quantized.conv_transpose3d),
                ns.call_function(torch.ops.quantized.conv_transpose3d_prepack),
            ),
            (
                False,
                ConvTranspose1d,
                conv_transpose1d_module_args,
                (conv_transpose1d_input,),
                ns.call_module(nnqr.ConvTranspose1d if is_reference else nnq.ConvTranspose1d),
                None
            ),
            (
                False,
                ConvTranspose2d,
                conv_transpose2d_module_args,
                (conv_transpose2d_input,),
                ns.call_module(nnqr.ConvTranspose2d if is_reference else nnq.ConvTranspose2d),
                None
            ),
            (
                False,
                ConvTranspose3d,
                conv_transpose3d_module_args,
                (conv_transpose3d_input,),
                ns.call_module(nnqr.ConvTranspose3d if is_reference else nnq.ConvTranspose3d),
                None
            ),
        ]
        return tests

    @skipIfNoFBGEMM
    def test_conv_transpose_not_reference(self):
        """ Test quantizing transposed conv
        """
        tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=False)
        for (is_dynamic, ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            node_occurrence = {}
            if weight_prepack_node:
                node_occurrence[weight_prepack_node] = 0
            self.checkGraphModeFxOp(
                ModuleClass(*module_constructor_inputs),
                inputs, quant_type,
                expected_node=quantized_node,
                expected_node_occurrence=node_occurrence,
                is_reference=False)

    @skipIfNoFBGEMM
    def test_conv_transpose_reference(self):
        """ Test quantizing transposed conv with reference option
        """
        tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=True)

        def _get_keys(prefix, is_dynamic):
            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
            if not is_dynamic:
                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
            return all_keys

        for (is_dynamic, ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            node_occurrence = {}
            if weight_prepack_node:
                node_occurrence[weight_prepack_node] = 0
            result_dict = self.checkGraphModeFxOp(
                ModuleClass(*module_constructor_inputs),
                inputs, quant_type,
                expected_node=quantized_node,
                expected_node_occurrence=node_occurrence,
                is_reference=True)
            qr = result_dict["quantized_reference"]

            def checkWeightQParams(model):
                module_name = "deconv"
                if hasattr(model, module_name):
                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
                    self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())

            def checkSerDeser(model, is_dynamic):
                module_name = "deconv"
                if hasattr(model, module_name):
                    # make sure seralization works
                    state_dict = copy.deepcopy(model.state_dict())
                    all_keys = _get_keys(module_name, is_dynamic)
                    for key in all_keys:
                        self.assertTrue(key in state_dict)
                    # check load_state_dict restores states
                    module = getattr(model, module_name)
                    prev_scale = module.weight_scale
                    module.weight_scale = None
                    model.load_state_dict(state_dict)
                    module = getattr(model, module_name)
                    self.assertTrue(torch.equal(prev_scale, module.weight_scale))


            checkWeightQParams(qr)
            qr = copy.deepcopy(qr)
            # make sure the qparams are preserved after copy
            checkWeightQParams(qr)

            checkSerDeser(qr, is_dynamic)

    def test_conv_transpose_relu_not_reference(self):
        """ Test quantizing transposed conv + relu
            Fusion with relu is not supported.
        """
        tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=False)
        for (is_dynamic, ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            node_occurrence = {}
            if weight_prepack_node:
                node_occurrence[weight_prepack_node] = 0
            if quantized_node.op == 'call_module':
                node_occurrence[ns.call_module(nn.ReLU)] = 1
            else:
                node_occurrence[ns.call_function(F.relu)] = 1
            self.checkGraphModeFxOp(
                ModuleClass(*module_constructor_inputs),
                inputs, quant_type,
                expected_node=quantized_node,
                expected_node_occurrence=node_occurrence,
                is_reference=False)

    @skipIfNoFBGEMM
    def test_conv_transpose_relu_reference(self):
        """ Test quantizing transposed conv with reference option
            Fusion with relu is not supported.
        """
        tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=True)

        def _get_keys(prefix, is_dynamic):
            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
            if not is_dynamic:
                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
            return all_keys

        for (is_dynamic, ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            node_occurrence = {}
            if weight_prepack_node:
                node_occurrence[weight_prepack_node] = 0
            if quantized_node.op == 'call_module':
                node_occurrence[ns.call_module(nn.ReLU)] = 1
            else:
                node_occurrence[ns.call_function(F.relu)] = 1
            result_dict = self.checkGraphModeFxOp(
                ModuleClass(*module_constructor_inputs),
                inputs, quant_type,
                expected_node=quantized_node,
                expected_node_occurrence=node_occurrence,
                is_reference=True)
            qr = result_dict["quantized_reference"]

            def checkWeightQParams(model):
                module_name = "deconv"
                if hasattr(model, module_name):
                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
                    self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())

            def checkSerDeser(model, is_dynamic):
                module_name = "deconv"
                if hasattr(model, module_name):
                    # make sure seralization works
                    state_dict = copy.deepcopy(model.state_dict())
                    all_keys = _get_keys(module_name, is_dynamic)
                    for key in all_keys:
                        self.assertTrue(key in state_dict)
                    # check load_state_dict restores states
                    module = getattr(model, module_name)
                    prev_scale = module.weight_scale
                    module.weight_scale = None
                    model.load_state_dict(state_dict)
                    module = getattr(model, module_name)
                    self.assertTrue(torch.equal(prev_scale, module.weight_scale))


            checkWeightQParams(qr)
            qr = copy.deepcopy(qr)
            # make sure the qparams are preserved after copy
            checkWeightQParams(qr)

            checkSerDeser(qr, is_dynamic)

    @skipIfNoFBGEMM
    def test_dynamic_quant_weight_observer(self):
        ''' Test that weight observer is run in convert step
        '''

        class M(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)

            def forward(self, x):
                return F.linear(x, self.weight)

        m = M(torch.rand(1, 1)).eval()
        qconfig = default_dynamic_qconfig
        qconfig_dict = {'': qconfig}
        example_inputs = (torch.rand(1, 1),)
        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        quantized = convert_to_reference_fx(prepared)
        qparams = (quantized._scale_0, quantized._zero_point_0)
        weight_obs = qconfig.weight()
        weight_obs(quantized.weight)
        # Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1])
        ref_qparams = (weight_obs.calculate_qparams()[0].item(), weight_obs.calculate_qparams()[1].item())
        self.assertEqual(qparams, ref_qparams)

    def test_conv_bn_relu(self):
        """ Tests fusion and quantization for "Conv - Bn" and "Conv - Bn - ReLU"
        """
        convs = {
            1: nn.Conv1d,
            2: nn.Conv2d,
            3: nn.Conv3d,
        }
        bns = {
            1: nn.BatchNorm1d,
            2: nn.BatchNorm2d,
            3: nn.BatchNorm3d,
        }
        quantized_convs = {
            1: nnq.Conv1d,
            2: nnq.Conv2d,
            3: nnq.Conv3d,
        }
        quantized_conv_relus = {
            1: nniq.ConvReLU1d,
            2: nniq.ConvReLU2d,
            3: nniq.ConvReLU3d,
        }

        class M(torch.nn.Module):
            def __init__(self, dim, has_relu):
                super().__init__()
                self.conv = convs[dim](3, 3, 3)
                self.bn = bns[dim](3)
                self.relu = nn.ReLU() if has_relu else nn.Identity()
                self.has_relu = has_relu
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv(x)
                x = self.bn(x)
                if self.has_relu:
                    x = self.relu(x)
                x = self.dequant(x)
                return x

        options = itertools.product([1, 2, 3], [True, False], self.static_quant_types)
        for dim, has_relu, quant_type in options:
            expected_node = ns.call_module(
                quantized_conv_relus[dim] if has_relu
                else quantized_convs[dim])
            m = M(dim, has_relu)
            m_eager = copy.deepcopy(m)
            result_dict = self.checkGraphModeFxOp(
                m,
                self.img_data_dict[dim],
                quant_type,
                expected_node=expected_node,
            )
            result = result_dict["quantized_output"]

            # check numerics
            qengine = torch.backends.quantized.engine
            if quant_type == QuantType.STATIC:
                m_eager.eval()
                qconfig = get_default_qconfig(qengine)
                prepare_fn = prepare
                is_qat = False
            else:
                m_eager.train()
                qconfig = get_default_qat_qconfig(qengine)
                prepare_fn = prepare_qat
                is_qat = True

            fuse_list = ["conv", "bn"]
            if has_relu:
                fuse_list.append("relu")
            if is_qat:
                fuse_modules_qat(m_eager, fuse_list, inplace=True)
            else:
                fuse_modules(m_eager, fuse_list, inplace=True)
            m_eager.qconfig = qconfig
            m_eager = prepare_fn(m_eager)
            prepared_fx = result_dict["prepared"]

            m_eager(*self.img_data_dict[dim][0])
            m_eager = convert(m_eager)
            result_eager = m_eager(*self.img_data_dict[dim][0])
            self.assertEqual(result, result_eager)

    def test_linear_bn(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(4, 4)
                self.bn = nn.BatchNorm1d(4)
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.linear(x)
                x = self.bn(x)
                x = self.dequant(x)
                return x

        data = (torch.randn(4, 4),)
        for quant_type in self.static_quant_types:
            expected_node = ns.call_module(nnq.Linear)
            m = M()
            m_eager = copy.deepcopy(m)
            result_dict = self.checkGraphModeFxOp(m, data, quant_type, expected_node=expected_node)
            result = result_dict["quantized_output"]

            # check numerics vs eager mode
            fuse_list = ["linear", "bn"]
            qengine = torch.backends.quantized.engine
            if quant_type == QuantType.STATIC:
                m_eager.eval()
                qconfig = get_default_qconfig(qengine)
                prepare_fn = prepare
                fuse_modules(m_eager, fuse_list, inplace=True)
            else:
                m_eager.train()
                qconfig = get_default_qat_qconfig(qengine)
                prepare_fn = prepare_qat
                fuse_modules_qat(m_eager, fuse_list, inplace=True)
            m_eager.qconfig = qconfig
            m_eager = prepare_fn(m_eager)
            m_eager(*data)
            m_eager = convert(m_eager)
            result_eager = m_eager(*data)
            self.assertEqual(result, result_eager)

    @skipIfNoFBGEMM
    def test_dynamic_quant_fp16(self):
        with override_quantized_engine('fbgemm'):
            class Linear(torch.nn.Module):
                def __init__(self, weight):
                    super().__init__()
                    self.weight = torch.nn.Parameter(weight)

                def forward(self, x):
                    return F.linear(x, self.weight)

            linear_input = torch.rand(8, 5)
            linear_weight = torch.rand(10, 5)

            class LinearModule(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = torch.nn.Linear(5, 10)

                def forward(self, x):
                    return self.linear(x)

            linear_module_input = torch.rand(8, 5)

            tests = [
                (Linear, (linear_weight,), (linear_input,),
                 ns.call_function(torch.ops.quantized.linear_dynamic_fp16),
                 ns.call_function(torch.ops.quantized.linear_prepack_fp16)),
                (LinearModule, (), (linear_module_input,),
                 ns.call_module(nnqd.Linear),
                 None),
            ]
            for (ModuleClass, module_constructor_inputs,
                 inputs, quantized_node, weight_prepack_node) in tests:
                for is_reference in [True, False]:
                    node_occurrence = {}
                    if weight_prepack_node:
                        node_occurrence[weight_prepack_node] = 0
                    m = ModuleClass(*module_constructor_inputs).eval()
                    qconfig_dict = {"": float16_dynamic_qconfig}
                    m = prepare_fx(m, qconfig_dict, example_inputs=inputs)
                    convert_fn = convert_to_reference_fx if is_reference else convert_fx
                    m = convert_fn(m)
                    self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)



    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @override_qengines
    def test_qat_prepare_device_affinity(self):
        """
        Tests that FX QAT prepare pass respects device affinity
        """
        class Model(nn.Module):

            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 1)
                self.bn = nn.BatchNorm2d(1)
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                x = self.relu(x)
                return x

        model = Model()
        qengine = torch.backends.quantized.engine
        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
        device = torch.device('cuda:0')
        model.to(device)

        example_inputs = (torch.randn(4, 1, 4, 4, device=device),)
        # QAT prepare
        model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)

        # ensure that running an input on CUDA works without any needed changes
        model(*example_inputs)

        # ensure all buffers and parameters are on the device we expect
        model_devices = {p.device for p in model.parameters()} | \
            {p.device for p in model.buffers()}
        self.assertEqual(len(model_devices), 1)
        model_device = next(iter(model_devices))
        self.assertEqual(model_device, device)

    @skipIfNoFBGEMM
    def test_dict_output(self):
        """ Make sure quantization runs for models with dictionary output
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return {"output": self.conv(x["input"])}

        example_inputs = ({"input": torch.randn(1, 1, 1, 1)},)
        m = M().eval()
        qconfig_dict = {"": default_qconfig}
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)

    @override_qengines
    def test_attention(self):
        """ Make sure quantization runs for a corner case in attention module
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv(x)
                q, k, v = x.chunk(3, dim=0)
                q = q.contiguous().view(-1, 1).transpose(0, 1)
                k = k.contiguous().view(-1, 1).transpose(0, 1)
                v = v.contiguous().view(-1, 1).transpose(0, 1)
                torch._assert(
                    k.size(1) == 1, "key size should be equal to 1"
                )
                r = torch.mm(k, v)
                return q * k + r

        example_inputs = (torch.randn(3, 1, 1, 1),)
        m = M().eval()
        qconfig_dict = {
            "": None,
            "object_type": [
                (nn.Conv2d, default_qconfig),
            ]
        }
        # make sure it runs
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)

    def _test_standalone_module(
            self,
            interface_config,
            prepare_count_check,
            standalone_prepare_count_check,
            convert_count_check,
            standalone_convert_count_check):
        """ Test standalone module with different quantized input/quantized output
        configurations
        """
        class StandaloneModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return self.conv(x)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.standalone = StandaloneModule()

            def forward(self, x):
                x = self.conv(x)
                x = self.standalone(x)
                return x

        class RefM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        example_inputs = (torch.randn(1, 1, 1, 1),)
        # instantiate M and RefM and align the parameters
        original_m = M().eval()
        original_ref_m = RefM().eval()
        original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
        original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
        original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
        original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

        for is_name in [True, False]:
            sm_example_inputs = example_inputs
            if is_name:
                prepare_config = {
                    "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)]
                }
            else:
                prepare_config = {
                    "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)]
                }

            original_m_copy = copy.deepcopy(original_m)
            original_ref_m_copy = copy.deepcopy(original_ref_m)

            qconfig_dict = {"": default_qconfig}
            # check prepared model
            m = prepare_fx(
                original_m_copy,
                qconfig_dict,
                example_inputs=example_inputs,
                prepare_custom_config=prepare_config)
            # calibration
            m(*example_inputs)
            self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
            self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)

            # check converted/quantized model
            m = convert_fx(m)
            self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
            self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
            res = m(*example_inputs)

            # quantize the reference model
            ref_m = prepare_fx(
                original_ref_m_copy,
                qconfig_dict,
                example_inputs=example_inputs,
            )
            ref_m(*example_inputs)
            ref_m = convert_fx(ref_m)
            ref_res = ref_m(*example_inputs)
            self.assertEqual(res, ref_res)

    def test_standalone_module_float_interface(self):
        float_interface_config = {
            "input_quantized_idxs": [],  # float input
            "output_quantized_idxs": [],  # float output
        }
        interface_config = float_interface_config
        # input and output of first conv, observer for standalone module
        # will be inserted in the standalone module itself
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2
        }
        # for input and output of conv in the standalone module
        standalone_prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor) : 1,
            ns.call_module(nnq.Conv2d) : 1,
            ns.call_method("dequantize") : 1,
        }
        standalone_convert_count_check = {
            # standalone module will take float as input and output
            # so we'll see quantize and dequantize in the modoule
            ns.call_function(torch.quantize_per_tensor) : 1,
            ns.call_module(nnq.Conv2d): 1,
            ns.call_method("dequantize") : 1,
        }
        self._test_standalone_module(
            interface_config,
            prepare_count_check,
            standalone_prepare_count_check,
            convert_count_check,
            standalone_convert_count_check)

    def test_standalone_module_quantized_interface(self):
        quantized_interface_config = {
            "input_quantized_idxs": [0],  # quantized input
            "output_quantized_idxs": [0],  # quantized output
        }
        interface_config = quantized_interface_config
        # observer for input and output of first conv
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2
        }
        # for output of conv in the standalone module
        standalone_prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 1
        }
        convert_count_check = {
            # quantizing input for conv
            ns.call_function(torch.quantize_per_tensor) : 1,
            ns.call_module(nnq.Conv2d) : 1,
            # dequantizing output of standalone module
            ns.call_method("dequantize") : 1,
        }
        standalone_convert_count_check = {
            # quantization of input happens in parent module
            # quantization of output happens in the quantized conv module
            ns.call_function(torch.quantize_per_tensor) : 0,
            ns.call_module(nnq.Conv2d): 1,
            # dequantization for output happens in parent module
            ns.call_method("dequantize") : 0,
        }
        self._test_standalone_module(
            interface_config,
            prepare_count_check,
            standalone_prepare_count_check,
            convert_count_check,
            standalone_convert_count_check)

    @skipIfNoFBGEMM
    def test_qconfig_none(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        qconfig_dict = {"": default_qconfig,
                        "module_name": [("conv2", None)]}
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        # first conv is quantized, second conv is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method("dequantize"),
            ns.call_module(nn.Conv2d),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)

    def test_qconfig_module_type(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 1)
                self.linear = nn.Linear(9, 3)

            def forward(self, x):
                x = self.conv(x)
                x = x.reshape((1, -1))
                x = self.linear(x)
                return x

        m = M().eval()
        qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
        example_inputs = (torch.randn(1, 1, 3, 3),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        # conv is quantized, linear is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method("dequantize"),
            ns.call_module(nn.Linear),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)

    def test_qconfig_qat_module_type(self):
        class LinearRelu(nn.Sequential):
            def __init__(self) -> None:
                super().__init__(
                    nn.Linear(5, 5),
                    nn.ReLU(),
                )

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin_relu = LinearRelu()
                self.linear = nn.Linear(5, 5)

            def forward(self, x):
                x = self.lin_relu(x)
                x = self.linear(x)
                return x

        model = M().train()

        qconfig_dict = {
            "": None,
            "object_type": [
                (torch.nn.Linear, default_qat_qconfig),
                (torch.nn.ReLU, default_qat_qconfig),
            ],
        }
        example_inputs = (torch.rand(5, 5),)
        m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method("dequantize"),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)

    def test_qconfig_function(self):
        class M(torch.nn.Module):
            def forward(self, x, y):
                return x + y

        m = M().eval()
        qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
        data = torch.randn(1, 1, 1, 1)
        example_inputs = (data, data)
        m = prepare_fx(m, qconfig_dict, example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        # first conv is quantized, second conv is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.add),
            ns.call_method("dequantize"),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)

    def test_qconfig_module_name_regex(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        # first conv is quantized, second conv is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method("dequantize"),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)

    def test_qconfig_precedence(self):
        for device in get_supported_device_types():
            class M(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = nn.Linear(1, 1)
                    self.conv = nn.Conv2d(1, 1, 1)
                    self.module_conv1 = nn.Conv2d(1, 1, 1)
                    self.module_conv2 = nn.Conv2d(1, 1, 1)

                def forward(self, x):
                    # global
                    x = self.linear(x)
                    # global + object_type --> object_type
                    x = self.conv(x)
                    # global + object_type + module_name_regex --> module_name_regex
                    x = self.module_conv1(x)
                    # global + object_type + module_name_regex + module_name --> module_name
                    x = self.module_conv2(x)
                    return x

            m = M().to(device).eval()

            global_qconfig = default_qconfig
            object_type_qconfig = default_dynamic_qconfig
            module_name_regex_qconfig = float16_dynamic_qconfig
            module_name_qconfig = default_qat_qconfig
            qconfig_dict = {
                "": global_qconfig,
                "object_type": [(nn.Conv2d, object_type_qconfig)],
                "module_name_regex": [("module_conv*", module_name_regex_qconfig)],
                "module_name": [("module_conv2", module_name_qconfig)]}
            m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),))
            self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func)
            self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func)
            self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func)
            self.assertEqual(m_prep.conv.qconfig.weight.p.func, object_type_qconfig.weight.p.func)
            self.assertEqual(m_prep.module_conv1.qconfig.activation.p.func, module_name_regex_qconfig.activation.p.func)
            self.assertEqual(m_prep.module_conv1.qconfig.weight.p.func, module_name_regex_qconfig.weight.p.func)
            self.assertEqual(m_prep.module_conv2.qconfig.activation.p.func, module_name_qconfig.activation.p.func)
            self.assertEqual(m_prep.module_conv2.qconfig.weight.p.func, module_name_qconfig.weight.p.func)

    def test_qconfig_module_name_object_type_order(self):
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(1, 1)
                self.fc2 = nn.Linear(1, 1)

            def forward(self, x):
                x = self.fc1(x)
                x = self.fc2(x)
                x = torch.add(x, x)
                x = torch.add(x, x)
                return x

        class M2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(1, 1)
                self.fc2 = nn.Linear(1, 1)
                self.m1 = M1()

            def forward(self, x):
                x = self.fc1(x)
                x = self.fc2(x)
                x = torch.add(x, x)
                x = torch.add(x, x)
                x = self.m1(x)
                return x

        class M3(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(1, 1)
                self.fc2 = nn.Linear(1, 1)
                self.m2 = M2()

            def forward(self, x):
                x = self.fc1(x)
                x = self.fc2(x)
                x = torch.add(x, x)
                x = torch.add(x, x)
                x = self.m2(x)
                return x

        m = M3().eval()
        qconfig_dict = {
            "module_name_object_type_order": [
                # test various FQNs: global, single child, multiple children
                ("", nn.Linear, 0, torch.ao.quantization.default_qconfig),
                ("", torch.add, 0, torch.ao.quantization.default_qconfig),
                ("m2", nn.Linear, 1, torch.ao.quantization.default_qconfig),
                ("m2", torch.add, 1, torch.ao.quantization.default_qconfig),
                ("m2.m1", nn.Linear, 0, torch.ao.quantization.default_qconfig),
                ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig),
            ],
        }
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)

        node_list = [
            # m3
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method("dequantize"),
            ns.call_module(nn.Linear),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.add),
            ns.call_method("dequantize"),
            ns.call_function(torch.add),
            # m2
            ns.call_module(nn.Linear),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method("dequantize"),
            ns.call_function(torch.add),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.add),
            # m1
            ns.call_module(nnq.Linear),
            ns.call_method("dequantize"),
            ns.call_module(nn.Linear),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.add),
            ns.call_method("dequantize"),
            ns.call_function(torch.add),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)

        # test that function order overrides global qconfig
        class M4(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(1, 1)
                self.fc2 = nn.Linear(1, 1)

            def forward(self, x):
                x = self.fc1(x)
                x = self.fc2(x)
                x = torch.add(x, x)
                x = torch.add(x, x)
                return x

        m = M4().eval()
        qconfig_dict = {
            "": torch.ao.quantization.default_qconfig,
            "module_name_object_type_order": [
                ("", nn.Linear, 1, None),
                ("", torch.add, 1, None),
            ],
        }
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method("dequantize"),
            ns.call_module(nn.Linear),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.add),
            ns.call_method("dequantize"),
            ns.call_function(torch.add),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)


    @override_qengines
    def test_qconfig_dict_with_fused_modules(self):
        class LinearReLUModel(torch.nn.Module):
            def __init__(self, relu):
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)
                self.relu = relu

            def forward(self, x):
                x = self.linear(x)
                x = self.relu(x)
                return x

        class ConvReLUModel(torch.nn.Module):
            def __init__(self, relu):
                super().__init__()
                self.conv = torch.nn.Conv1d(3, 3, 3)
                self.relu = relu

            def forward(self, x):
                x = self.conv(x)
                x = self.relu(x)
                return x

        class ConvBnReLUModel(torch.nn.Module):
            def __init__(self, relu):
                super().__init__()
                self.conv = torch.nn.Conv1d(3, 3, 3)
                self.bn = torch.nn.BatchNorm1d(3)
                self.relu = relu

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                x = self.relu(x)
                return x

        for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
            for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
                m = model(relu).eval()
                qengine = torch.backends.quantized.engine
                qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine)
                # should not crash as in https://github.com/pytorch/pytorch/issues/75825
                prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))

    # TODO: move QConfigMapping tests to test/quantization/core
    def test_qconfig_mapping_set_global(self):
        qconfig = get_default_qconfig()
        qconfig_mapping = QConfigMapping()
        self.assertEqual(qconfig_mapping.global_qconfig, None)
        qconfig_mapping.set_global(qconfig)
        self.assertEqual(qconfig_mapping.global_qconfig, qconfig)

    def test_qconfig_mapping_set_object_type(self):
        qconfig1 = get_default_qconfig()
        qconfig2 = get_default_qconfig()
        qconfig3 = get_default_qconfig()
        self.assertNotEqual(qconfig1, qconfig2)
        self.assertNotEqual(qconfig1, qconfig3)
        qconfig_mapping = QConfigMapping()
        self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 0)
        # Insert some entries
        qconfig_mapping.set_object_type(torch.nn.Linear, qconfig1)
        qconfig_mapping.set_object_type(torch.nn.ReLU, qconfig2)
        self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 2)
        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig1)
        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
        # Override existing key
        qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3)
        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3)
        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)

    def test_qconfig_mapping_set_module_name_regex(self):
        qconfig1 = get_default_qconfig()
        qconfig2 = get_default_qconfig()
        qconfig3 = get_default_qconfig()
        self.assertNotEqual(qconfig1, qconfig2)
        self.assertNotEqual(qconfig1, qconfig3)
        qconfig_mapping = QConfigMapping()
        self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 0)
        # Insert some entries
        qconfig_mapping.set_module_name_regex("foo.*bar", qconfig1)
        qconfig_mapping.set_module_name_regex("foo.*", qconfig2)
        self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 2)
        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig1)
        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
        # Override existing key
        qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3)
        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3)
        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)

    def test_qconfig_mapping_set_module_name(self):
        qconfig1 = get_default_qconfig()
        qconfig2 = get_default_qconfig()
        qconfig3 = get_default_qconfig()
        self.assertNotEqual(qconfig1, qconfig2)
        self.assertNotEqual(qconfig1, qconfig3)
        qconfig_mapping = QConfigMapping()
        self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 0)
        # Insert some entries
        qconfig_mapping.set_module_name("mod1", qconfig1)
        qconfig_mapping.set_module_name("mod2", qconfig2)
        self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig1)
        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
        # Override existing key
        qconfig_mapping.set_module_name("mod1", qconfig3)
        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3)
        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)

    def test_qconfig_mapping_set_module_name_object_type_order(self):
        qconfig1 = get_default_qconfig()
        qconfig2 = get_default_qconfig()
        qconfig3 = get_default_qconfig()
        self.assertNotEqual(qconfig1, qconfig2)
        self.assertNotEqual(qconfig1, qconfig3)
        qconfig_mapping = QConfigMapping()
        self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 0)
        # Insert some entries
        qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig1)
        qconfig_mapping.set_module_name_object_type_order("mod2", torch.nn.ReLU, 1, qconfig2)
        self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2)
        key1 = ("mod1", torch.nn.Linear, 0)
        key2 = ("mod2", torch.nn.ReLU, 1)
        self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1)
        self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1)
        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1)
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
        # Override existing key
        qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3)
        self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2)
        self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1)
        self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3)
        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3)
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
        # No match
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None)
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None)
        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
                         qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None)

    def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2):
        """
        Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods.
        """
        return {
            _GLOBAL_DICT_KEY: global_qconfig,
            _OBJECT_TYPE_DICT_KEY: [
                (torch.nn.Linear, qconfig1),
                (torch.nn.ReLU, qconfig2),
            ],
            _MODULE_NAME_REGEX_DICT_KEY: [
                ("foo.*bar", qconfig1),
                ("foo.*", qconfig2),
            ],
            _MODULE_NAME_DICT_KEY: [
                ("bazbaz", qconfig1),
                ("borbor", qconfig2),
            ],
            _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
                ("bazbaz", torch.nn.Linear, 0, qconfig1),
                ("foofoo", torch.nn.ReLU, 1, qconfig2),
            ],
        }

        with self.assertRaises(ValueError) as context:
            m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))  # noqa: F821
        self.assertTrue(
            'Expected qconfig_dict to have the following keys:' in str(context.exception)
        )
        self.assertTrue('But found \'object_typo\' instead.' in str(context.exception))

    def test_qconfig_mapping_from_dict(self):
        global_qconfig = QConfig(123, "global")
        qconfig1 = QConfig(1, "one")
        qconfig2 = QConfig(2, "two")
        qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2)
        qconfig_dict["undefined_dict_key"] = [(123, qconfig1), (234, qconfig2)]
        qconfig_mapping = QConfigMapping.from_dict(qconfig_dict)
        self.assertEqual(qconfig_mapping.global_qconfig, global_qconfig)
        self.assertEqual(qconfig_mapping.object_type_qconfigs, OrderedDict({
            torch.nn.Linear: qconfig1,
            torch.nn.ReLU: qconfig2,
        }))
        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs, OrderedDict({
            "foo.*bar": qconfig1,
            "foo.*": qconfig2,
        }))
        self.assertEqual(qconfig_mapping.module_name_qconfigs, OrderedDict({
            "bazbaz": qconfig1,
            "borbor": qconfig2,
        }))
        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs, OrderedDict({
            ("bazbaz", torch.nn.Linear, 0): qconfig1,
            ("foofoo", torch.nn.ReLU, 1): qconfig2,
        }))

    def test_qconfig_mapping_to_dict(self):
        global_qconfig = QConfig(123, "global")
        qconfig1 = QConfig(1, "one")
        qconfig2 = QConfig(2, "two")
        qconfig_mapping = QConfigMapping().set_global(global_qconfig) \
            .set_object_type(torch.nn.Linear, qconfig1) \
            .set_object_type(torch.nn.ReLU, qconfig2) \
            .set_module_name_regex("foo.*bar", qconfig1) \
            .set_module_name_regex("foo.*", qconfig2) \
            .set_module_name("bazbaz", qconfig1) \
            .set_module_name("borbor", qconfig2) \
            .set_module_name_object_type_order("bazbaz", torch.nn.Linear, 0, qconfig1) \
            .set_module_name_object_type_order("foofoo", torch.nn.ReLU, 1, qconfig2)
        qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2)
        self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict)

    def test_qconfig_mapping_repr(self):
        self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str))

    def test_default_qconfig_mapping_override_global(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return self.conv(x)

        m = M().eval()
        my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer)
        qconfig_mapping = get_default_qconfig_mapping()
        # Override global qconfig
        old_global_qconfig = qconfig_mapping.global_qconfig
        qconfig_mapping.set_global(my_qconfig)
        # Verify the correct qconfig was used
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(m, qconfig_mapping, example_inputs)
        self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver))
        self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver))
        self.assertTrue(hasattr(m, "activation_post_process_0"))
        self.assertTrue(hasattr(m, "activation_post_process_1"))
        self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver))
        self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver))

    # Dummy classes for PrepareCustomConfig testing

    class _DummyStandaloneModule:
        pass

    class _DummyFloatModule:
        pass

    class _DummyObservedModule:
        pass

    class _DummyQuantizedModule:
        pass

    class _DummyNonTraceableModule1:
        pass

    class _DummyNonTraceableModule2:
        pass

    def test_prepare_custom_config_set_standalone_module_name(self):
        qconfig_mapping = QConfigMapping()
        example_inputs = (torch.randn(3),)
        child_prepare_custom_config = PrepareCustomConfig()
        backend_config = BackendConfig("my_backend")
        config_entry = StandaloneModuleConfigEntry(
            qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.standalone_module_names), 0)
        prepare_custom_config.set_standalone_module_name(
            "module1", qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
        self.assertEqual(list(prepare_custom_config.standalone_module_names.keys()), ["module1"])
        self.assertEqual(prepare_custom_config.standalone_module_names["module1"], config_entry)

    def test_prepare_custom_config_set_standalone_module_class(self):
        qconfig_mapping = QConfigMapping()
        example_inputs = (torch.randn(3),)
        child_prepare_custom_config = PrepareCustomConfig()
        backend_config = BackendConfig("my_backend")
        config_entry = StandaloneModuleConfigEntry(
            qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.standalone_module_classes), 0)
        prepare_custom_config.set_standalone_module_class(
            self._DummyStandaloneModule, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
        self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1)
        self.assertTrue(self._DummyStandaloneModule in prepare_custom_config.standalone_module_classes)
        self.assertEqual(prepare_custom_config.standalone_module_classes[self._DummyStandaloneModule], config_entry)

    def test_prepare_custom_config_set_float_to_observed_mapping(self):
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 0)
        prepare_custom_config.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule, QuantType.STATIC)
        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1)
        self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC])
        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1)
        self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC])
        self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule],
                         self._DummyObservedModule)

    def test_prepare_custom_config_set_non_traceable_module_names(self):
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.non_traceable_module_names), 0)
        prepare_custom_config.set_non_traceable_module_names(["module1", "module2"])
        self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module1", "module2"])

    def test_prepare_custom_config_set_non_traceable_module_classes(self):
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.non_traceable_module_classes), 0)
        prepare_custom_config.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
        self.assertEqual(prepare_custom_config.non_traceable_module_classes,
                         [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])

    def test_prepare_custom_config_set_input_quantized_indexes(self):
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.input_quantized_indexes), 0)
        prepare_custom_config.set_input_quantized_indexes([0, 1])
        self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1])

    def test_prepare_custom_config_set_output_quantized_indexes(self):
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.output_quantized_indexes), 0)
        prepare_custom_config.set_output_quantized_indexes([0, 1])
        self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1])

    def test_prepare_custom_config_set_preserved_attributes(self):
        prepare_custom_config = PrepareCustomConfig()
        self.assertEqual(len(prepare_custom_config.preserved_attributes), 0)
        prepare_custom_config.set_preserved_attributes(["attr1", "attr2"])
        self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"])

    def _get_dummy_prepare_custom_config_dict(self):
        """
        Return a dummy prepare_custom_config_dict to test PrepareCustomConfig's to_dict and from_dict methods.
        """
        return {
            STANDALONE_MODULE_NAME_DICT_KEY: [(
                "module1",
                QConfigMapping(),
                (torch.randn(3),),
                PrepareCustomConfig(),
                BackendConfig("my_backend"),
            )],
            STANDALONE_MODULE_CLASS_DICT_KEY: [(
                self._DummyStandaloneModule,
                QConfigMapping(),
                (torch.randn(10),),
                PrepareCustomConfig(),
                BackendConfig("my_backend"),
            )],
            FLOAT_TO_OBSERVED_DICT_KEY: {
                "static": {
                    self._DummyFloatModule: self._DummyObservedModule
                },
            },
            NON_TRACEABLE_MODULE_NAME_DICT_KEY: ["module2", "module3"],
            NON_TRACEABLE_MODULE_CLASS_DICT_KEY: [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2],
            INPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1],
            OUTPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1],
            PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]
        }

    def test_prepare_custom_config_from_dict(self):
        prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict()
        (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0]
        (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0]
        sm_config_entry1 = StandaloneModuleConfigEntry(qm1, ei1, pcc1, bcd1)
        sm_config_entry2 = StandaloneModuleConfigEntry(qm2, ei2, pcc2, bcd2)
        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config_dict)

        # Standalone modules
        self.assertEqual(len(prepare_custom_config.standalone_module_names), 1)
        self.assertTrue(sm_name in prepare_custom_config.standalone_module_names)
        self.assertEqual(prepare_custom_config.standalone_module_names[sm_name], sm_config_entry1)
        self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1)
        self.assertTrue(sm_class in prepare_custom_config.standalone_module_classes)
        self.assertEqual(prepare_custom_config.standalone_module_classes[sm_class], sm_config_entry2)

        # Float to observed mapping
        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1)
        self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC])
        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1)
        self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC])
        self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule],
                         self._DummyObservedModule)

        # Other
        self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module2", "module3"])
        self.assertEqual(prepare_custom_config.non_traceable_module_classes,
                         [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
        self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1])
        self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1])
        self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"])

    def test_prepare_custom_config_to_dict(self):
        prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict()
        (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0]
        (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0]
        prepare_custom_config = PrepareCustomConfig() \
            .set_standalone_module_name(sm_name, qm1, ei1, pcc1, bcd1) \
            .set_standalone_module_class(sm_class, qm2, ei2, pcc2, bcd2) \
            .set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule) \
            .set_non_traceable_module_names(["module2", "module3"]) \
            .set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) \
            .set_input_quantized_indexes([0, 1]) \
            .set_output_quantized_indexes([0, 1]) \
            .set_preserved_attributes(["attr1", "attr2"])
        # PrepareCustomConfig.to_dict also converts internal QConfigMappings and PrepareCustomConfigs to dicts
        prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] = (sm_name, qm1.to_dict(), ei1, pcc1.to_dict(), bcd1)
        prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] = (sm_class, qm2.to_dict(), ei2, pcc2.to_dict(), bcd2)
        self.assertEqual(prepare_custom_config.to_dict(), prepare_custom_config_dict)

    def test_convert_custom_config_set_observed_to_quantized_mapping(self):
        convert_custom_config = ConvertCustomConfig()
        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 0)
        convert_custom_config.set_observed_to_quantized_mapping(
            self._DummyObservedModule, self._DummyQuantizedModule, QuantType.STATIC)
        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1)
        self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC])
        self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC])
        self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule],
                         self._DummyQuantizedModule)

    def test_convert_custom_config_set_preserved_attributes(self):
        convert_custom_config = ConvertCustomConfig()
        self.assertEqual(len(convert_custom_config.preserved_attributes), 0)
        convert_custom_config.set_preserved_attributes(["attr1", "attr2"])
        self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"])

    def _get_dummy_convert_custom_config_dict(self):
        """
        Return a dummy convert_custom_config_dict to test ConvertCustomConfig's to_dict and from_dict methods.
        """
        return {
            OBSERVED_TO_QUANTIZED_DICT_KEY: {
                "static": {
                    self._DummyObservedModule: self._DummyQuantizedModule
                },
            },
            PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]
        }

    def test_convert_custom_config_from_dict(self):
        convert_custom_config_dict = self._get_dummy_convert_custom_config_dict()
        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config_dict)
        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1)
        self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC])
        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]), 1)
        self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC])
        self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule],
                         self._DummyQuantizedModule)
        self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"])

    def test_convert_custom_config_to_dict(self):
        convert_custom_config = ConvertCustomConfig() \
            .set_observed_to_quantized_mapping(self._DummyObservedModule, self._DummyQuantizedModule) \
            .set_preserved_attributes(["attr1", "attr2"])
        self.assertEqual(convert_custom_config.to_dict(), self._get_dummy_convert_custom_config_dict())

    def test_fuse_custom_config_set_preserved_attributes(self):
        fuse_custom_config = FuseCustomConfig()
        self.assertEqual(len(fuse_custom_config.preserved_attributes), 0)
        fuse_custom_config.set_preserved_attributes(["attr1", "attr2"])
        self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"])

    def test_fuse_custom_config_from_dict(self):
        fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]}
        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config_dict)
        self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"])

    def test_fuse_custom_config_to_dict(self):
        fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]}
        fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
        self.assertEqual(fuse_custom_config.to_dict(), fuse_custom_config_dict)

    def test_remove_qconfig(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.avg_pool = torch.nn.AvgPool2d(1)

            def forward(self, x):
                return self.avg_pool(x)

        m = M().eval()
        qconfig_dict = {'': default_qconfig}
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        for name, module in m.named_modules():
            self.assertFalse(hasattr(module, 'qconfig'),
                             'qconfig is not removed for ' + name)

    def test_return_none(self):
        class M(torch.nn.Module):
            def forward(self, x):
                pass

        m = M().eval()
        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),))
        m = convert_fx(m)

    def test_default_quant_after_none_qconfig(self):
        """ Make sure default quant is inserted properly"""
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = x.transpose(1, 2)
                x = self.conv2(x)

        m = M().eval()
        qconfig_dict = {
            "": default_qconfig,
            "module_name": [
                ("conv1", None)
            ]
        }
        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
        m = convert_fx(m)

    def test_qconfig_for_call_method(self):
        class Sub(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = x.transpose(2, 3)
                x = self.conv(x)
                return x.transpose(2, 3)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = Sub()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.sub(x)
                x = self.conv2(x)
                return x.transpose(2, 3)

        qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]}
        # since sub is configured to have qconfig None, we should dequantize the output
        # of self.conv1 and quantize the input of self.conv2
        # dequantize after conv2 should happen after transpose since
        # it is configured with default_qconfig
        # nodes in Sub module instance is not quantized
        node_list1 = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method("dequantize"),
            ns.call_method("transpose"),
            ns.call_module(nn.Conv2d),
            ns.call_method("transpose"),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method("transpose"),
            ns.call_method("dequantize")
        ]

        qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]}
        # Only nodes in Sub module instance are quantized
        # the first transpose is not quantized because the input is not quantized
        node_list2 = [
            ns.call_module(nn.Conv2d),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("transpose"),
            ns.call_module(nnq.Conv2d),
            ns.call_method("transpose"),
            ns.call_method("dequantize"),
            ns.call_module(nn.Conv2d),
            ns.call_method("transpose"),
        ]

        for qconfig_dict, node_list in [
                (qconfig_dict1, node_list1),
                (qconfig_dict2, node_list2)
        ]:
            example_inputs = (torch.randn(2, 1, 3, 3),)
            m = M().eval()
            m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
            m(torch.randn(2, 1, 3, 3))
            m = convert_fx(m)
            self.checkGraphModuleNodes(m, expected_node_list=node_list)
            # make sure it runs
            m(*example_inputs)

    def test_qconfig_for_call_func(self):
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(
                    Linear(),
                    Linear()
                )
                self.mods2 = Linear()

            def forward(self, x):
                x = self.mods1(x)
                x = self.mods2(x)
                return x

        model = M().eval()
        example_inputs = (torch.rand(5, 5),)
        qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]}
        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)

        m = convert_fx(m)
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.nn.functional.linear)
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)
        m(torch.rand(5, 5))

    def test_preserve_attributes(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return self.conv(x)

        m = M()
        m.eval()
        m.preserved_attr = 3
        prepare_custom_config_dict = {
            "preserved_attributes": ["preserved_attr"]
        }
        example_inputs = (torch.randn(1, 1, 1, 1),)
        m = prepare_fx(
            m,
            {"": default_qconfig},
            example_inputs=example_inputs,
            prepare_custom_config=prepare_custom_config_dict)

        def assertAttrPreserved(m):
            self.assertTrue(hasattr(m, "preserved_attr"))
            self.assertEqual(m.preserved_attr, 3)

        assertAttrPreserved(m)
        convert_custom_config_dict = {
            "preserved_attributes": ["preserved_attr"]
        }
        m = convert_fx(m, convert_custom_config=convert_custom_config_dict)
        assertAttrPreserved(m)

    @skipIfNoFBGEMM
    def test_qat_and_script(self):
        model = LinearModelWithSubmodule().train()
        qengine = torch.backends.quantized.engine
        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
        x = torch.randn(5, 5)
        example_inputs = (x,)
        model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)

        # ensure scripting works
        scripted = torch.jit.script(model)
        # run one round to make sure model runs
        scripted(x)
        FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \
                   .run(scripted.graph)

        # disable fake_quant and observer
        for epoch in range(3):
            if epoch == 1:
                scripted.apply(torch.ao.quantization.disable_observer)
            if epoch == 2:
                scripted.apply(torch.ao.quantization.disable_fake_quant)

        # ensure the fake_quant and observer have been disabled.
        matches = ['.fake_quant_enabled', '.observer_enabled']
        for key, v in scripted.state_dict().items():
            if any(x in key for x in matches):
                self.assertEqual(v, torch.tensor([0], dtype=torch.int64))

        # enable them back
        scripted.apply(torch.ao.quantization.enable_fake_quant)
        scripted.apply(torch.ao.quantization.enable_observer)
        for key, v in scripted.state_dict().items():
            if any(x in key for x in matches):
                self.assertEqual(v, torch.tensor([1], dtype=torch.int64))

    @skipIfNoFBGEMM
    def test_save_observer_state_dict(self):
        orig = LinearModelWithSubmodule().eval()
        model = orig
        qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')}
        x = torch.randn(5, 5)
        model = prepare_fx(model, qconfig_dict, example_inputs=(x,))

        # run it through input
        model(x)
        # save state_dict of model
        obs_dict = torch.ao.quantization.get_observer_state_dict(model)

        quant = convert_fx(model)

        b = io.BytesIO()
        torch.save(obs_dict, b)

        # Load the stats into new model
        for weights_only in [True, False]:
            b.seek(0)
            model_2 = orig
            model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,))

            loaded_dict = torch.load(b, weights_only=weights_only)
            torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict)

            quant_2 = convert_fx(model_2)

            # Verify that loaded state dict produces same results.
            self.assertEqual(quant(x), quant_2(x))

    @skipIfNoFBGEMM
    def test_custom_module_class(self):
        class CustomModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, x):
                return self.linear(x)

        class ObservedCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_float(cls, float_module):
                assert hasattr(float_module, 'qconfig')
                observed = cls(float_module.linear)
                observed.qconfig = float_module.qconfig
                return observed

        class StaticQuantCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_observed(cls, observed_module):
                assert hasattr(observed_module, 'qconfig')
                assert hasattr(observed_module, 'activation_post_process')
                observed_module.linear.activation_post_process = \
                    observed_module.activation_post_process
                quantized = cls(nnq.Linear.from_float(observed_module.linear))
                return quantized

        class DynamicQuantCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_observed(cls, observed_module):
                assert hasattr(observed_module, 'qconfig')
                observed_module.linear.qconfig = observed_module.qconfig
                quantized = cls(nnqd.Linear.from_float(observed_module.linear))
                return quantized

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)
                self.custom = CustomModule()

            def forward(self, x):
                x = self.linear(x)
                x = self.custom(x)
                return x

        class RefM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = torch.nn.Linear(3, 3)
                self.linear2 = torch.nn.Linear(3, 3)

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear2(x)
                return x

        # instantiate M and RefM and align the parameters
        original_m = M().eval()
        original_ref_m = RefM().eval()
        original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach())
        original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach())
        original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach())
        original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach())

        a16_qconfig = QConfig(
            activation=MinMaxObserver.with_args(dtype=torch.qint32, quant_min=0, quant_max=65536),
            weight=default_weight_observer,
        )
        test_configs = {
            "static": (default_qconfig, StaticQuantCustomModule, 3),
            "static_a16": (a16_qconfig, StaticQuantCustomModule, 3),
            "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0)
        }

        for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]:
            key = _get_quant_type_to_str(quant_type)
            qconfig, quantized_module_class, num_observers = test_configs[key]
            qconfig_dict = {"": qconfig}
            if key == "static":
                prepare_custom_config_dict = {
                    "float_to_observed_custom_module_class": {
                        "static": {
                            CustomModule: ObservedCustomModule
                        }
                    }
                }
                convert_custom_config_dict = {
                    "observed_to_quantized_custom_module_class": {
                        "static": {
                            ObservedCustomModule: quantized_module_class
                        }
                    }
                }
            else:
                prepare_custom_config_dict = {
                    "non_traceable_module_class": [
                        CustomModule
                    ]
                }
                convert_custom_config_dict = {
                    "observed_to_quantized_custom_module_class": {
                        "dynamic": {
                            CustomModule: quantized_module_class
                        }
                    }
                }

            example_inputs = (torch.randn(3, 3),)
            # check prepared model
            m = prepare_fx(
                copy.deepcopy(original_m),
                qconfig_dict,
                example_inputs=example_inputs,
                prepare_custom_config=prepare_custom_config_dict)
            # calibration
            m(*example_inputs)
            # all activation observers are inserted in the top level module
            count_check = {
                ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)

            # check converted/quantized model
            m = convert_fx(
                m,
                convert_custom_config=convert_custom_config_dict)
            if quant_type == QuantType.STATIC:
                count_check = {
                    ns.call_function(torch.quantize_per_tensor) : 1,
                    ns.call_module(nnq.Linear) : 1,
                    ns.call_method('dequantize') : 1,
                }
                self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
            self.assertEqual(type(m.custom), quantized_module_class)
            res = m(*example_inputs)

            # quantize the reference model
            ref_m = prepare_fx(
                copy.deepcopy(original_ref_m), qconfig_dict, example_inputs=example_inputs)
            ref_m(*example_inputs)
            ref_m = convert_fx(ref_m)
            ref_res = ref_m(*example_inputs)
            self.assertEqual(res, ref_res)

    @skipIfNoFBGEMM
    def test_custom_module_class_input_has_multiple_users(self):
        """ Tests that the flow still works when the input of custom module
        has multiple users
        """
        class CustomModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, x):
                return self.linear(x)

        class ObservedCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_float(cls, float_module):
                assert hasattr(float_module, 'qconfig')
                observed = cls(float_module.linear)
                observed.qconfig = float_module.qconfig
                return observed

        class StaticQuantCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_observed(cls, observed_module):
                assert hasattr(observed_module, 'qconfig')
                assert hasattr(observed_module, 'activation_post_process')
                observed_module.linear.activation_post_process = \
                    observed_module.activation_post_process
                quantized = cls(nnq.Linear.from_float(observed_module.linear))
                return quantized

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)
                self.custom = CustomModule()

            def forward(self, x0):
                x1 = self.custom(x0)
                x2 = self.linear(x0)
                return x1 + x2

        prepare_custom_config_dict = {
            "float_to_observed_custom_module_class": {
                "static": {
                    CustomModule: ObservedCustomModule
                }
            }
        }
        convert_custom_config_dict = {
            "observed_to_quantized_custom_module_class": {
                "static": {
                    ObservedCustomModule: StaticQuantCustomModule
                }
            }
        }
        m = M().eval()
        example_inputs = (torch.randn(3, 3),)
        m = prepare_fx(
            m,
            {"": default_qconfig},
            example_inputs=example_inputs,
            prepare_custom_config=prepare_custom_config_dict)
        # make sure it works
        m = convert_fx(
            m,
            convert_custom_config=convert_custom_config_dict)
        # make sure it runs
        m(*example_inputs)

    @skipIfNoFBGEMM
    def test_custom_module_class_input_has_duplicate_nodes(self):
        """ Tests that the flow still works when the graph has
        multiple nodes with the same custom module target.
        """
        class CustomModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, x):
                return self.linear(x)

        class ObservedCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_float(cls, float_module):
                assert hasattr(float_module, 'qconfig')
                observed = cls(float_module.linear)
                observed.qconfig = float_module.qconfig
                return observed

        class StaticQuantCustomModule(torch.nn.Module):
            def __init__(self, linear):
                super().__init__()
                self.linear = linear

            def forward(self, x):
                return self.linear(x)

            @classmethod
            def from_observed(cls, observed_module):
                assert hasattr(observed_module, 'qconfig')
                assert hasattr(observed_module, 'activation_post_process')
                observed_module.linear.activation_post_process = \
                    observed_module.activation_post_process
                quantized = cls(nnq.Linear.from_float(observed_module.linear))
                return quantized

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.custom = CustomModule()

            def forward(self, x0):
                x1 = self.custom(x0)
                x2 = self.custom(x0)
                return x1 + x2

        prepare_custom_config_dict = {
            "float_to_observed_custom_module_class": {
                "static": {
                    CustomModule: ObservedCustomModule
                }
            }
        }
        convert_custom_config_dict = {
            "observed_to_quantized_custom_module_class": {
                "static": {
                    ObservedCustomModule: StaticQuantCustomModule
                }
            }
        }
        m = M().eval()
        example_inputs = (torch.randn(3, 3),)
        m = prepare_fx(
            m,
            {"": default_qconfig},
            example_inputs=example_inputs,
            prepare_custom_config=prepare_custom_config_dict)
        # make sure it works
        m = convert_fx(
            m,
            convert_custom_config=convert_custom_config_dict)
        # make sure it runs
        m(*example_inputs)

    @skipIfNoFBGEMM
    def test_non_traceable_module(self):
        class NonTraceable(torch.nn.Module):
            def forward(self, x):
                for k in x.keys():
                    print(x[k])
                return x

        class NonTraceable2(torch.nn.Module):
            def forward(self, x):
                # data dependent control flow is not traceable
                for i in x:
                    print(i)
                return x

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m1 = NonTraceable()
                self.m2 = NonTraceable2()

            def forward(self, x):
                x = self.m1(x)
                x = self.m2(x)
                return x

        m = M().eval()
        qconfig_dict = {"": default_qconfig}
        prepare_custom_config_dict = {
            "non_traceable_module_name": [
                "m1"
            ],
            "non_traceable_module_class": [
                NonTraceable2
            ]
        }
        m = prepare_fx(
            m, qconfig_dict,
            example_inputs=({"key": torch.randn(1)},),
            prepare_custom_config=prepare_custom_config_dict)

        node_occurrence = {
            ns.call_module(NonTraceable) : 1,
            ns.call_module(NonTraceable2) : 1,
        }
        # make sure these modules are not traced
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

    def test_prepared_model_deepcopy(self):
        """Ensures that copy.deepcopy works correctly on a prepared model.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self._foobar = 'foobar'
                self.foobar2 = 'foobar2'

            def forward(self, x):
                x = self.conv(x)
                return x

        m = M()
        m.eval()
        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        example_inputs = (torch.randn(4, 1, 4, 4),)
        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        # calibrate
        prepared(*example_inputs)
        # copy
        prepared_copy = copy.deepcopy(prepared)
        # quantize, should run with no errors
        quantized = convert_fx(prepared_copy)

    def test_quantized_model_type(self):
        """ Test state_dict and deepcopy works properly in the quantized model
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.linear(x)

        example_inputs = (torch.rand(8, 5),)
        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        m = convert_fx(m)
        # test deepcopy
        m_copy = copy.deepcopy(m)
        self.assertEqual(m_copy(*example_inputs), m(*example_inputs))

        # test state_dict
        state_dict = m.state_dict()
        m_new = M().eval()
        m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs)
        m_new = convert_fx(m_new)
        m_new.load_state_dict(state_dict)
        self.assertEqual(m_new(*example_inputs), m(*example_inputs))

    def test_dequantize(self):
        r""" Test to make sure dequantize node are placed before
        non-quantizable node
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.act = torch.nn.GELU()

            def forward(self, x):
                x = self.conv(x)
                return self.act(x)

        data = torch.rand(5, 1, 3, 3, dtype=torch.float)
        for quant_type in self.static_quant_types:
            node_list = [
                ns.call_module(nnq.Conv2d),
                ns.call_method("dequantize"),
                ns.call_module(nn.GELU),
            ]
            self.checkGraphModeFxOp(
                M().eval(), (data,), quant_type, expected_node_list=node_list)

    def test_sequential(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.convs = torch.nn.Sequential(
                    torch.nn.Conv2d(1, 1, 1),
                    torch.nn.Conv2d(1, 1, 1)
                )

            def forward(self, x):
                x = self.convs(x)
                return x

        data = torch.rand(5, 1, 3, 3, dtype=torch.float)
        for quant_type in self.static_quant_types:
            node_list = [
                ns.call_module(nnq.Conv2d),
                ns.call_module(nnq.Conv2d),
            ]
            self.checkGraphModeFxOp(
                M().eval(), (data,), quant_type, expected_node_list=node_list)

    def _test_quantized_inputs_outputs(
            self, prepare_custom_config_dict, prepare_count_check,
            convert_count_check):
        """
        Test the option to have inputs and outputs of the graph quantized
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        # quantized input, quantized output
        m = M()
        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        example_inputs = (torch.randn(1, 1, 4, 4),)
        m.eval()
        mp = torch.ao.quantization.quantize_fx.prepare_fx(
            m, qconfig_dict,
            example_inputs=example_inputs,
            prepare_custom_config=prepare_custom_config_dict)
        self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
        mp(*example_inputs)
        mq = torch.ao.quantization.quantize_fx.convert_fx(mp)
        self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)

    def test_quantized_input_quantized_output(self):
        prepare_custom_config_dict = {
            'input_quantized_idxs': [0], 'output_quantized_idxs': [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method('dequantize'): 0,
        }
        self._test_quantized_inputs_outputs(
            prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_fp32_input_quantized_output(self):
        prepare_custom_config_dict = {
            'output_quantized_idxs': [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 0,
        }
        self._test_quantized_inputs_outputs(
            prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_quantized_input_fp32_output(self):
        prepare_custom_config_dict = {
            'input_quantized_idxs': [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method('dequantize'): 1,
        }
        self._test_quantized_inputs_outputs(
            prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_fp32_input_fp32_output(self):
        prepare_custom_config_dict = {}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1,
        }
        self._test_quantized_inputs_outputs(
            prepare_custom_config_dict, prepare_count_check, convert_count_check)

    @skipIfNoFBGEMM
    def test_convtranspose_per_channel_fails_early(self):
        r"""
        Verifies that attempting to quantize a ConvTranspose module with per-Channel
        weight observers fails in the prepare step, as opposed to the convert step.
        """
        m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
        m.eval()
        qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')}
        with self.assertRaises(AssertionError) as context:
            mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
        self.assertTrue(
            str(context.exception) ==
            'Per channel weight observer is not supported yet for ConvTranspose{n}d.')

    @skipIfNoFBGEMM
    def test_qparams_buffers(self):
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(
                    Linear(),
                    Linear()
                )
                self.mods2 = Linear()

            def forward(self, x):
                x = self.mods1(x)
                x = self.mods2(x)
                return x

        model = M().eval()
        qconfig_dict = {"": default_qconfig}
        example_inputs = (torch.rand(5, 5),)
        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        keys = m.state_dict().keys()
        quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0
        for k in keys:
            if 'input_scale' in k:
                quant_scale_count = quant_scale_count + 1
            elif 'input_zero_point' in k:
                quant_zero_point = quant_zero_point + 1
            elif 'scale' in k:
                scale_count = scale_count + 1
            elif 'zero_point' in k:
                zero_point_count = zero_point_count + 1

        # Expect each quantized linear op to have a scale and zero point
        self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict")
        self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict")
        m(*example_inputs)
        # ensure it is scriptable
        scripted = torch.jit.script(m)
        scripted_keys = scripted.state_dict().keys()
        scripted.mods1_0_packed_weight_0 = m.state_dict()["mods1_0_packed_weight_0"]
        non_packed_weight_keys = [key for key in keys if "_packed_weight" not in key]
        self.assertTrue(
            set(scripted_keys) == set(non_packed_weight_keys),
            "Expected the scripted model to preserve the state_dict for non-packed weight attributes")
        # TODO: probably don't want to hardcode the attribute names, since they are generated
        for attr_name in [
                "mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
                "mods1_0_scale_1", "mods1_0_zero_point_1",
                "mods1_1_scale_1", "mods1_1_zero_point_1",
                "mods2_scale_1", "mods2_zero_point_1"]:
            self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")

    @skipIfNoFBGEMM
    def test_packed_weight_fused_op(self):
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return F.linear(x, self.w, self.b)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(
                    Linear(),
                    Linear()
                )
                self.mods2 = Linear()
                self.relu = F.relu

            def forward(self, x):
                x = self.mods1(x)
                x = self.mods2(x)
                x = self.relu(x)
                return x

        model = M().eval()
        example_inputs = (torch.rand(5, 5),)
        qconfig_dict = {"": default_qconfig}
        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        assert hasattr(m, "mods1_0_packed_weight_0")
        assert hasattr(m, "mods1_1_packed_weight_0")
        assert hasattr(m, "mods2_packed_weight_0")

    @skipIfNoFBGEMM
    def test_mul_add_fp16_config(self):
        with override_quantized_engine('fbgemm'):
            class Linear(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.w = torch.ones(5, 5)
                    self.b = torch.zeros(5)

                def forward(self, x):
                    return torch.nn.functional.linear(x, self.w, self.b)

            class M(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.mods1 = torch.nn.Sequential(
                        Linear(),
                        Linear()
                    )
                    self.mods2 = Linear()

                def forward(self, x):
                    x = x * 5
                    x = x + 5
                    x = self.mods1(x)
                    x = self.mods2(x)
                    return x
            model = M().eval()
            qconfig_dict = {"": float16_dynamic_qconfig}
            example_inputs = (torch.rand(5, 5),)
            m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
            m = convert_fx(m)
            # make sure it runs
            m(*example_inputs)

    def test_getattr_with_nontensor_result(self):
        """
        Verifies that binary ops get quantized correctly if some
        of the args are nodes but not Tensors, such as an `x.ndim`
        pattern.
        """
        class M1(torch.nn.Module):
            def forward(self, x):
                dims = x.ndim
                dims_sub = dims - 1
                dims_sub2 = dims_sub - 1
                x = torch.add(x, dims_sub2)
                return x

        class M2(torch.nn.Module):
            def forward(self, x):
                dims = x.ndim
                dims_sub = dims - 2
                mul = [1] * dims_sub
                dims_list = [-1, x.size(1)] + mul
                x = x.view(dims_list)
                return x

        class M3(torch.nn.Module):
            def forward(self, x):
                shape = x.shape
                x = x.view(shape)
                return x

        for cls in (M1, M2, M3):
            m = cls().eval()
            example_inputs = (torch.rand(4, 4, 4, 4),)
            m(*example_inputs)
            qconfig_dict = {'': torch.ao.quantization.default_qconfig}
            mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
            mp(torch.rand(4, 4, 4, 4))
            mc = convert_fx(mp)

    class _NonReferenceTestModel(nn.Module):
        def __init__(self, func, lin_in, lin_out):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.lin = nn.Linear(lin_in, lin_out)
            self.func = func

        def forward(self, x, y, z):
            x = self.pool(F.relu(self.conv1(x)))
            x = torch.flatten(x, 1)
            x = self.func(x, y, z)
            x = self.lin(x)
            return x

    # This function looks at the node specified by the NodeInfo in the key of
    # node_info_to_non_tensor_args and checks that the args at specified indices
    # are not observed (since they are non tensors). If the args at those indices
    # are a tuple/list (which do not show up as nodes) the function checks the
    # individual elements of the tuple/list recursively.
    def _check_not_observed(self, model, node_info_to_non_tensor_args):

        # this is a helper function (for easier recursion) that checks whether
        # arg_node is observed
        def _check_node_not_observed(model, arg_node, node):
            if isinstance(arg_node, (tuple, list)):
                for new_node in arg_node:
                    _check_node_not_observed(model, new_node, node)
            elif arg_node.op == "call_module":
                self.assertTrue(
                    not _is_activation_post_process(getattr(model, arg_node.target)),
                    f"Arg: {arg_node} of node: {node} is observed but is not a float tensor",
                )

        for node in model.graph.nodes:
            indices = node_info_to_non_tensor_args.get(
                NodeInfo(node.op, node.target), []
            )
            for index in indices:
                if index < len(node.args):
                    arg_node = node.args[index]
                    _check_node_not_observed(model, arg_node, node)

    # This test checks that the model gets prepared correct, doesn't have observers
    # on specific ops (see _check_not_observed) and that the prepared model runs
    def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args):
        model.eval()
        qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")}
        prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args))
        self._check_not_observed(prepared_model, node_info_to_non_tensor_args)
        prepared_model(*args)

    def test_masked_fill_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.masked_fill(y, z)

        model = self._NonReferenceTestModel(func, 1176, 1)
        args = [torch.randn(5, 3, 32, 32), torch.randn(1176) > 0, 0.1]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "masked_fill"): [1, 2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_permute_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.permute(y, z)

        model = self._NonReferenceTestModel(func, 1176, 1)
        args = [torch.randn(5, 3, 32, 32), 0, 1]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "permute"): [1, 2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_repeat_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.repeat(y, z)

        model = self._NonReferenceTestModel(func, 1176, 1)
        args = [torch.randn(5, 3, 32, 32), 2, 1]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "repeat"): [1, 2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_reshape_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.reshape(-1, y)

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), 5, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_size_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.reshape((-1, x.size(y)))

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), 0, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "size"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_transpose_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.transpose(y, z)

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), 0, 1]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_torch_transpose_nontensor_args_not_observed(self):
        # TODO: make torch.transpose traceable by fx when using
        # variable nontensor arguments
        # func = lambda x, y, z: torch.transpose(x, y, z) # error
        def func(x, y, z):
            return torch.transpose(x, 0, 1)

        model = self._NonReferenceTestModel(func, 5, 1)
        node_info_to_non_tensor_args = {
            NodeInfo("call_method", torch.transpose): [1, 2]
        }
        args = [torch.randn(5, 3, 32, 32), 0, 1]
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_unsqueeze_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.unsqueeze(y)

        model = self._NonReferenceTestModel(func, 1176, 1)
        args = [torch.randn(5, 3, 32, 32), 1, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_unsqueeze__nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.unsqueeze_(y)

        model = self._NonReferenceTestModel(func, 1176, 1)
        args = [torch.randn(5, 3, 32, 32), 1, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze_"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_torch_unsqueeze_nontensor_args_not_observed(self):
        # TODO: make torch.unsqueeze scriptable by fx when using
        # variable nontensor arguments
        # func = lambda x, y, z: torch.unsqueeze(x, y) # error
        def func(x, y, z):
            return torch.unsqueeze(x, 1)

        model = self._NonReferenceTestModel(func, 1176, 1)
        args = [torch.randn(5, 3, 32, 32), 1, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", torch.unsqueeze): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_view_nontensor_args_not_observed(self):
        def func(x, y, z):
            return x.view(-1, y)

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), 5, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "view"): [2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_list_args(self):
        def func(x, y, z):
            return x.reshape(y)

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), [-1, 5], None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_split_list_args(self):
        def func(x, y, z):
            return x.reshape([y, z])

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), -1, 5]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_tuple_args(self):
        def func(x, y, z):
            return x.reshape(y)

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), (-1, 5), None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_split_tuple_args(self):
        def func(x, y, z):
            return x.reshape((y, z))

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), -1, 5]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_dict_args(self):
        def func(x, y, z):
            return x.transpose(y["first"], y["second"])

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), {"first": 0, "second": 1}, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_dict_tuple_args(self):
        class reshape_module(nn.Module):
            def forward(self, x, y, z):
                return x.reshape(y["shape"])

        model = self._NonReferenceTestModel(reshape_module(), 5, 1)
        args = [torch.randn(5, 3, 32, 32), {"shape": (-1, 5)}, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_propagate_dtypes_for_known_nodes_dict_split_tuple_args(self):
        def func(x, y, z):
            return x.reshape((y["first"], y["second"]))

        model = self._NonReferenceTestModel(func, 5, 1)
        args = [torch.randn(5, 3, 32, 32), {"first": -1, "second": 5}, None]
        node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1]}
        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)

    def test_assert_on_size_after_quant_layer(self):
        """
        Verifies that calculating a size of a quantized tensor works
        correctly in quantization passes.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                torch._assert(x.size(1) == 1, 'foobar')
                return x

        m = M().eval()
        example_inputs = (torch.rand(4, 1, 4, 4),)
        m(*example_inputs)
        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        mp(*example_inputs)
        mc = convert_fx(mp)
        mc(*example_inputs)

    def test_fp32_sum(self):
        """
        Verifies that fp32 sum works correctly if it's before or after
        quantized layers.
        """
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = torch.stack([x])
                x = torch.sum(x)
                return x

        class M2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x1 = torch.stack([x])
                x1 = torch.sum(x1, dim=0)
                x2 = self.conv2(x1)
                return x2

        for cls in (M1, M2):
            m = cls().eval()
            example_inputs = (torch.rand(4, 1, 4, 4),)
            m(*example_inputs)
            qconfig_dict = {'': torch.ao.quantization.default_qconfig}
            mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
            mp(*example_inputs)
            mc = convert_fx(mp)
            mc(*example_inputs)

    def test_fusion_pattern_unquantized(self):
        """
        Ensure that leaving a possible fusion pattern of multiple nodes
        unquantized runs through the APIs without errors.
        """
        class Child(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = nn.ReLU()

            def forward(self, x):
                x = torch.add(x, 1.0)
                x = torch.nn.functional.relu(x)
                return x

        class Parent(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.child = Child()
                self.conv = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.child(x)
                x = self.conv(x)
                return x

        m = Parent().eval()
        qconfig_dict = {
            '': torch.ao.quantization.default_qconfig,
            'module_name': [
                ('child', None),
            ],
        }
        example_inputs = (torch.rand(1, 1, 1, 1),)
        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        mp(*example_inputs)
        mc = convert_fx(mp)

    def test_state_dict(self):
        """ Make sure packed params appear in state_dict
        """

        # test linear packed weight
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.rand(4, 30)
                self.b = torch.rand(4)

            def forward(self, x):
                return F.linear(x, self.w, self.b)

        m = M1().eval()
        qconfig_dict = {"": default_qconfig}
        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),))
        m = convert_fx(m)
        state_dict = m.state_dict()
        self.assertTrue("_packed_weight_0" in state_dict)

        # test conv packed weight
        class M2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.rand(3, 3, 3, 3)
                self.b = torch.rand(3)
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x):
                return F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)

        m = M2().eval()
        qconfig_dict = {"": default_qconfig}
        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
        m = convert_fx(m)
        state_dict = m.state_dict()
        self.assertTrue("_packed_weight_0" in state_dict)

        # test load
        ref_weight, ref_bias = torch.ops.quantized.conv2d_unpack(state_dict["_packed_weight_0"])
        data = torch.rand(1, 3, 5, 5)
        ref_res = m(data)
        m = M2().eval()
        m = prepare_fx(m, qconfig_dict, (data,))
        m = convert_fx(m)
        res = m(data)
        weight, bias = m._packed_weight_0.unpack()
        # check that random model weight/bias does not match ref weight/bias
        self.assertNotEqual(weight, ref_weight)
        self.assertNotEqual(bias, ref_bias)
        self.assertNotEqual(res, ref_res)
        m.load_state_dict(state_dict)

        def checkModel(m, data, ref_weight, ref_bias, ref_res):
            res = m(data)
            weight, bias = m._packed_weight_0.unpack()
            # check that weight/bias matches after load the state_dict
            self.assertEqual(weight, ref_weight)
            self.assertEqual(bias, ref_bias)
            self.assertEqual(res, ref_res)

        checkModel(m, data, ref_weight, ref_bias, ref_res)

        # Test save to disk and load back
        m = M2().eval()
        m = prepare_fx(m, qconfig_dict, example_inputs=(data,))
        m = convert_fx(m)
        m.load_state_dict(state_dict)
        with TemporaryFileName() as fname:
            torch.save(m.state_dict(), fname)
            # weights_only=False as this is loading a ScriptModule
            m.load_state_dict(torch.load(fname, weights_only=False))

        checkModel(m, data, ref_weight, ref_bias, ref_res)

    @skipIfNoFBGEMM
    def test_preserve_qconfig(self):
        """
        Test to make sure the temporary config option to preserve qconfig attributes
        in the model works
        """
        with override_quantized_engine('fbgemm'):
            class Linear(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.w = torch.ones(5, 5)
                    self.b = torch.zeros(5)

                def forward(self, x):
                    return torch.nn.functional.linear(x, self.w, self.b)

            class M(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.mods1 = torch.nn.Sequential(
                        Linear(),
                        Linear()
                    )
                    self.mods2 = torch.nn.Sigmoid()

                def forward(self, x):
                    x = self.mods1(x)
                    x = self.mods2(x)
                    return x

            model = M().eval()
            qconfig_dict = {
                "object_type": [
                    (torch.nn.functional.linear, float16_dynamic_qconfig),
                ],
            }
            example_inputs = (torch.rand(5, 5),)
            m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
            m(*example_inputs)
            m = convert_fx(m, _remove_qconfig=False)

            self.assertTrue(hasattr(m.mods2, 'qconfig'))

    def test_not_used(self):
        """ Test quantizing a not used value"""

        class M(torch.nn.Module):
            def forward(self, x):
                x = x + x
                x.sigmoid_()
                return x

        m = M().eval()
        qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig)
        # make sure quantization runs
        m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),))
        m = convert_fx(m)

    def test_qparams_fqn(self):
        """ Test that the FQN of input_scale/zero_point is set
        to that of first linear use. """
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(
                    Linear(),
                    Linear()
                )

            def forward(self, x):
                x = torch.cat((x,), 1)
                tmp = x.size()
                x = self.mods1(x)
                y = x * tmp[0]
                return y

        model = M().eval()
        qconfig_dict = {
            "": None,
            "object_type": [
                (torch.nn.functional.linear, default_qconfig),
                (torch.nn.functional.relu, default_qconfig),
            ],
        }
        example_inputs = (torch.rand(5, 5),)
        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        keys = m.state_dict().keys()
        m(torch.randn(5, 5))
        # TODO: probably don't want to hardcode the attribute names, since they are generated
        for attr_name in [
                "mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
                "mods1_0_scale_0", "mods1_0_zero_point_0",
                "mods1_1_scale_0", "mods1_1_zero_point_0"]:
            self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")

    def test_no_obs_between_unmatched_node_and_copy_node(self):
        """
        Verifies that an observer is not inserted between an unmatched
        node and a node matched to CopyNodeQuantizeHandler.  This is done
        because observers require activations to be Tensors, and there is
        no guarantee that an output of an unmatched node is a Tensor.
        """

        class M(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = nn.ReLU()

            def forward(self, x):
                x = _user_func_with_complex_return_type(x)
                x1 = x[0] + 1
                return x1, x[1]

        m = M().eval()

        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        example_inputs = (torch.randn(4, 4, 4, 4),)
        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        # if an observer is inserted after _user_func_with_complex_return_type,
        # the following call will fail
        mp(*example_inputs)
        mc = convert_fx(mp)
        mc(*example_inputs)

    def test_fold_quant_dequant(self):
        """ Test that the sequence of quant-dequant nodes in the
            graph, get folded and we erase the extra dequant nodes.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                x = torch.cat((x,), 1)
                tmp = x.size()
                x = torch.nn.functional.linear(x, self.w, self.b)
                y = x * tmp[0]
                return y

        model = M().eval()
        qconfig_dict = {
            "": None,
            "object_type": [
                (torch.nn.functional.linear, default_qconfig),
            ],
        }
        example_inputs = (torch.rand(5, 5),)
        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        keys = m.state_dict().keys()
        m(*example_inputs)
        dequant = 0
        quant = 0
        for n in m.graph.nodes:
            if n.op == "call_method" and n.target == "dequantize":
                dequant = dequant + 1
            if n.op == "call_function" and n.target == torch.quantize_per_tensor:
                quant = quant + 1
        self.assertEqual(dequant, 1)
        self.assertEqual(quant, 1)

    def test_quant_output_always_observed(self):
        """
        If the output is hardcoded to be quantized, ensure that
        there is always an observer, even if the last non-output node is not
        quantizeable.
        """
        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
        prepare_custom_config_dict = {'output_quantized_idxs': [0]}
        example_inputs = (torch.randn(4, 1, 4, 4),)

        # non-quantizeable node, quantized output
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.identity = torch.nn.Identity()

            def forward(self, x):
                x = self.identity(x)
                return x

        m1 = M1()
        self.checkGraphModeFxOp(
            m1, example_inputs, QuantType.QAT,
            prepare_expected_node_occurrence={
                ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
            },
            expected_node_occurrence={
                ns.call_function(torch.quantize_per_tensor): 1,
            },
            prepare_custom_config=prepare_custom_config_dict)

        # quantizeable node, quantized output
        class M2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv(x)
                return x

        m2 = M2()
        self.checkGraphModeFxOp(
            m2, example_inputs, QuantType.QAT,
            prepare_expected_node_occurrence={
                # one for weights, one for activations
                ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
            },
            expected_node_occurrence={
                ns.call_function(torch.quantize_per_tensor): 1,
            },
            prepare_custom_config=prepare_custom_config_dict)

        # quantizeable node, quantized dictionary output
        class M3(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv(x)
                return {"output": x}

        m3 = M3()
        self.checkGraphModeFxOp(
            m3, example_inputs, QuantType.QAT,
            prepare_expected_node_occurrence={
                # one for weights, one for activations
                ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
            },
            expected_node_occurrence={
                ns.call_function(torch.quantize_per_tensor): 1,
            },
            prepare_custom_config=prepare_custom_config_dict)

    def test_deepcopy_preserve_attributes(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = 3

            def forward(self, x):
                return x

        m = M().eval()
        m = prepare_fx(
            m,
            {"": default_qconfig},
            example_inputs=(torch.randn(1),),
            prepare_custom_config={"preserved_attributes": ["attr"]})
        # preserved attributes are also stored in meta so that it doesn't get lost
        # during deepcopy
        self.assertTrue(hasattr(m, "attr"))
        self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
        m2 = copy.deepcopy(m)
        self.assertTrue(hasattr(m2, "attr"))
        self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
        m = convert_fx(m, convert_custom_config={"preserved_attributes": ["attr"]})
        self.assertTrue(hasattr(m, "attr"))
        self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
        m2 = copy.deepcopy(m)
        self.assertTrue(hasattr(m2, "attr"))
        self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY])

    def test_output_lists_and_dicts(self):
        """Verify that specifying complicated output types does not crash.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv(x)
                return {'foo': [x]}, [{'foo': [[x]]}]

        m = M().eval()
        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
        mc = convert_fx(mp)

    def test_shape_followed_by_quantized_op(self):
        """ Make sure that shape does not dequantize
        the Tensor before the next operator
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(2, 2, 2)
                self.conv2 = torch.nn.Conv2d(2, 2, 2)

            def forward(self, x):
                x = self.conv1(x)
                s = x.shape
                torch._assert(s == x.shape, "")
                x = self.conv2(x)
                return x

        # make sure quantization runs
        m = M().eval()
        example_inputs = (torch.randn(2, 2, 4, 4),)
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

    def test_trace_quantize_per_tensor(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv(x)
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),))
        m = convert_fx(m)
        # Make sure this runs without error
        m = torch.fx.Transformer(m).transform()

    def test_copy_node_has_shared_actpp_instance(self):
        """ Test the output of CopyNode to have the same
        observer/fake_quant instance as the input
        """

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.avgpool2d = torch.nn.AvgPool2d(kernel_size=3)

            def forward(self, x):
                x = self.avgpool2d(x)
                return x

        for quant_type in self.static_quant_types:
            m = M()
            # Checks that we have an observer for both input and output
            occurrence_map = {
                QuantType.STATIC: {
                    ns.call_module(torch.ao.quantization.MinMaxObserver): 2
                },
                QuantType.QAT: {
                    ns.call_module(torch.ao.quantization.FakeQuantize): 2
                }
            }
            if quant_type == QuantType.QAT:
                m.train()
                prepare = prepare_qat_fx
                qconfig = default_qat_qconfig
                actpp_module_class = torch.ao.quantization.FakeQuantize
            else:
                m.eval()
                prepare = prepare_fx
                qconfig = default_qconfig
                actpp_module_class = torch.ao.quantization.MinMaxObserver

            example_inputs = (torch.randn(1, 3, 3, 3),)
            m = prepare(m, {"": qconfig}, example_inputs=example_inputs)
            # check that there is a duplicated observer instance
            actpp_module_count = 0
            for name, module in m.named_modules(remove_duplicate=False):
                if isinstance(module, actpp_module_class):
                    actpp_module_count += 1
            self.assertEqual(actpp_module_count, 2)

            actpp_module_count = 0
            for name, module in m.named_modules():
                if isinstance(module, actpp_module_class):
                    actpp_module_count += 1
            self.assertEqual(actpp_module_count, 1)

            m_copy = copy.deepcopy(m)
            m = convert_fx(m)
            m_reference = convert_to_reference_fx(m_copy)

            # checks for non-reference quantized model
            node_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 1,
                ns.call_method("dequantize"): 1
            }
            node_list = [
                ns.call_function(torch.quantize_per_tensor),
                ns.call_module(torch.nn.AvgPool2d),
                ns.call_method("dequantize"),
            ]
            self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, expected_node_list=node_list)

            # checks for reference quantized model, for copy nodes we'll have
            # dequant - copy_node - quant patterns which will be fused later
            # in the backend lowering step
            node_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 2,
                ns.call_method("dequantize"): 2
            }
            node_list = [
                ns.call_function(torch.quantize_per_tensor),
                ns.call_method("dequantize"),
                ns.call_module(torch.nn.AvgPool2d),
                ns.call_function(torch.quantize_per_tensor),
                ns.call_method("dequantize"),
            ]
            self.checkGraphModuleNodes(m_reference, expected_node_occurrence=node_occurrence, expected_node_list=node_list)

    def test_linear_qint8_activation(self):
        """Test support for qint8 activation in reference pattern
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 2, 2, 2)
                self.linear = torch.nn.Linear(8, 5)

            def forward(self, x):
                x = self.conv(x)
                x = torch.flatten(x, 1)
                x = self.linear(x)
                return x

        m = M().eval()
        example_inputs = (torch.rand(2, 1, 5, 5),)
        m = prepare_fx(
            m,
            {"": torch.ao.quantization.QConfig(
                activation=torch.ao.quantization.HistogramObserver.with_args(
                    qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
                ), weight=torch.ao.quantization.default_per_channel_weight_observer)},
            example_inputs=example_inputs)
        m = convert_to_reference_fx(m)
        m(*example_inputs)

    def test_preserve_tuple(self):
        """ Test tuple input type is preserved
        """

        class LSTM(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lstm = nn.LSTM(50, 50, 1)

            def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]):
                h = state[0]
                c = state[1]
                return self.lstm(inputs, (h, c))

        m = LSTM().eval()
        example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50))
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        # make sure the arg[1] of lstm module is a tuple
        for n in m.graph.nodes:
            if n.target == "lstm":
                self.assertEqual(type(n.args[1]), tuple)

    def _test_static_lstm_helper(self, model, prepare_node_occurrence, convert_node_occurrence):
        """
        Helper method to validate the graph of a model with static LSTM.
        """
        qconfig_mapping = get_default_qconfig_mapping()
        prepare_custom_config = PrepareCustomConfig() \
            .set_float_to_observed_mapping(torch.nn.LSTM, torch.ao.nn.quantizable.LSTM)
        convert_custom_config = ConvertCustomConfig() \
            .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, torch.ao.nn.quantized.LSTM)
        example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))

        model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config)
        self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence)
        model(*example_inputs)

        model = convert_fx(model, convert_custom_config=convert_custom_config)
        self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence)
        model(*example_inputs)

    def test_static_lstm(self):
        """
        Test statically quantized custom module LSTM followed by ops that consume individual
        tensors of the output tuple.
        """
        class MyModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lstm = nn.LSTM(50, 50, 1)
                self.linear1 = nn.Linear(50, 10)
                self.linear2 = nn.Linear(50, 10)
                self.linear3 = nn.Linear(50, 10)

            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
                (out, (h0_out, c0_out)) = self.lstm(inputs, (h0, c0))
                out = self.linear1(out)
                h0_out = self.linear2(h0_out)
                c0_out = self.linear3(c0_out)
                return (out, (h0_out, c0_out))

        m = MyModel()
        prepare_node_occurrence = {
            ns.call_module(torch.ao.nn.quantizable.LSTM): 1,
        }
        convert_node_occurrence = {
            ns.call_module(torch.ao.nn.quantized.LSTM): 1,
            ns.call_function(torch.quantize_per_tensor): 3,
            # lstm[0].dequantize()
            # lstm[1][0].dequantize()
            # lstm[1][1].dequantize()
            ns.call_method("dequantize"): 3,
            # lstm[0], lstm[1], lstm[1][0], lstm[1][1]
            ns.call_function(operator.getitem): 4,
            # No tuples are consumed
            ns.call_function(tuple): 0,
        }
        self._test_static_lstm_helper(m, prepare_node_occurrence, convert_node_occurrence)

    def test_static_lstm_consume_tuple(self):
        """
        Test statically quantized custom module LSTM followed by a module that consumes the
        output tuple, either as a whole or part of it.
        """
        class ModuleAfterLSTM(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.identity = torch.nn.Identity()

            def forward(self, x):
                return self.identity(x)

        class ConsumeWholeTuple(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lstm = nn.LSTM(50, 50, 1)
                self.module_after_lstm = ModuleAfterLSTM()

            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
                x = self.lstm(inputs, (h0, c0))
                x = self.module_after_lstm(x)  # consume tuple (output, (hidden0, hidden1))
                return x

        class ConsumeHiddenTuple(ConsumeWholeTuple):
            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
                x = self.lstm(inputs, (h0, c0))
                x = self.module_after_lstm(x[1])  # consume tuple (hidden0, hidden1)
                return x

        # Test consuming the whole tuple (output, (hidden0, hidden1))
        m1 = ConsumeWholeTuple()
        prepare_node_occurrence = {
            ns.call_module(torch.ao.nn.quantizable.LSTM): 1,
        }
        convert_node_occurrence1 = {
            ns.call_module(torch.ao.nn.quantized.LSTM): 1,
            ns.call_function(torch.quantize_per_tensor): 3,
            # lstm[0].dequantize()
            # lstm[1][0].dequantize()
            # lstm[1][1].dequantize()
            ns.call_method("dequantize"): 3,
            # lstm[0], lstm[1], lstm[1][0], lstm[1][1]
            ns.call_function(operator.getitem): 4,
            # tuple(output_dq, tuple(hidden0_dq, hidden1_dq))
            ns.call_function(tuple): 2,
        }
        self._test_static_lstm_helper(m1, prepare_node_occurrence, convert_node_occurrence1)

        # Test consuming just the hidden tuple (hidden0, hidden1)
        m2 = ConsumeHiddenTuple()
        convert_node_occurrence2 = {
            ns.call_module(torch.ao.nn.quantized.LSTM): 1,
            ns.call_function(torch.quantize_per_tensor): 3,
            # lstm[1][0].dequantize()
            # lstm[1][1].dequantize()
            ns.call_method("dequantize"): 2,
            # lstm[1], lstm[1][0], lstm[1][1]
            ns.call_function(operator.getitem): 3,
            # tuple(hidden0_dq, hidden1_dq)
            ns.call_function(tuple): 1,
        }
        self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2)

    def test_static_lstm_with_custom_fixed_qparams(self):
        """
        Test statically quantized LSTM with custom fixed qparams assigned to each of the
        inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM`
        and use the child class in the custom module mapping.
        """
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.my_lstm = torch.nn.LSTM(50, 50, 1)

            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
                x = self.my_lstm(inputs, (h0, c0))
                return x

        # Construct a BackendConfig that supports qint32 for certain ops
        # TODO: build a BackendConfig from scratch instead of modifying an existing one
        qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
        my_backend_config = get_qnnpack_backend_config()
        for config in my_backend_config.configs:
            if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
                config.add_dtype_config(qint32_dtype_config)

        class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
            """
            Example of user provided LSTM implementation that assigns fixed qparams
            to the inner ops.
            """
            @classmethod
            def from_float(cls, float_lstm):
                assert isinstance(float_lstm, cls._FLOAT_MODULE)
                # uint16, [-16, 16)
                linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
                # uint16, [0, 1)
                sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32)
                # uint16, [-1, 1)
                tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
                # int16, [-16, 16)
                cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32)
                # uint8, [-1, 1)
                hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
                example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50)))
                return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
                    float_lstm=float_lstm,
                    example_inputs=example_inputs,
                    backend_config=my_backend_config,
                    linear_output_obs_ctr=linear_output_obs_ctr,
                    sigmoid_obs_ctr=sigmoid_obs_ctr,
                    tanh_obs_ctr=tanh_obs_ctr,
                    cell_state_obs_ctr=cell_state_obs_ctr,
                    hidden_state_obs_ctr=hidden_state_obs_ctr,
                )

        class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
            """
            Example of user provided LSTM implementation that produces a reference
            quantized module from a `UserObservedLSTM`.
            """
            @classmethod
            def from_observed(cls, observed_lstm):
                assert isinstance(observed_lstm, cls._FLOAT_MODULE)
                return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module(
                    observed_lstm=observed_lstm,
                    backend_config=my_backend_config,
                )

        # FX graph mode quantization
        m = MyModel()
        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
        example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
        prepare_custom_config = PrepareCustomConfig() \
            .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM)
        convert_custom_config = ConvertCustomConfig() \
            .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM)
        prepared = prepare_fx(
            m,
            qconfig_mapping,
            example_inputs,
            prepare_custom_config,
            backend_config=my_backend_config,
        )
        prepared(*example_inputs)
        converted = convert_fx(
            prepared,
            convert_custom_config,
            backend_config=my_backend_config,
        )
        converted(*example_inputs)

        # Find the patterns [dq - op - q_to_specific_dtype] in the graph and
        # verify that qparams and dtypes are set correctly in the quantize ops
        node_name_to_expected_quantize_args = {
            "igates": (None, None, torch.quint8),
            "hgates": (None, None, torch.quint8),
            "add": (2 ** -11, 2 ** 15, torch.qint32),  # gates.add
            "input_gate": (2 ** -16, 0, torch.qint32),
            "forget_gate": (2 ** -16, 0, torch.qint32),
            "cell_gate": (2 ** -15, 2 ** 15, torch.qint32),
            "output_gate": (2 ** -16, 0, torch.qint32),
            "mul": (2 ** -11, 0, torch.qint32),  # fgate_cx.mul
            "mul_1": (2 ** -11, 0, torch.qint32),  # igate_cgate.mul
            "add_1": (2 ** -11, 0, torch.qint32),  # fgate_cx_igate_cgate.add
            "mul_2": (2 ** -7, 2 ** 7, torch.quint8),  # ogate_cy.mul
        }
        cell = converted.my_lstm.layers.get_submodule("0").layer_fw.cell
        matched_names = set()
        for node in cell.graph.nodes:
            if node.name not in node_name_to_expected_quantize_args:
                continue
            matched_names.add(node.name)
            # Match preceding dequantize
            self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
            # Match following quantize with the specific qparams and dtypes
            expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name]
            for user in node.users.keys():
                self.assertEqual(user.target, torch.quantize_per_tensor)
                if expected_scale is not None:
                    self.assertEqual(getattr(cell, user.args[1].target), expected_scale)
                if expected_zp is not None:
                    self.assertEqual(getattr(cell, user.args[2].target), expected_zp)
                self.assertEqual(user.args[-1], expected_dtype)
        # Ensure all patterns were matched
        self.assertEqual(matched_names, set(node_name_to_expected_quantize_args.keys()))

    def test_reroute_tuple_getitem_patterns(self):
        """
        The following graph should redirect the output to `b`. After the transformation,
        all other nodes, including the inputs `a` and `c`, are no longer needed.

             a   b     c
             |   \\   /
             \\   tuple
              \\   /
               tuple
               /  \\
              /    \\
             |      \\
             |       \\
             |        \\
        getitem0    getitem1
             |      /     \\
             | getitem0  getitem1
             |     \\     /
             \\      tuple
              \\      /
               \\    /
                tuple
                  |
               getitem1
                  |
               getitem0
                  |
                output
        """
        # Construct graph manually because symbolic_trace does not insert tuple and getitem nodes
        graph = torch.fx.Graph()
        a = graph.create_node("placeholder", "a")
        b = graph.create_node("placeholder", "b")
        c = graph.create_node("placeholder", "c")
        bc = graph.call_function(tuple, args=([b, c],))
        abc = graph.call_function(tuple, args=([a, bc],))

        # Break down tuple and reconstruct it again
        a2 = graph.call_function(operator.getitem, args=(abc, 0))
        bc2 = graph.call_function(operator.getitem, args=(abc, 1))
        b2 = graph.call_function(operator.getitem, args=(bc2, 0))
        c2 = graph.call_function(operator.getitem, args=(bc2, 1))
        bc3 = graph.call_function(tuple, args=([b2, c2],))
        abc2 = graph.call_function(tuple, args=([a2, bc3],))

        # Output tuple[1][0]
        bc4 = graph.call_function(operator.getitem, args=(abc2, 1))
        b3 = graph.call_function(operator.getitem, args=(bc4, 0))
        output = graph.output(b3)

        # Do reroute
        _reroute_tuple_getitem_pattern(graph)

        # Assert that output reroutes to `b` directly, and all other nodes can be removed
        output_ancestors = []
        def gather_ancestors(current_node):  # noqa: E306
            for arg in current_node.args:
                output_ancestors.append(arg)
                gather_ancestors(arg)
        gather_ancestors(output)
        self.assertEqual(output_ancestors, [b])
        self.assertEqual(output.args[0], b)

    def test_relu_lowering(self):
        class M(torch.nn.Module):
            def forward(self, x):
                return torch.nn.functional.relu(x)

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),))
        m_copy = copy.deepcopy(m)
        m = convert_fx(m)
        m_ref = convert_to_reference_fx(m_copy)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        node_occurrence_ref = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2
        }

        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)

    @skipIfNoFBGEMM
    def test_dynamic_with_fusion(self):
        """
        Tests that dynamic quantization APIs work with Linear + Relu fusion
        """
        with override_quantized_engine('fbgemm'):
            class LinearRelu(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = torch.nn.Linear(5, 5)
                    self.relu = torch.nn.ReLU()

                def forward(self, x):
                    x = self.linear(x)
                    return self.relu(x)

            class Linear(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.w = torch.ones(5, 5)
                    self.b = torch.zeros(5)

                def forward(self, x):
                    return torch.nn.functional.linear(x, self.w, self.b)

            class M(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu())
                    self.mods2 = Linear()
                    self.relu = F.relu

                def forward(self, x):
                    x = self.mods1(x)
                    x = self.mods2(x)
                    x = self.relu(x)
                    return x

            dynamic_quantized_ops = {
                float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16,
                default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic
            }
            for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
                model = M().eval()
                qconfig_dict = {
                    "": qconfig
                }
                example_inputs = (torch.rand(5, 5),)
                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
                m = convert_fx(m)
                m(*example_inputs)
                node_list = [
                    ns.call_module(nniqd.LinearReLU),
                    ns.call_module(nniqd.LinearReLU),
                    ns.call_function(dynamic_quantized_ops[qconfig]),
                ]
                self.checkGraphModuleNodes(m, expected_node_list=node_list)

    @skipIfNoFBGEMM
    def test_dynamic_with_fusion_multiple_uses(self):
        """
        Tests that dynamic quantization APIs work with Linear + Relu fusion
        """
        with override_quantized_engine('fbgemm'):
            class LinearRelu(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = torch.nn.Linear(5, 5)
                    self.relu = torch.nn.ReLU()

                def forward(self, x):
                    x = self.linear(x)
                    return self.relu(x)

            class M(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear_relu = LinearRelu()

                def forward(self, x):
                    x = self.linear_relu(x)
                    x = self.linear_relu(x)
                    return x

            for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
                model = M().eval()
                qconfig_dict = {
                    "": qconfig
                }
                example_inputs = (torch.randn(5, 5),)
                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
                m = convert_fx(m)
                m(*example_inputs)
                node_list = [
                    ns.call_module(nniqd.LinearReLU),
                    ns.call_module(nniqd.LinearReLU),
                ]
                self.checkGraphModuleNodes(m, expected_node_list=node_list)

    @skipIfNoFBGEMM
    def test_dynamic_linear_input_multiple_use(self):
        """
        Tests input for dynamic linear being used by multiple ops
        """
        with override_quantized_engine('fbgemm'):
            class LinearRelu(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = torch.nn.Linear(5, 5)
                    self.relu = torch.nn.ReLU()

                def forward(self, x):
                    x = self.linear(x)
                    return self.relu(x)

            class M(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.mod1 = LinearRelu()
                    self.mod2 = LinearRelu()

                def forward(self, x):
                    y1 = self.mod1(x)
                    y2 = self.mod2(x)
                    return y1 + y2

            for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
                model = M().eval()
                qconfig_dict = {
                    "": qconfig
                }
                example_inputs = (torch.rand(5, 5, 5),)
                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
                m = convert_fx(m)
                m(*example_inputs)
                node_list = [
                    ns.call_module(nniqd.LinearReLU),
                    ns.call_module(nniqd.LinearReLU),
                ]
                self.checkGraphModuleNodes(m, expected_node_list=node_list)

    def test_ref_linear_module(self):
        """ Make sure the numerics for models with ref linear module
        matches models with fbgemm/qnnpack module
        """
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 5)

            def forward(self, x):
                return self.linear(x)

        class M2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 5)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(self.linear(x))

        for M in [M1, M2]:
            m = M().eval()
            example_inputs = (torch.randn(5, 10),)
            m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
            m_copy = copy.deepcopy(m)
            m = convert_fx(m)
            m_ref = convert_to_reference_fx(m_copy)
            result = m(*example_inputs)
            result_ref = m_ref(*example_inputs)
            self.assertTrue(torch.equal(result, result_ref))

    def test_ref_conv_module(self):
        """ Make sure the numerics for models with ref conv module
        matches models with fbgemm/qnnpack module
        """
        convs = {
            1: nn.Conv1d,
            2: nn.Conv2d,
            3: nn.Conv3d,
        }

        class M1(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.conv = convs[dim](3, 3, 3)

            def forward(self, x):
                return self.conv(x)

        class M2(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.conv = convs[dim](3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(self.conv(x))

        for dim, M in itertools.product([1, 2, 3], [M1, M2]):
            m = M(dim).eval()
            data = self.img_data_dict[dim][0][0]
            m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,))
            m_copy = copy.deepcopy(m)
            m = convert_fx(m)
            m_ref = convert_to_reference_fx(m_copy)
            result = m(data)
            result_ref = m_ref(data)
            self.assertTrue(torch.equal(result, result_ref))

    def test_sub_scalar(self):
        class M(torch.nn.Module):
            def forward(self, x):
                x = x + 1
                x = x - 1
                x = x + 3
                x = x - 4
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),))
        m = convert_fx(m)
        occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence)

    def test_observer_fqn(self):
        """
        Test to make sure the observer FQN is based on the quantizable op/module that it is observing
        and uses the modules FQN to determine the observer name.
        """
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)


            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)


        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(
                    Linear(),
                    Linear()
                )
                self.mods2 = Linear()
                self.mods3 = torch.nn.Linear(5, 5)

            def forward(self, x):
                x = self.mods1(x)
                x = torch.add(x, 4)
                x = self.mods2(x)
                y = torch.add(x, 2)
                z = torch.mul(x, 5)
                a = self.mods3(y)
                return a, z

        model = M().eval()

        prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5)))
        name_list = []
        for name, mod in prepared.named_modules():
            if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver):
                name_list.append(name)
        expected_name_list = ['activation_post_process_0',
                              'activation_post_process_1',
                              'activation_post_process_2',
                              'activation_post_process_3',
                              'activation_post_process_4',
                              'activation_post_process_6',
                              'activation_post_process_7',
                              'activation_post_process_10']
        assert name_list == expected_name_list

    def test_conv_lowering(self):
        convs = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
        qconvs = {1: nn.quantized.Conv1d, 2: nn.quantized.Conv2d, 3: nn.quantized.Conv3d}

        class M(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.conv = convs[dim](3, 3, 3)

            def forward(self, x):
                return self.conv(x)

        for dim in range(1, len(convs) + 1):
            m = M(dim).eval()
            data = self.img_data_dict[dim][0][0]
            m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,))
            m_ref = copy.deepcopy(m)
            m_ref = convert_to_reference_fx(m_ref)
            m = convert_fx(m)
            out_ref = m_ref(data)
            out = m(data)
            # check that reference pattern for quantized conv module is fused
            expected_node_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 1,
                ns.call_module(qconvs[dim]): 1,
                ns.call_method("dequantize"): 1
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_node_occurrence)
            # checking result match
            self.assertTrue(torch.equal(out_ref, out))

    def test_convert_qconfig_mapping(self):
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)


        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(
                    Linear(),
                    Linear()
                )
                self.mods3 = torch.nn.Linear(5, 5)

            def forward(self, x):
                x = self.mods1(x)
                x = torch.add(x, 4)
                z = torch.mul(x, 5)
                x = self.mods3(z)
                return x

        model = M().train()

        for check in ["module_name", "object_type"]:
            qconfig_dict = {"": None,
                            "object_type": [
                                (nn.functional.linear, get_default_qat_qconfig("fbgemm")),
                                (torch.add, get_default_qat_qconfig("fbgemm")),
                                (nn.Linear, get_default_qat_qconfig("fbgemm")),
                            ],
                            }
            example_inputs = (torch.rand(5, 5),)
            prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
            prepared(*example_inputs)
            if check == "module_name":
                convert_qconfig_dict = {"": None,
                                        "object_type": [
                                            (nn.functional.linear, get_default_qat_qconfig("fbgemm")),
                                            (torch.add, get_default_qat_qconfig("fbgemm")),
                                            (nn.Linear, get_default_qat_qconfig("fbgemm")),
                                        ],
                                        "module_name": [("mods1.0", None)]}

                node_occurrence = {
                    ns.call_function(torch.quantize_per_tensor): 2,
                    ns.call_function(torch.nn.functional.linear): 1,
                    ns.call_function(torch.ops.quantized.linear): 1,
                    ns.call_function(torch.ops.quantized.add): 1,
                    ns.call_method("dequantize"): 2
                }
                order_check = [
                    ns.call_function(torch.nn.functional.linear),
                    ns.call_function(torch.quantize_per_tensor),
                    ns.call_function(torch.ops.quantized.linear),
                    ns.call_function(torch.ops.quantized.add),
                    ns.call_method("dequantize"),
                    ns.call_function(torch.quantize_per_tensor),
                    ns.call_module(nnq.Linear),
                    ns.call_method("dequantize"),
                ]
            elif check == "object_type":
                convert_qconfig_dict = {"": None,
                                        "object_type": [
                                            (nn.functional.linear, get_default_qat_qconfig("fbgemm")),
                                            (torch.add, get_default_qat_qconfig("fbgemm")),
                                            (nn.Linear, None),
                                        ]}

                node_occurrence = {
                    ns.call_function(torch.quantize_per_tensor): 1,
                    ns.call_function(torch.ops.quantized.linear): 2,
                    ns.call_function(torch.ops.quantized.add): 1,
                    ns.call_function(torch.mul): 1,
                    ns.call_method("dequantize"): 1
                }
                order_check = [
                    ns.call_function(torch.quantize_per_tensor),
                    ns.call_function(torch.ops.quantized.linear),
                    ns.call_function(torch.ops.quantized.linear),
                    ns.call_function(torch.ops.quantized.add),
                    ns.call_method("dequantize"),
                    ns.call_function(torch.mul),
                    ns.call_module(nn.Linear),
                ]

            converted = convert_fx(prepared, qconfig_mapping=convert_qconfig_dict)
            converted(torch.rand(5, 5))
            self.checkGraphModuleNodes(
                converted,
                expected_node_occurrence=node_occurrence,
                expected_node_list=order_check)

    def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2):
        self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr)

    def test_register_patterns(self):
        def cleanUp():
            del _DEFAULT_FUSION_PATTERNS["dummy_fusion"]
            del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"]
            del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"]
            del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"]
            del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"]
            del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"]
            del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"]
            del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"]
        self.addCleanup(cleanUp)

        @_register_fusion_pattern("dummy_fusion")
        class DummyFusion:
            pass

        @_register_quant_pattern("dummy_quant")
        class DummyQuant:
            pass

        @_register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer)
        class DummyQuant2:
            pass

        @_register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer)
        class DummyQuant3:
            pass

        self.assertEqual(_DEFAULT_FUSION_PATTERNS["dummy_fusion"], DummyFusion)
        self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant)
        self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2)
        self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3)
        self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer)
        self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer)
        self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"],
                                                  default_fixed_qparams_range_0to1_fake_quant)
        self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"],
                                                  default_fixed_qparams_range_neg1to1_fake_quant)
        output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True)
        output_observer_map = get_default_output_activation_post_process_map(is_training=False)
        self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer)
        self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"),
                                                  default_fixed_qparams_range_neg1to1_fake_quant)



    def test_reuse_input_qconfig(self):
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x):
                x = self.conv(x)
                x = x.reshape()
                return x

        class M2(torch.nn.Module):
            def forward(self, x):
                x = x.reshape()
                return x

        options = itertools.product([M1, M2], [True, False])
        for M, is_qat in options:
            m = M1().eval()
            example_inputs = (torch.randn(1, 3, 3, 3),)
            m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
            m = convert_fx(m)
            node_list = [
                ns.call_function(torch.quantize_per_tensor),
                ns.call_module(nnq.Conv2d),
                ns.call_method("reshape"),
                ns.call_method("dequantize"),
            ]
            self.checkGraphModuleNodes(
                m,
                expected_node_list=node_list)

            m = M2().eval()
            m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
            m = convert_fx(m)
            node_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 0,
                ns.call_method("dequnatize"): 0,
            }
            node_list = [
                ns.call_method("reshape"),
            ]
            self.checkGraphModuleNodes(
                m,
                expected_node_occurrence=node_occurrence,
                expected_node_list=node_list)

    def test_stack_trace_preserved_linear(self):
        class M(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(1, 1)

            def forward(self, x):
                x = self.linear(x)
                return x

        m = M().eval()
        mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))

        found_stack_trace = False
        for n in mp.graph.nodes:
            if n.op == 'call_module' and n.target == 'linear':
                found_stack_trace = n.stack_trace is not None
                break
        self.assertTrue(found_stack_trace)

        # test reference model
        mq = convert_to_reference_fx(copy.deepcopy(mp))
        found_stack_trace = False
        for n in mq.graph.nodes:
            if n.op == 'call_module' and n.target == 'linear':
                found_stack_trace = n.stack_trace is not None
                break
        self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True")

        # test quantized model
        mq = convert_fx(mp)
        found_stack_trace = False
        for n in mq.graph.nodes:
            if n.op == 'call_module' and n.target == 'linear':
                found_stack_trace = n.stack_trace is not None
                break
        self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: False")

    def test_qat_skip_untraced(self):
        class UnTraceableModuleClass(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(2, 2)

            def forward(self, x):
                return self.linear(x)

        class UnTraceableModuleName(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(2, 2)

            def forward(self, x):
                return self.linear(x)

        class M(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.untraceable_module_class = UnTraceableModuleClass()
                self.untraceable_module_name = UnTraceableModuleClass()

            def forward(self, x):
                x = self.untraceable_module_class(x)
                x = self.untraceable_module_name(x)
                return x

        mod = M()

        qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig()}
        prepare_custom_config_dict = {
            "non_traceable_module_class": [UnTraceableModuleClass],
            "non_traceable_module_name": ["untraceable_module_name"],
        }
        example_inputs = (torch.randn(2, 2),)
        mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx(
            mod.train(), qconfig_dict, example_inputs=example_inputs,
            prepare_custom_config=prepare_custom_config_dict
        )
        mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx(
            mod.train(), qconfig_dict, example_inputs=example_inputs,
            prepare_custom_config=prepare_custom_config_dict
        )
        self.assertTrue(
            isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear)
        )
        self.assertTrue(
            isinstance(mod_prep.untraceable_module_name.linear, torch.nn.Linear)
        )
        self.assertTrue(
            type(mod_prep.untraceable_module_class.linear)
            is not torch.ao.nn.qat.modules.linear.Linear,
            "prepare_qat_fx shold not convert anything inside untraced module classes",
        )
        self.assertTrue(
            type(mod_prep.untraceable_module_name.linear)
            is not torch.ao.nn.qat.modules.linear.Linear,
            "prepare_qat_fx shold not convert anything inside modules named in untraced_module_names",
        )

    def test_qconfig_dict_setup(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.Conv1d = torch.nn.Conv1d(1, 1, 1)
                self.Conv2d = torch.nn.Conv2d(1, 1, 1)
                self.Conv3d = torch.nn.Conv3d(1, 1, 1)
                self.ConvTranspose1d = torch.nn.ConvTranspose1d(1, 1, 1)
                self.ConvTranspose2d = torch.nn.ConvTranspose2d(1, 1, 1)
                self.ConvTranspose3d = torch.nn.ConvTranspose3d(1, 1, 1)
                self.Linear = torch.nn.Linear(1, 1, 1)

            def forward(self, x):
                x = self.Conv1d(x)
                x = self.Conv2d(x)
                x = self.Conv3d(x)
                x = self.ConvTranspose1d(x)
                x = self.ConvTranspose2d(x)
                x = self.ConvTranspose3d(x)
                x = self.Linear(x)
                x = torch.nn.functional.conv1d(x, torch.rand(2, 2))
                x = torch.nn.functional.conv2d(x, torch.rand(2, 2))
                x = torch.nn.functional.conv3d(x, torch.rand(2, 2))
                x = torch.nn.functional.linear(x, torch.rand(2, 2))
                return x

        backends = ["qnnpack", "fbgemm"]
        for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
            for backend in backends:
                m = M().eval()
                qconfig_dict = func(backend)
                m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
                for name, mod in m.named_modules():
                    if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
                        if backend == "fbgemm":
                            lower_bnd = 0
                            upper_bnd = 127
                        else:
                            lower_bnd = 0
                            upper_bnd = 255
                        if issubclass(type(mod), FakeQuantize):
                            self.assertEqual(mod.activation_post_process.quant_min, lower_bnd)
                            self.assertEqual(mod.activation_post_process.quant_max, upper_bnd)
                        else:
                            self.assertEqual(mod.quant_min, lower_bnd)
                            self.assertEqual(mod.quant_max, upper_bnd)

    def test_prepare_mode(self):
        class LinearModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        def _test(prepare_fn, qconfig_dict):
            m = LinearModel()
            m1 = copy.deepcopy(m)
            m1.train()
            example_inputs = (torch.randn(1, 5),)
            prepare_fn(m1, qconfig_dict, example_inputs=example_inputs)
            m2 = copy.deepcopy(m)
            m2.eval()
            prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)

        # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
        _test(prepare_fx, get_default_qconfig_mapping())
        _test(prepare_qat_fx, get_default_qat_qconfig_mapping())

    def _validate_qconfig_against_backend_config_constraints(
            self,
            model: torch.nn.Module,
            qconfig: QConfig,
            backend_config: BackendConfig,
            satisfies_constraints: bool,
            qconfig_name: Optional[str] = None):
        """
        Helper method to validate whether `qconfig` satisfies the constraints specified in `backend_config`.
        """
        qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
        example_inputs = (torch.rand((1, 30), dtype=torch.float),)
        model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
        model(*example_inputs)
        model = convert_fx(model, backend_config=backend_config)
        if satisfies_constraints:
            expected_node_occurrence = {
                ns.call_module(torch.ao.nn.quantized.Linear) : 1,
                ns.call_module(torch.nn.Linear) : 0,
            }
        else:
            expected_node_occurrence = {
                ns.call_module(torch.ao.nn.quantized.Linear) : 0,
                ns.call_module(torch.nn.Linear) : 1,
            }
        try:
            self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
        except AssertionError as e:
            if qconfig_name is not None:
                print(f"ERROR: Validation for QConfig '{qconfig_name}' failed")
            raise e

    def test_backend_config_quantization_range(self):
        """
        Check that quantization ranges specified through the BackendConfig are reflected in
        the observers inserted into the model.
        """
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(30, 4).float()

            def forward(self, x):
                return self.linear(x)

        dtype_config = DTypeConfig(
            input_dtype=DTypeWithConstraints(
                dtype=torch.quint8,
                quant_min_lower_bound=0,
                quant_max_upper_bound=31,
            ),
            output_dtype=DTypeWithConstraints(
                dtype=torch.quint8,
                quant_min_lower_bound=0,
                quant_max_upper_bound=31,
            ),
            weight_dtype=DTypeWithConstraints(
                dtype=torch.qint8,
                quant_min_lower_bound=-64,
                quant_max_upper_bound=63,
            ),
            bias_dtype=torch.float,
        )
        backend_config = BackendConfig() \
            .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear)
                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E128
                .add_dtype_config(dtype_config)
                .set_root_module(torch.nn.Linear)
                .set_reference_quantized_module(nnqr.Linear))

        def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool):
            self._validate_qconfig_against_backend_config_constraints(
                MyModel(), qconfig, backend_config, satisfies_constraints)

        # Case 1: QConfig ranges fit within backend ranges, OK
        qconfig1 = QConfig(
            activation=MinMaxObserver.with_args(quant_min=0, quant_max=15, dtype=torch.quint8),
            weight=MinMaxObserver.with_args(quant_min=-32, quant_max=31, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
        validate_qconfig(qconfig1, satisfies_constraints=True)

        # Case 2: QConfig activation range falls outside backend range, should fail
        qconfig2 = QConfig(
            activation=MinMaxObserver.with_args(quant_min=0, quant_max=63, dtype=torch.quint8),
            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
        validate_qconfig(qconfig2, satisfies_constraints=False)

        # Case 3: QConfig weight range falls outside backend range, should fail
        qconfig3 = QConfig(
            activation=MinMaxObserver.with_args(dtype=torch.quint8),
            weight=MinMaxObserver.with_args(quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
        validate_qconfig(qconfig3, satisfies_constraints=False)

        # Case 4: QConfig doesn't specify range, should fail
        qconfig4 = QConfig(activation=ReuseInputObserver, weight=ReuseInputObserver)
        validate_qconfig(qconfig4, satisfies_constraints=False)

    def test_backend_config_scale_min(self):
        """
        Test QConfig eps validation against the BackendConfig's min scale value.
        """
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(30, 4).float()

            def forward(self, x):
                return self.linear(x)

        dtype_config = DTypeConfig(
            input_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12),
            output_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12),
            weight_dtype=DTypeWithConstraints(dtype=torch.qint8, scale_min_lower_bound=2 ** -12),
            bias_dtype=torch.float,
        )

        backend_config = BackendConfig() \
            .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear)
                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E128
                .add_dtype_config(dtype_config)
                .set_root_module(torch.nn.Linear)
                .set_reference_quantized_module(nnqr.Linear))

        def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool):
            self._validate_qconfig_against_backend_config_constraints(
                MyModel(), qconfig, backend_config, satisfies_constraints)

        # Case 1: QConfig min scale value == backend min scale value, OK
        qconfig1 = QConfig(
            activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -12),
            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -12))
        validate_qconfig(qconfig1, satisfies_constraints=True)

        # Case 2: QConfig min scale value > backend min scale value, OK
        qconfig2 = QConfig(
            activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -10),
            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -10))
        validate_qconfig(qconfig2, satisfies_constraints=True)

        # Case 3: QConfig activation min scale value < backend min scale value, should fail
        qconfig3 = QConfig(
            activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -14),
            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
        validate_qconfig(qconfig3, satisfies_constraints=False)

        # Case 3: QConfig weight min scale value < backend min scale value, should fail
        qconfig4 = QConfig(
            activation=MinMaxObserver.with_args(dtype=torch.quint8),
            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -14))
        validate_qconfig(qconfig4, satisfies_constraints=False)

        # Case 5: QConfig doesn't specify eps, should fail
        qconfig5 = QConfig(
            activation=FixedQParamsObserver.with_args(scale=1.0, zero_point=0),
            weight=FixedQParamsObserver.with_args(scale=1.0, zero_point=0))
        validate_qconfig(qconfig5, satisfies_constraints=False)

    def test_qnnpack_backend_config(self):
        """
        Test whether default QNNPACK QConfigs are compatible with the QNNPACK BackendConfig.
        """
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(30, 4).float()

            def forward(self, x):
                return self.linear(x)

        all_qconfigs: List[Tuple[QConfig, str]] = [
            (get_default_qconfig("qnnpack", version=0), "default_qnnpack_qconfig_v0"),
            (get_default_qat_qconfig("qnnpack", version=0), "default_qat_qnnpack_qconfig_v0"),
            (get_default_qat_qconfig("qnnpack", version=1), "default_qat_qnnpack_qconfig_v1"),
            (default_symmetric_qnnpack_qconfig, "default_symmetric_qnnpack_qconfig"),
            (default_symmetric_qnnpack_qat_qconfig, "default_symmetric_qnnpack_qat_qconfig"),
            # TODO: Test these QConfigs once they are fixed, see https://github.com/pytorch/pytorch/issues/85862
            # (default_per_channel_symmetric_qnnpack_qconfig, "default_per_channel_symmetric_qnnpack_qconfig"),
            # (default_per_channel_symmetric_qnnpack_qat_qconfig, "default_per_channel_symmetric_qnnpack_qat_qconfig"),
        ]
        backend_config = get_qnnpack_backend_config()
        for qconfig, qconfig_name in all_qconfigs:
            self._validate_qconfig_against_backend_config_constraints(
                MyModel(), qconfig, backend_config, satisfies_constraints=True, qconfig_name=qconfig_name)

    def test_symmetric_qnnpack_qconfig_mapping(self):
        """
        Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qconfig_mapping`
        works with the QNNPACK BackendConfig.
        """
        if "qnnpack" not in supported_qengines:
            return

        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(30, 4).float()

            def forward(self, x):
                return self.linear(x)

        with override_quantized_engine("qnnpack"):
            qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
            example_inputs = (torch.rand((1, 30), dtype=torch.float),)
            backend_config = get_qnnpack_backend_config()
            model = MyModel()
            model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
            model(*example_inputs)
            model = convert_fx(model, backend_config=backend_config)
            expected_node_occurrence = {
                ns.call_module(torch.ao.nn.quantized.Linear) : 1,
                ns.call_module(torch.nn.Linear) : 0,
            }
            self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
            model(*example_inputs)

    def test_symmetric_qnnpack_qat_qconfig_mapping(self):
        """
        Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qat_qconfig_mapping`
        works with the QNNPACK BackendConfig.
        """
        if "qnnpack" not in supported_qengines:
            return

        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(30, 4).float()

            def forward(self, x):
                return self.linear(x)

        with override_quantized_engine("qnnpack"):
            qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping()
            example_inputs = (torch.rand((1, 30), dtype=torch.float),)
            backend_config = get_qnnpack_backend_config()
            model = MyModel()
            model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
            model(*example_inputs)
            model = convert_fx(model, backend_config=backend_config)
            expected_node_occurrence = {
                ns.call_module(torch.ao.nn.quantized.Linear) : 1,
                ns.call_module(torch.nn.Linear) : 0,
            }
            self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
            model(*example_inputs)


    def test_get_executorch_backend_config(self):
        from torch.ao.quantization.backend_config import get_executorch_backend_config
        # make sure this runs
        executorch_backend_config = get_executorch_backend_config()

    def test_backend_config_check_for_weight_and_bias(self):
        """ Test to make sure the backend_config check for weight and bias
        runs when the qconfig is None for the ops with weight and bias
        previously the error was not hit because we first check input, and
        the check for weight and bias are skipped.
        """

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.tensor((5, 5))
                self.bias = torch.tensor((5,))

            def forward(self, x):
                return torch.addmm(self.bias, x, self.weight)

        m = M().eval()
        qconfig_mapping = QConfigMapping()
        observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
        weighted_op_quint8_dtype_config = DTypeConfig(
            input_dtype=torch.quint8,
            output_dtype=torch.quint8,
            weight_dtype=torch.qint8,
            bias_dtype=torch.float,
        )
        dtype_configs = [weighted_op_quint8_dtype_config]
        backend_pattern_config = BackendPatternConfig(torch.addmm) \
            .set_observation_type(observation_type) \
            .set_dtype_configs(dtype_configs) \
            ._set_input_type_to_index({"weight": 2, "bias": 0})
        backend_config = BackendConfig() \
            .set_backend_pattern_config(backend_pattern_config)
        example_inputs = (torch.rand(1, 5),)
        # make sure this runs
        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)

    def test_get_default_qconfig_valid_backend(self):
        """ Checks that AssertionError is raised when non expected backend input is specified
        """
        invalid_backends = ["imaginary_backend", 3]
        for invalid_backend in invalid_backends:
            with self.assertRaisesRegex(AssertionError, "not supported"):
                qconfig = get_default_qconfig(invalid_backend)
            with self.assertRaisesRegex(AssertionError, "not supported"):
                qconfig = get_default_qat_qconfig(invalid_backend)
            with self.assertRaisesRegex(AssertionError, "not supported"):
                qconfig_mapping = get_default_qconfig_mapping(invalid_backend)
            with self.assertRaisesRegex(AssertionError, "not supported"):
                qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend)

    def test__convert_to_reference_decomposed_fx(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        m = M().eval()
        qconfig_mapping = get_default_qconfig_mapping("fbgemm")
        example_inputs = (torch.randn(1, 5),)
        m = prepare_fx(m, qconfig_mapping, example_inputs)
        m_ref = copy.deepcopy(m)
        m_ref = convert_to_reference_fx(m_ref)
        m = _convert_to_reference_decomposed_fx(m)
        expected_occurrence = {
            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence)
        # make sure it runs
        res_ref = m_ref(*example_inputs)
        res = m(*example_inputs)
        self.assertEqual(res, res_ref)

    @skipIfNoQNNPACK
    def test__convert_to_reference_decomposed_fx_dynamic_quant(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        # to avoid reduce_range
        with override_quantized_engine("qnnpack"):
            m = M().eval()
            qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
                .set_object_type(torch.nn.Linear, default_dynamic_qconfig)
            example_inputs = (torch.randn(1, 5),)
            m = prepare_fx(m, qconfig_mapping, example_inputs)
            m(*example_inputs)
            m_ref = copy.deepcopy(m)
            m_ref = convert_to_reference_fx(m_ref)
            m = _convert_to_reference_decomposed_fx(m)
            expected_occurrence = {
                ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1,
                ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1,
                ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1,
            }
            self.checkGraphModuleNodes(
                m,
                expected_node_occurrence=expected_occurrence)
            # make sure it runs
            res_ref = m_ref(*example_inputs)
            res = m(*example_inputs)
            self.assertEqual(res, res_ref)

    def test__convert_to_reference_decomposed_fx_per_channel_quant(self):
        class M(torch.nn.Module):
            def forward(self, x, weight, bias):
                return F.linear(x, weight, bias)

        m = M().eval()
        qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
            .set_object_type(F.linear, default_per_channel_qconfig)
        example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,))
        m = prepare_fx(m, qconfig_mapping, example_inputs)
        m(*example_inputs)
        m_ref = copy.deepcopy(m)
        m_ref = convert_to_reference_fx(m_ref)
        m = _convert_to_reference_decomposed_fx(m)
        expected_occurrence = {
            # for input and output activations
            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
            # for weight
            ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence)
        # make sure it runs
        res_ref = m_ref(*example_inputs)
        res = m(*example_inputs)
        self.assertEqual(res, res_ref)

    def test_change_backend_config_for_fixed_qparam_ops(self):
        """ Making sure we can skip validation of qconfigs for fixedqparam ops based
        on BackendConfig
        """
        class M(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.tanh = torch.nn.Tanh()

            def forward(self, x: torch.Tensor):
                x = self.tanh(x)
                return x

        model = M().eval()
        # we set a global default_qconfig, which will be ignored since the backend
        # we defined doesn't support anything
        # this is to make sure we don't validate the qconfig when BackendConfig does not
        # have fixed qparam op related configurations
        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
        backend_config = BackendConfig()
        # make sure this runs
        model = prepare_fx(
            model,
            qconfig_mapping=qconfig_mapping,
            example_inputs=(torch.randn(1, 2, 3, 4),),
            backend_config=backend_config
        )

    def test_channel_shuffle_lowering(self):
        # Three versions of channel shuffle
        class M1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.op = torch.nn.ChannelShuffle(2)

            def forward(self, x):
                return self.op(x + x) + x

        class M2(torch.nn.Module):
            def forward(self, x):
                return torch.channel_shuffle(x + x, 2) + x

        class M3(torch.nn.Module):
            def forward(self, x):
                return torch.nn.functional.channel_shuffle(x + x, 2) + x

        x = torch.randn(4, 4, 4, 4)
        # torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle
        model_node_pairs = [
            (M1().eval(), ns.call_module(torch.nn.ChannelShuffle)),
            (M2().eval(), ns.call_function(torch.channel_shuffle)),
            (M3().eval(), ns.call_function(torch.channel_shuffle))
        ]
        for m, node in model_node_pairs:
            m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,))
            m_copy = copy.deepcopy(m)
            m = convert_fx(m)
            m_ref = convert_to_reference_fx(m_copy)
            node_occurrence = {
                node: 1,
                ns.call_function(torch.quantize_per_tensor): 1,
                ns.call_method("dequantize"): 1
            }
            node_occurrence_ref = {
                node: 1,
                ns.call_function(torch.quantize_per_tensor): 4,
                ns.call_method("dequantize"): 4
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
            self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)

    def test_match_pattern_with_multiple_args(self):
        """ Test that we can match a pattern that has multiple arguments
        Pattern:
                           shape \
        transpose (observed) -> reshape -> output (observed) ->

        where `reshape` has two arguments
        """

        def _get_pattern_configs():
            backend_pattern_configs = []
            observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
            weighted_op_quint8_dtype_config = DTypeConfig(
                input_dtype=torch.quint8,
                output_dtype=torch.quint8,
                weight_dtype=torch.qint8,
                bias_dtype=torch.float,
            )
            dtype_configs = [weighted_op_quint8_dtype_config]

            def root_node_getter(node_pattern):
                reshape, transpose, shape = node_pattern
                return transpose

            backend_pattern_configs.append(
                BackendPatternConfig()
                ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode))  # noqa: E131
                .set_observation_type(observation_type)
                .set_dtype_configs(dtype_configs)
                ._set_root_node_getter(root_node_getter)
            )
            return backend_pattern_configs

        backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs())

        class M(torch.nn.Module):
            def forward(self, x):
                x = torch.transpose(x, 0, 1)
                x = torch.reshape(x, (-1,))
                return x

        m = M().eval()
        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
        example_inputs = (torch.randn(1, 3, 3, 3),)
        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
        node_occurrence = {
            # one for input of the pattern and one for output of the pattern
            ns.call_module(MinMaxObserver): 2
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

    def _test_linear_activation_fusion_lowering_helper(
            self, module, example_inputs, qconfig_mapping,
            backend_config, fused_module, root_module, activation_module):
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1,
            ns.call_module(fused_module): 1,
            ns.call_module(root_module): 0,
            ns.call_module(activation_module): 0,
        }
        node_occurrence_ref = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2,
        }
        m = module.eval()
        m = prepare_fx(m, qconfig_mapping,
                       example_inputs=example_inputs,
                       backend_config=backend_config)
        m_copy = copy.deepcopy(m)
        m = convert_fx(m, backend_config=backend_config)
        m_ref = convert_to_reference_fx(m_copy)

        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
        m(*example_inputs)

    @skipIfNoONEDNN
    def test_linear_leaky_relu_lowering(self):
        """ Test fusion and lowering of Linear - (bn -) LeakyReLU
            by FX. For onednn backedn only.
        """
        from torch.ao.quantization.backend_config import get_onednn_backend_config
        qconfig_mapping = get_default_qconfig_mapping('onednn')
        with override_quantized_engine('onednn'):
            for with_bn in [True, False]:
                m = LinearBnLeakyReluModel(with_bn)
                self._test_linear_activation_fusion_lowering_helper(
                    m,
                    m.get_example_inputs(),
                    qconfig_mapping,
                    get_onednn_backend_config(),
                    nniq.LinearLeakyReLU,
                    nn.Linear,
                    nn.LeakyReLU)

    @skipIfNoONEDNN
    def test_linear_tanh_lowering(self):
        """ Test fusion and lowering of Linear - Tanh
            by FX. For onednn backedn only.
        """
        from torch.ao.quantization.backend_config import get_onednn_backend_config
        qconfig_mapping = get_default_qconfig_mapping('onednn')
        # TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
        #      Need to be able to support fusion of ops with different qconfigs
        # Since tanh must have 'fixed_qparams_qconfig' while linear should use
        # the global qconfig, we need to set qconfigs for them manually here for
        # fusion and cannot put such configs in onednn's default qconfig_mapping.
        # Known issue:
        # Cannot fuse linear - tanh and quantize standalone tanh at the same time.
        qconfig = get_default_qconfig('onednn')
        qconfig_mapping.set_object_type(torch.nn.Linear, qconfig)
        qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig)
        with override_quantized_engine('onednn'):
            m = LinearTanhModel()
            self._test_linear_activation_fusion_lowering_helper(
                m,
                m.get_example_inputs(),
                qconfig_mapping,
                get_onednn_backend_config(),
                nniq.LinearTanh,
                nn.Linear,
                nn.Tanh)

    @override_qengines
    def test_linear_size_view(self):
        class M(torch.nn.Module):
            def __init__(self, use_relu=False):
                super().__init__()
                self.linear = torch.nn.Linear(16, 32)
                self.relu = torch.nn.ReLU()
                self.use_relu = use_relu

            def forward(self, x):
                x = self.linear(x)
                if self.use_relu:
                    x = self.relu(x)
                return x.view(x.size(0), 1, 4, 8)

        for use_relu in [False, True]:
            model_fp32 = M(use_relu).eval()
            qengine = torch.backends.quantized.engine
            qconfig_mapping = get_default_qconfig_mapping(qengine)
            x = torch.randn((5, 16))
            model_fp32(x)
            prepared_model = prepare_fx(model_fp32, qconfig_mapping, x)
            prepared_model(x)
            quantized_model = convert_fx(prepared_model)
            node_occurrence = {
                ns.call_module(nnq.Linear): 0 if use_relu else 1,
                ns.call_module(nniq.LinearReLU): 1 if use_relu else 0,
                ns.call_function(torch.quantize_per_tensor): 1,
                ns.call_method("dequantize"): 1
            }
            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)

    @override_qengines
    def test_linear_shape_view(self):
        class M(torch.nn.Module):
            def __init__(self, use_relu=False):
                super().__init__()
                self.linear = torch.nn.Linear(16, 32)
                self.relu = torch.nn.ReLU()
                self.use_relu = use_relu

            def forward(self, x):
                x = self.linear(x)
                if self.use_relu:
                    x = self.relu(x)
                return x.view(x.shape[0], 1, 4, 8)

        for use_relu in [False, True]:
            model_fp32 = M(use_relu).eval()
            qengine = torch.backends.quantized.engine
            qconfig_mapping = get_default_qconfig_mapping(qengine)
            x = torch.randn((5, 16))
            model_fp32(x)
            prepared_model = prepare_fx(model_fp32, qconfig_mapping, x)
            prepared_model(x)
            quantized_model = convert_fx(prepared_model)
            node_occurrence = {
                ns.call_module(nnq.Linear): 0 if use_relu else 1,
                ns.call_module(nniq.LinearReLU): 1 if use_relu else 0,
                ns.call_function(torch.quantize_per_tensor): 1,
                ns.call_method("dequantize"): 1
            }
            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)

    def test_mixed_dtypes(self):
        """
        Test that multiple dtypes can be used in the same model for different layers,
        and the dtypes will be converted correctly between the layers.
        """
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = torch.nn.Linear(5, 5)
                self.linear2 = torch.nn.Linear(5, 5)
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()
                self.float_functional = torch.ao.nn.quantized.FloatFunctional()

            def forward(self, x: torch.Tensor):
                x = self.linear1(x)  # qint32
                x = self.linear2(x)  # quint8
                linear2 = x
                x = self.sigmoid(x)  # back to qint32
                x = self.tanh(x)  # back to quint8
                x = self.float_functional.add(linear2, x)  # adding two quint8's together
                return x

        def make_qconfig(scale, zp, dtype):
            return QConfig(
                activation=FixedQParamsObserver.with_args(scale=scale, zero_point=zp, dtype=dtype),
                weight=torch.ao.quantization.default_weight_observer)

        # Set up a QConfigMapping that specifies different qparams and dtypes for different layers
        qconfig_mapping = QConfigMapping() \
            .set_global(get_default_qconfig("qnnpack")) \
            .set_module_name("linear1", make_qconfig(1234, 11, torch.qint32)) \
            .set_module_name("linear2", make_qconfig(2345, 22, torch.quint8)) \
            .set_object_type(torch.nn.Sigmoid, make_qconfig(3456, 33, torch.qint32)) \
            .set_object_type(torch.nn.Tanh, make_qconfig(4567, 44, torch.quint8))

        # Set up BackendConfig that supports the dtypes configured in the above QConfigMapping
        weighted_op_qint32_dtype_config = DTypeConfig(
            input_dtype=torch.qint32,
            output_dtype=torch.qint32,
            weight_dtype=torch.qint8,
            bias_dtype=torch.float,
        )
        fixed_qparams_op_quint8_dtype_config = DTypeConfig(
            input_dtype=torch.quint8,
            output_dtype=torch.quint8,
        )
        fixed_qparams_op_qint32_dtype_config = DTypeConfig(
            input_dtype=torch.qint32,
            output_dtype=torch.qint32,
        )
        backend_config = get_qnnpack_backend_config()
        for config in backend_config.configs:
            if config.pattern == torch.nn.Linear:
                config.add_dtype_config(weighted_op_qint32_dtype_config)
            elif config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh]:
                config.add_dtype_config(fixed_qparams_op_quint8_dtype_config)
                config.add_dtype_config(fixed_qparams_op_qint32_dtype_config)

        # Produce the reference quantized model
        m = MyModule()
        example_inputs = (torch.rand(5, 5),)
        prepared = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
        prepared(*example_inputs)  # calibrate
        converted = convert_to_reference_fx(prepared, backend_config=backend_config)
        converted(*example_inputs)

        # Verify that the reference model is correct
        #
        # Reference model until add should be:
        # fp32_input -> q_to_int32 -> [dq -> linear1_fp32 -> q_to_int32] -> dq ->
        # q_to_uint8 -> [dq -> linear2_fp32 -> q_to_uint8] -> dq (linear2_dq) ->
        # q_to_int32 -> [dq -> sigmoid_fp32 -> q_to_int32] -> dq ->
        # q_to_uint8 -> [dq -> tanh_fp32 -> q_to_uint8] -> dq (tanh_dq)
        #
        # Complete reference model with add should be:
        # [(linear2_dq, tanh_dq) -> add_fp32 -> q_to_uint8] -> dq -> fp32_output

        target_to_expected_dtypes = {
            "linear1": torch.qint32,
            "linear2": torch.quint8,
            "sigmoid": torch.qint32,
            "tanh": torch.quint8,
            torch.add: torch.quint8,
        }
        # Find the patterns [dq - op_fp32 - q_to_specific_dtype] in the graph
        linear2_node = tanh_node = None
        for node in converted.graph.nodes:
            if node.target not in target_to_expected_dtypes:
                continue

            # Match preceding dequantize
            self.assertTrue(len(node.args) == 1 or len(node.args) == 2)
            self.assertTrue(all(arg.target == "dequantize" for arg in node.args))

            # Match following quantize with the specific dtypes
            self.assertEqual(len(node.users), 1)
            user = next(iter(node.users.keys()))
            self.assertEqual(user.target, torch.quantize_per_tensor)
            self.assertEqual(user.args[-1], target_to_expected_dtypes[node.target])

            # Match [dq - torch.add(linear2_dq, tanh_dq) - q]
            if node.target == "linear2":
                linear2_node = node
            elif node.target == "tanh":
                tanh_node = node
            elif node.target == torch.add:
                linear2_dq, tanh_dq = node.args
                self.assertEqual(tanh_dq.args[0].args[0], tanh_node)
                self.assertEqual(linear2_dq.args[0].args[0], linear2_node)

    def test_lowering_functional_conv_with_kwargs(self):
        dim_to_op = {
            1: F.conv1d,
            2: F.conv2d,
            3: F.conv3d,
        }
        dim_to_qop = {
            1: torch.ops.quantized.conv1d,
            2: torch.ops.quantized.conv2d,
            3: torch.ops.quantized.conv3d,
        }

        class Mod(nn.Module):
            def __init__(self, in_channels, out_channels, kernel_size, dimension):
                super().__init__()
                self.dim = dimension
                self.op = dim_to_op[dimension]
                kernel_sizes = [kernel_size] * self.dim
                self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_sizes))

            def forward(self, input):
                return self.op(input, self.weight, bias=None, stride=[1] * self.dim,
                               padding=[0] * self.dim, dilation=[1] * self.dim, groups=1)

        for dimension in [1, 2, 3]:
            model = Mod(3, 16, 3, dimension)
            model.eval()
            qconfig_mapping = get_default_qconfig_mapping()
            input_shape = (1, 3, *([8] * dimension))
            example_inputs = torch.randn(input_shape)
            prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
            prepared_model(example_inputs)
            quantized_model = convert_fx(prepared_model)
            # This should pass
            quantized_model(example_inputs)
            # Ensure the quantized model has the expected op
            node_occurrence = {
                ns.call_function(dim_to_qop[dimension]): 1,
            }
            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)

    def test_lowering_functional_conv_transpose_with_kwargs(self):
        dim_to_op = {
            1: F.conv_transpose1d,
            2: F.conv_transpose2d,
            3: F.conv_transpose3d,
        }
        dim_to_qop = {
            1: torch.ops.quantized.conv_transpose1d,
            2: torch.ops.quantized.conv_transpose2d,
            3: torch.ops.quantized.conv_transpose3d,
        }

        class Mod(nn.Module):
            def __init__(self, in_channels, out_channels, kernel_size, dimension):
                super().__init__()
                self.dim = dimension
                self.op = dim_to_op[dimension]
                kernel_sizes = [kernel_size] * self.dim
                self.weight = nn.Parameter(torch.randn(in_channels, out_channels, *kernel_sizes))

            def forward(self, input):
                return self.op(input, self.weight, bias=None, stride=[1] * self.dim,
                               padding=[0] * self.dim, output_padding=[0] * self.dim,
                               dilation=[1] * self.dim, groups=1)

        for dimension in [1, 2, 3]:
            model = Mod(3, 16, 3, dimension)
            model.eval()
            qconfig_mapping = get_default_qconfig_mapping()
            input_shape = (1, 3, *([8] * dimension))
            example_inputs = torch.randn(input_shape)
            prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
            prepared_model(example_inputs)
            quantized_model = convert_fx(prepared_model)
            # This should pass
            quantized_model(example_inputs)
            # Ensure the quantized model has the expected op
            node_occurrence = {
                ns.call_function(dim_to_qop[dimension]): 1,
            }
            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)

    def test_lowering_functional_linear_with_kwargs(self):
        class Mod(nn.Module):
            def __init__(self, in_channels, out_channels):
                super().__init__()
                self.weight = nn.Parameter(torch.randn(out_channels, in_channels))

            def forward(self, input):
                return F.linear(input, self.weight, bias=None)

        model = Mod(8, 4)
        model.eval()
        qconfig_mapping = get_default_qconfig_mapping()
        example_inputs = torch.randn(1, 8)
        prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
        prepared_model(example_inputs)
        quantized_model = convert_fx(prepared_model)
        # This should pass
        quantized_model(example_inputs)
        # Ensure the quantized model has the expected op
        node_occurrence = {
            ns.call_function(torch.ops.quantized.linear): 1,
        }
        self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
    def setUp(self):
        super().setUp()
        self.custom_qconfig = torch.ao.quantization.QConfig(
            activation=torch.ao.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
            ),
            weight=torch.ao.quantization.default_per_channel_weight_observer
        )
        self.common_quant_patterns = {
            torch.nn.ConvTranspose1d: DefaultNodeQuantizeHandler,
            torch.nn.ConvTranspose2d: DefaultNodeQuantizeHandler,
            torch.nn.ELU: DefaultNodeQuantizeHandler,
            torch.nn.LeakyReLU: DefaultNodeQuantizeHandler,
            torch.nn.Hardswish: DefaultNodeQuantizeHandler,
            torch.nn.InstanceNorm1d: DefaultNodeQuantizeHandler,
            torch.nn.InstanceNorm2d: DefaultNodeQuantizeHandler,
            torch.nn.InstanceNorm3d: DefaultNodeQuantizeHandler,
            torch.nn.LayerNorm: DefaultNodeQuantizeHandler,
            torch.nn.SiLU: DefaultNodeQuantizeHandler,
            torch.nn.Mish: DefaultNodeQuantizeHandler,
            torch.nn.GELU: DefaultNodeQuantizeHandler,
            torch.nn.Softmax: DefaultNodeQuantizeHandler,
            torch.nn.functional.elu: DefaultNodeQuantizeHandler,
            torch.nn.functional.hardswish: DefaultNodeQuantizeHandler,
            torch.nn.functional.instance_norm: DefaultNodeQuantizeHandler,
            torch.nn.functional.layer_norm: DefaultNodeQuantizeHandler,
            torch.nn.functional.leaky_relu: DefaultNodeQuantizeHandler,
            torch.nn.functional.silu: DefaultNodeQuantizeHandler,
            torch.nn.functional.mish: DefaultNodeQuantizeHandler,
            torch.nn.functional.gelu: DefaultNodeQuantizeHandler,
            torch.nn.functional.softmax: DefaultNodeQuantizeHandler,
            torch.sum: DefaultNodeQuantizeHandler
        }

    """Unit tests for individual ops
    """
    @skipIfNoFBGEMM
    def test_linear_module(self):
        with override_quantized_engine('fbgemm'):
            class LinearModel(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = torch.nn.Linear(30, 4).float()

                def forward(self, x):
                    return self.linear(x)

            class LinearReLUModel(torch.nn.Module):
                def __init__(self, f_relu=False):
                    super().__init__()
                    self.linear = torch.nn.Linear(30, 4).float()
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()

                def forward(self, x):
                    x = self.linear(x)
                    x = self.relu(x)
                    return x

            class LinearBnModel(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = torch.nn.Linear(4, 4).float()
                    self.bn = torch.nn.BatchNorm1d(4)

                def forward(self, x):
                    x = self.linear(x)
                    x = self.bn(x)
                    return x

            # Test linear
            data = (torch.rand((1, 30), dtype=torch.float),)
            for quant_type in self.all_quant_types:
                model = LinearModel()
                quantized_module = nnqd.Linear if quant_type == QuantType.DYNAMIC else nnq.Linear
                quantized_node = ns.call_module(quantized_module)
                result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
                if quant_type in self.static_quant_types:
                    self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])

            # TODO: enable test for dynamic quant
            # Test linear-relu
            for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
                model = LinearReLUModel(f_relu)
                quantized_node = ns.call_module(nniq.LinearReLU)
                result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
                self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])

            # Test linear-bn
            data = (torch.rand((4, 4), dtype=torch.float),)
            for quant_type in self.static_quant_types:
                model = LinearBnModel()
                quantized_node = ns.call_module(nnq.Linear)
                result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
                self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])

    @skipIfNoFBGEMM
    def test_functional_linear(self):
        with override_quantized_engine('fbgemm'):
            class FuncLinear(torch.nn.Module):
                def __init__(self, use_bias, has_relu, f_relu):
                    super().__init__()
                    self.w = torch.randn(4, 30)
                    self.b = torch.randn(4)
                    self.use_bias = use_bias
                    if has_relu:
                        if f_relu:
                            self.relu_or_id = F.relu
                        else:
                            self.relu_or_id = torch.nn.ReLU()
                    else:
                        self.relu_or_id = torch.nn.Identity()

                def forward(self, x):
                    if self.use_bias:
                        x = F.linear(x, self.w, self.b)
                    else:
                        x = F.linear(x, self.w)
                    x = self.relu_or_id(x)
                    return x

            data = (torch.rand((1, 30), dtype=torch.float),)
            quant_type_to_qlinear_fun = {
                QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
                QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
                QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
            }
            quant_type_to_qlinear_relu_fun = {
                # we don't have linear_relu_dynamic
                QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic),
                QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu),
                QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu),
            }

            options = itertools.product(
                self.all_quant_types,
                (True, False),  # use_bias
                (True, False),  # has_relu
                (True, False),  # functional relu
            )
            for quant_type, use_bias, has_relu, f_relu in options:
                # when has_relu is False, we are using an nn.Identity and
                # we will insert observer/fake_quant for the output of nn.Identity since
                # it is a copy node, that's why we have extra observer/fake_quant
                # when has_relu is False
                quant_type_to_prepare_expected_node_occurrence = {
                    QuantType.DYNAMIC: {
                        ns.call_module(torch.ao.quantization.PlaceholderObserver): 1,
                        ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
                    },
                    # There should be 3 observers: after input, weight and activation.
                    # one more observer for torch.nn.Identity when there is no relu
                    QuantType.STATIC: {
                        ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3,
                        ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1,
                    },
                    # There should be 3 observers: after input, weight and activation.
                    QuantType.QAT: {
                        ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4,
                    },
                }
                model = FuncLinear(use_bias, has_relu, f_relu)
                if has_relu:
                    qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type]
                else:
                    qlinear_fun = quant_type_to_qlinear_fun[quant_type]

                if quant_type != QuantType.DYNAMIC:
                    num_dequantize = 1
                else:
                    # we will have an extra quantize_per_tensor_dynamic + dequantize for
                    # nn.Identity right now, but it will be fixed after we use
                    # backend_config to configure the default pt backend
                    num_dequantize = int(not has_relu)

                convert_node_occurrence = {
                    ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
                    qlinear_fun: 1,
                    ns.call_method("dequantize"): num_dequantize if quant_type != QuantType.DYNAMIC else 0,
                }
                prepare_expected_node_occurrence = \
                    quant_type_to_prepare_expected_node_occurrence[quant_type]
                result_dict = self.checkGraphModeFxOp(
                    model, data, quant_type, qlinear_fun,
                    prepare_expected_node_occurrence=prepare_expected_node_occurrence,
                    expected_node_occurrence=convert_node_occurrence)
                if quant_type != QuantType.DYNAMIC:
                    self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
                    # Ensure packed weights in lowered models are folded
                    self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys())

    @skipIfNoFBGEMM
    def test_linear_dynamic_fp16(self):
        with override_quantized_engine('fbgemm'):
            class FuncLinear(torch.nn.Module):
                def __init__(self, use_bias, has_relu, f_relu):
                    super().__init__()
                    self.w = torch.randn(4, 30)
                    self.b = torch.randn(4)
                    self.use_bias = use_bias
                    if has_relu:
                        if f_relu:
                            self.relu = F.relu
                        else:
                            self.relu = torch.nn.ReLU()
                    else:
                        self.relu = torch.nn.Identity()

                def forward(self, x):
                    if self.use_bias:
                        x = F.linear(x, self.w, self.b)
                    else:
                        x = F.linear(x, self.w)
                    x = self.relu(x)
                    return x

            data = (torch.rand((1, 30), dtype=torch.float),)
            options = itertools.product(
                (True, False),  # use_bias
                (True, False),  # has_relu
                (True, False),  # functional relu
                (True, False),  # is_reference
            )
            for use_bias, has_relu, f_relu, is_reference in options:
                model = FuncLinear(use_bias, has_relu, f_relu)
                if is_reference:
                    qlinear_fun = ns.call_function(torch.nn.functional.linear)
                else:
                    if has_relu:
                        qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16)
                    else:
                        qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16)
                prepare_node_occurrence = {
                    # activation and weight
                    ns.call_module(torch.ao.quantization.PlaceholderObserver): 2
                }
                convert_node_occurrence = {
                    qlinear_fun: 1,
                    # weight
                    ns.call_method("to"): 1 if is_reference else 0
                }
                self.checkGraphModeFxOp(
                    model, data, QuantType.DYNAMIC, qlinear_fun,
                    is_reference=is_reference,
                    custom_qconfig_dict={"": float16_dynamic_qconfig},
                    prepare_expected_node_occurrence=prepare_node_occurrence,
                    expected_node_occurrence=convert_node_occurrence)

    def test_linear_static_fp16(self):
        class FuncLinear(torch.nn.Module):
            def __init__(self, use_bias, has_relu, f_relu):
                super().__init__()
                self.w = torch.randn(4, 30)
                self.b = torch.randn(4)
                self.use_bias = use_bias
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                if self.use_bias:
                    x = F.linear(x, self.w, self.b)
                else:
                    x = F.linear(x, self.w)
                x = self.relu(x)
                return x

        data = (torch.rand((1, 30), dtype=torch.float),)
        options = itertools.product(
            (True, False),  # use_bias
            (True, False),  # has_relu
            (True, False),  # functional relu
            (True, False),  # is_reference
        )
        backend_config = get_test_only_legacy_native_backend_config()
        for use_bias, has_relu, f_relu, is_reference in options:
            model = FuncLinear(use_bias, has_relu, f_relu)
            linear_fun = ns.call_function(torch.nn.functional.linear)
            # when has_relu is False, we are using an nn.Identity and
            # we will insert observer/fake_quant for the output of nn.Identity since
            # it is a copy node, that's why we have extra observer/fake_quant
            # when has_relu is False
            prepare_node_occurrence = {
                # activation, weight, bias and output
                ns.call_module(torch.ao.quantization.PlaceholderObserver): 3 + int(use_bias) + int(not has_relu),
            }
            # We have extra to and dequantize when is_reference is True
            # and has_relu is False since when has_relu is False, we
            # have an nn.Identity in the model, which is a CopyNode
            # and we would add extra quant - dequant for CopyNode in
            # reference patterns
            convert_node_occurrence = {
                # we don't support static fp16 ops, so the linear function
                # is unfused
                linear_fun: 1,
                # activation, weight, bias and output
                ns.call_method("to"): 3 + int(use_bias) + int(not has_relu and is_reference),
                ns.call_method("dequantize"): 3 + int(use_bias) + int(not has_relu and is_reference)
            }
            self.checkGraphModeFxOp(
                model, data, QuantType.DYNAMIC, linear_fun,
                is_reference=is_reference,
                custom_qconfig_dict={"": float16_static_qconfig},
                prepare_expected_node_occurrence=prepare_node_occurrence,
                expected_node_occurrence=convert_node_occurrence,
                backend_config=backend_config)

    @skipIfNoFBGEMM
    def test_conv_module(self):
        conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}

        class ConvWrapper(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()

            def forward(self, x):
                return self.conv(x)

        options = itertools.product([1, 2, 3], self.static_quant_types)
        quantized_nodes = {
            # dim
            1: ns.call_module(nnq.Conv1d),
            2: ns.call_module(nnq.Conv2d),
            3: ns.call_module(nnq.Conv3d),
        }
        for dim, quant_type in options:
            self.checkGraphModeFxOp(
                ConvWrapper(dim), self.img_data_dict[dim], quant_type,
                quantized_nodes[dim])

    @skipIfNoFBGEMM
    def test_functional_conv(self):
        with override_quantized_engine('fbgemm'):
            """ Test for function conv and functional conv + relu
            """
            convs = {
                1: torch.nn.functional.conv1d,
                2: torch.nn.functional.conv2d,
                3: torch.nn.functional.conv3d,
            }

            class FuncConv(torch.nn.Module):
                def __init__(self, dim, use_bias, has_relu, f_relu):
                    super().__init__()
                    self.dim = dim
                    self.w = torch.randn(tuple([3] * (dim + 2)))
                    self.b = torch.randn(3) if use_bias else None
                    self.stride = tuple([1] * dim)
                    self.padding = tuple([0] * dim)
                    self.dilation = tuple([1] * dim)
                    self.groups = 1
                    self.use_bias = use_bias
                    if has_relu:
                        if f_relu:
                            self.relu = F.relu
                        else:
                            self.relu = torch.nn.ReLU()
                    else:
                        self.relu = torch.nn.Identity()

                def forward(self, x):
                    x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
                    x = self.relu(x)
                    return x

            quant_type_to_qconv_fun = {
                QuantType.STATIC: {
                    1: ns.call_function(torch.ops.quantized.conv1d),
                    2: ns.call_function(torch.ops.quantized.conv2d),
                    3: ns.call_function(torch.ops.quantized.conv3d)
                },
                QuantType.QAT: {
                    1: ns.call_function(torch.ops.quantized.conv1d),
                    2: ns.call_function(torch.ops.quantized.conv2d),
                    3: ns.call_function(torch.ops.quantized.conv3d)
                },
            }
            quant_type_to_qconv_relu_fun = {
                QuantType.STATIC: {
                    1: ns.call_function(torch.ops.quantized.conv1d_relu),
                    2: ns.call_function(torch.ops.quantized.conv2d_relu),
                    3: ns.call_function(torch.ops.quantized.conv3d_relu)
                },
                QuantType.QAT: {
                    1: ns.call_function(torch.ops.quantized.conv1d_relu),
                    2: ns.call_function(torch.ops.quantized.conv2d_relu),
                    3: ns.call_function(torch.ops.quantized.conv3d_relu)
                },
            }

            options = itertools.product(
                [1, 2, 3],  # dims
                self.static_quant_types,
                (True, False),  # use_bias
                (True, False),  # has_relu
                (True, False),  # functional relu
            )
            for dim, quant_type, use_bias, has_relu, f_relu in options:
                # when has_relu is False, we are using an nn.Identity and
                # we will insert observer/fake_quant for the output of nn.Identity since
                # it is a copy node, that's why we have extra observer/fake_quant
                # when has_relu is False
                quant_type_to_prepare_expected_node_occurrence = {
                    QuantType.DYNAMIC: {},
                    # There should be 3 observers: after input, weight and activation.
                    QuantType.STATIC: {
                        ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3,
                        ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1,
                    },
                    # There should be 3 observers: after input, weight and activation.
                    QuantType.QAT: {
                        ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4,
                    },
                }
                data_dims = [2, 3] + [4] * dim
                data = (torch.randn(tuple(data_dims), dtype=torch.float),)
                model = FuncConv(dim, use_bias, has_relu, f_relu)
                if has_relu:
                    qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim]
                else:
                    qconv_fun = quant_type_to_qconv_fun[quant_type][dim]

                convert_node_occurrence = {
                    ns.call_function(torch.quantize_per_tensor): 1,
                    qconv_fun: 1,
                    ns.call_method("dequantize"): 1
                }
                prepare_expected_node_occurrence = \
                    quant_type_to_prepare_expected_node_occurrence[quant_type]
                result_dict = self.checkGraphModeFxOp(
                    model, data, quant_type, qconv_fun,
                    prepare_expected_node_occurrence=prepare_expected_node_occurrence,
                    expected_node_occurrence=convert_node_occurrence)
                if quant_type != QuantType.DYNAMIC:
                    self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
                    # Ensure packed weights in lowered models are folded
                    self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys())

    @skipIfNoFBGEMM
    def test_quantized_conv_relu(self):
        """tests for conv1d_relu/conv2d_relu/conv3d_relu"""
        conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}

        class ConvNdRelu(torch.nn.Module):
            def __init__(self, dim, inplace):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()
                self.relu = torch.nn.ReLU(inplace)

            def forward(self, x):
                return self.relu(self.conv(x))

        class ConvNdFunctionalRelu(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()

            def forward(self, x):
                return F.relu(self.conv(x))

        class ConvNdInplaceFunctionalRelu(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()

            def forward(self, x):
                return F.relu(self.conv(x), True)

        options = itertools.product([1, 2, 3], self.static_quant_types)
        quantized_nodes = {
            # dim
            1: ns.call_module(nniq.ConvReLU1d),
            2: ns.call_module(nniq.ConvReLU2d),
            3: ns.call_module(nniq.ConvReLU3d),
        }
        for dim, quant_type in options:
            for m in [ConvNdRelu(dim, True),
                      ConvNdRelu(dim, False),
                      ConvNdFunctionalRelu(dim),
                      ConvNdInplaceFunctionalRelu(dim)]:
                self.checkGraphModeFxOp(
                    m, self.img_data_dict[dim], quant_type,
                    quantized_nodes[dim])


    def _test_binary_op_int8_impl(self, binary_op, ibinary_op, quantized_op):
        data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
                torch.randn(1, 1, 1, 1, dtype=torch.float))
        options = itertools.product([True, False], [True, False], [True, False])
        quant_type = QuantType.STATIC
        # testing for default int8 static quant
        for is_inplace, is_scalar, is_reference in options:
            if is_reference:
                node_list = [
                    ns.call_method("dequantize"),
                    ns.call_function(binary_op),
                    ns.call_function(torch.quantize_per_tensor)
                ]
                quantized_node = None
            else:
                node_list = None
                quantized_node = ns.call_function(quantized_op)

            self.checkGraphModeFxOp(
                BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
                quantized_node, expected_node_list=node_list, is_reference=is_reference)
            # This tests the binary op should be quantized even when it is not feed with a
            # quantized input
            self.checkGraphModeFxOp(
                BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar),
                data, quant_type, quantized_node,
                expected_node_list=node_list, is_reference=is_reference)


    def _test_binary_op_float16_impl(self, binary_op, ibinary_op):
        data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
                torch.randn(1, 1, 1, 1, dtype=torch.float))
        quant_type = QuantType.STATIC
        # testing for fp16 static quant
        # we are producing fp16 patterns
        options = itertools.product([True, False], [True, False])
        custom_qconfig_dict = {
            "object_type": [(binary_op, float16_static_qconfig)]
        }
        backend_config = get_test_only_legacy_native_backend_config()
        for is_inplace, is_scalar in options:
            node_occurrence = {
                # output_conv1, output_add1, output_add2 for scalar
                # output_conv1, output_conv2, output_add1, output_add2 for non-scalar
                ns.call_method("to"): 3 if is_scalar else 4
            }
            self.checkGraphModeFxOp(
                BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
                expected_node_occurrence=node_occurrence,
                custom_qconfig_dict=custom_qconfig_dict,
                backend_config=backend_config)

            node_occurrence = {
                # input_add, output_add for scalar
                # input_add1, input_add2, output_add for non-scalar
                ns.call_method("to"): 2 if is_scalar else 3
            }
            self.checkGraphModeFxOp(
                BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
                expected_node_occurrence=node_occurrence,
                custom_qconfig_dict=custom_qconfig_dict,
                backend_config=backend_config)

    def _test_binary_op_relu_int8_impl(self, binary_op, ibinary_op, quantized_op):
        data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
                torch.rand((1, 1, 1, 1), dtype=torch.float))
        quant_type = QuantType.STATIC
        quantized_node = ns.call_function(quantized_op)
        options = itertools.product(
            [True, False], [nn.ReLU, F.relu, torch.relu], [True, False])
        for is_inplace_op, relu_callable, is_scalar in options:
            model = BinaryOpRelu(
                binary_op, ibinary_op, is_inplace_op, relu_callable, is_scalar)
            self.checkGraphModeFxOp(
                model, data, quant_type, quantized_node)

    def _test_binary_op_relu_float16_impl(self, binary_op, ibinary_op):
        data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
                torch.rand((1, 1, 1, 1), dtype=torch.float))
        quant_type = QuantType.STATIC
        options = itertools.product(
            [True, False], [nn.ReLU, F.relu, torch.relu], [True, False])
        custom_qconfig_dict = {
            "": float16_static_qconfig,
            "object_type": [(torch.nn.Conv2d, None)]
        }
        backend_config = get_test_only_legacy_native_backend_config()
        for is_inplace_op, is_functional_relu, is_scalar in options:
            node_occurrence = {
                ns.call_method("to"): 3 if is_scalar else 4
            }
            model = BinaryOpRelu(
                binary_op, ibinary_op, is_inplace_op, is_functional_relu, is_scalar)
            self.checkGraphModeFxOp(
                model, data, quant_type, custom_qconfig_dict=custom_qconfig_dict,
                expected_node_occurrence=node_occurrence,
                backend_config=backend_config)


    @skipIfNoFBGEMM
    def test_add(self):
        self._test_binary_op_int8_impl(
            operator.add, operator.iadd, torch.ops.quantized.add)
        self._test_binary_op_float16_impl(
            operator.add, operator.iadd)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_sub(self):
        self._test_binary_op_float16_impl(operator.sub, operator.isub)
        self._test_binary_op_float16_impl(torch.sub, None)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_div(self):
        self._test_binary_op_float16_impl(operator.truediv, operator.itruediv)
        self._test_binary_op_float16_impl(torch.div, None)

    @skipIfNoFBGEMM
    def test_mul(self):
        self._test_binary_op_int8_impl(
            operator.mul, operator.imul, torch.ops.quantized.mul)
        self._test_binary_op_float16_impl(operator.mul, operator.imul)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_sum(self):
        class Sum(torch.nn.Module):
            def forward(self, x):
                x = torch.sum(x, [1], keepdim=True)
                x = torch.sum(x, [1])
                return x

        data = torch.randn(1, 2, 3, 4, dtype=torch.float)
        quant_type = QuantType.STATIC
        # testing for fp16 static quant
        # we are producing fp16 patterns
        custom_qconfig_dict = {
            "object_type": [(torch.sum, float16_static_qconfig)]
        }
        node_occurrence = {
            # input_sum1, output_sum1, output_sum2
            ns.call_method("to"): 3
        }
        self.checkGraphModeFxOp(
            Sum(), data, quant_type,
            expected_node_occurrence=node_occurrence,
            custom_qconfig_dict=custom_qconfig_dict)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_bmm(self):
        class BMMMethod(torch.nn.Module):
            def forward(self, x, y):
                return x.bmm(y)

        data = (torch.randn(1, 1, 1, dtype=torch.float),
                torch.randn(1, 1, 1, dtype=torch.float))
        quant_type = QuantType.STATIC
        # testing for fp16 static quant
        # we are producing fp16 patterns
        custom_qconfig_dict = {
            "object_type": [(torch.bmm, float16_static_qconfig),
                            ("bmm", float16_static_qconfig)]
        }
        node_occurrence = {
            # input_bmm1, input_bmm2, output_bmm
            ns.call_method("to"): 3
        }
        self.checkGraphModeFxOp(
            BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type,
            expected_node_occurrence=node_occurrence,
            custom_qconfig_dict=custom_qconfig_dict)

        # TODO: support call_method("bmm")
        # we can transform call_method("bmm") to call_function(torch.bmm)
        # self.checkGraphModeFxOp(
        #     BMMMethod(), data, quant_type,
        #     expected_node_occurrence=node_occurrence,
        #     custom_qconfig_dict=custom_qconfig_dict,
        #     print_debug_info=True)

    @skipIfNoFBGEMM
    def test_add_relu(self):
        self._test_binary_op_relu_int8_impl(
            operator.add, operator.iadd, torch.ops.quantized.add_relu)
        self._test_binary_op_relu_float16_impl(
            operator.add, operator.iadd)

    @skipIfNoFBGEMM
    def test_add_relu_multiple_uses_of_relu(self):
        class Sub(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU(inplace=True)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = Sub()

            def forward(self, x, y):
                x = x + y
                x = self.sub.relu(x)
                x = x + y
                x = self.sub.relu(x)
                return x

        m = M().eval()
        example_inputs = (torch.randn(3), torch.randn(3))
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        m = convert_fx(m)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_function(torch.ops.quantized.add_relu): 2,
            ns.call_method("dequantize"): 1,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        # check the model is scriptable
        m = torch.jit.script(m)
        # check the model is runnable
        m(*example_inputs)

    @skipIfNoFBGEMM
    def test_mul_relu(self):
        self._test_binary_op_relu_int8_impl(
            operator.mul, operator.imul, torch.ops.quantized.mul_relu)
        self._test_binary_op_relu_float16_impl(
            operator.mul, operator.imul)

    # TODO(future PR): make more generic
    def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence):
        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
        mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
        self.checkGraphModuleNodes(
            mp, expected_node_occurrence=expected_node_occurrence)

    @skipIfNoFBGEMM
    def test_quantized_add_qat(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = torch.add(x, 1.0)
                x = self.conv1(x)
                x = torch.add(x, 1.0)
                x = torch.relu(x)
                x = self.conv2(x)
                return x

        m = M()
        example_inputs = (torch.randn(1, 1, 1, 1),)
        expected_node_occurrence = {
            ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
        }
        self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence)

    @skipIfNoFBGEMM
    def test_quantized_mul_qat(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = torch.mul(x, 1.0)
                x = self.conv1(x)
                x = torch.mul(x, 1.0)
                x = torch.relu(x)
                x = self.conv2(x)
                return x

        m = M()
        example_inputs = (torch.randn(1, 1, 1, 1),)
        expected_node_occurrence = {
            ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
        }
        self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence)

    def test_int8_input_no_unnecessary_fq(self):
        """
        If the inputs to the graph are quantized and the only node
        does not need an activation observer, verifies that the
        activation observer is not inserted.
        """
        class M(nn.Module):
            def __init__(self, scalar):
                super().__init__()
                self.scalar = scalar
                self.add_func = torch.ao.nn.quantized.FloatFunctional()

            def forward(self, x):
                return self.add_func.add_scalar(x, self.scalar)

        m = M(0.5)
        mp = torch.ao.quantization.quantize_fx.prepare_qat_fx(
            m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')},
            example_inputs=(torch.randn(1),),
            prepare_custom_config={"input_quantized_idxs": [0]})
        expected_node_occurrence = {
            ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1,
        }
        self.checkGraphModuleNodes(
            mp, expected_node_occurrence=expected_node_occurrence)

    @skipIfNoFBGEMM
    def test_cat(self):
        """ quantization of the output of cat will depend on the
        input of cat. we only quantize the output of cat when its inputs are quantized.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()

            def forward(self, x, y):
                x = self.conv1(x)
                y = self.conv2(y)
                return torch.cat([x, y], 1)

        example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float),
                          torch.randn(1, 2, 5, 5, dtype=torch.float))
        quantized_node = ns.call_function(torch.cat)
        options = itertools.product(self.static_quant_types, [True, False])
        for quant_type, is_reference in options:
            if is_reference:
                converted_node_list = [
                    ns.call_method("dequantize"),
                    ns.call_function(torch.cat),
                    ns.call_function(torch.quantize_per_tensor)
                ]
                converted_node_occurrence = {
                    # inputs and outputs of the two conv, and output of cat
                    ns.call_method("dequantize"): 5,
                    ns.call_function(torch.cat): 1,
                    # inputs and outputs of the two conv, and output of cat
                    ns.call_function(torch.quantize_per_tensor): 5,
                }
            else:
                converted_node_list = None
                converted_node_occurrence = {
                    # output of cat
                    ns.call_method("dequantize"): 1,
                    ns.call_function(torch.cat): 1,
                    # for two inputs
                    ns.call_function(torch.quantize_per_tensor): 2,
                }

            self.checkGraphModeFxOp(
                M(),
                example_inputs,
                quant_type,
                quantized_node,
                expected_node_list=converted_node_list,
                expected_node_occurrence=converted_node_occurrence,
                is_reference=is_reference)

        # check cat is using the same observer for input and output
        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        # two inputs and one output of torch.cat are using same observer, so we have
        # 2 observers that's replicated
        all_observers = len(dict(m.named_modules(remove_duplicate=False)))
        distinct_observers = len(dict(m.named_modules()))
        self.assertEqual(all_observers, distinct_observers + 2)
        # make sure the converted model runs
        m = convert_fx(m)
        m(*example_inputs)

    @skipIfNoFBGEMM
    def test_qbatch_norm(self):
        bn_module = {
            # TODO: quantized batchnorm 1d module is missing
            # 1 : torch.nn.BatchNorm1d,
            2 : torch.nn.BatchNorm2d,
            3 : torch.nn.BatchNorm3d,
        }

        class M(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.bn = bn_module[dim](3).to(torch.float)

            def forward(self, x):
                return self.bn(x)

        options = itertools.product(self.static_quant_types, [2, 3], [True, False])
        quantized_nodes = {
            False: {
                # 1: ns.call_module(nnq.BatchNorm1d),
                2: ns.call_module(nnq.BatchNorm2d),
                3: ns.call_module(nnq.BatchNorm3d),
            },
            True: {
                # 1: ns.call_module(nn.BatchNorm1d),
                2: ns.call_module(nn.BatchNorm2d),
                3: ns.call_module(nn.BatchNorm3d),
            }
        }
        for quant_type, dim, is_reference in options:
            self.checkGraphModeFxOp(
                M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference)

    @skipIfNoFBGEMM
    def test_qbatch_norm_relu(self):
        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}

        class BNRelu(torch.nn.Module):
            def __init__(self, dim, inplace):
                super().__init__()
                self.bn = bn_module[dim](3).to(torch.float)
                self.relu = torch.nn.ReLU(inplace=inplace)

            def forward(self, x):
                return self.relu(self.bn(x))

        class BNFuncRelu(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.bn = bn_module[dim](3).to(torch.float)

            def forward(self, x):
                return F.relu(self.bn(x), False)

        class BNFuncInplaceRelu(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.bn = bn_module[dim](3).to(torch.float)

            def forward(self, x):
                return F.relu(self.bn(x), True)

        options = itertools.product(self.static_quant_types, [2, 3], [True, False])
        quantized_nodes = {
            True: {
                2: ns.call_module(nni.BNReLU2d),
                3: ns.call_module(nni.BNReLU3d),
            },
            False: {
                2: ns.call_module(nniq.BNReLU2d),
                3: ns.call_module(nniq.BNReLU3d),
            }
        }
        for quant_type, dim, is_reference in options:
            for instance in [BNRelu(dim, True), BNRelu(dim, False),
                             BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
                self.checkGraphModeFxOp(
                    instance, self.img_data_dict[dim], quant_type,
                    quantized_nodes[is_reference][dim], is_reference=is_reference)

    def _test_activation_impl(
            self, float_module, float_op, quantized_module, quantized_op):
        ''' Test for activation op(with inplace options), float_op can be
        torch op or functional op
        '''
        class M(torch.nn.Module):
            def __init__(self, is_module, inplace):
                super().__init__()
                self.is_module = is_module
                self.inplace = inplace
                if self.is_module:
                    self.op = float_module(self.inplace)
                else:
                    self.op = float_op

            def forward(self, input):
                if self.is_module:
                    return self.op(input)
                else:
                    return self.op(input, self.inplace)

        options = itertools.product([True, False], [True, False], self.static_quant_types, [True, False])
        quantized_nodes = {
            # is_module
            True: {
                # is_reference
                True: ns.call_module(float_module),
                False: ns.call_module(quantized_module),
            },
            False: {
                True: ns.call_function(float_op),
                False: ns.call_function(quantized_op),
            }
        }

        for is_module, is_inplace, quant_type, is_reference in options:
            self.checkGraphModeFxOp(
                M(is_module, is_inplace), self.img_data_2d,
                quant_type, quantized_nodes[is_module][is_reference], is_reference=is_reference)

    def test_hardswish(self):
        self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish)

    def test_elu(self):
        self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu)

    def test_leaky_relu(self):
        self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)

    def test_prelu(self):
        class M(torch.nn.Module):
            def __init__(self, num_param: int):
                super().__init__()
                self.op = torch.nn.PReLU(num_parameters=num_param)

            def forward(self, input):
                return self.op(input)

        X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]]
        options = itertools.product([1, 4], self.static_quant_types, [True, False])
        quantized_nodes = {
            # is_reference
            True: ns.call_module(torch.nn.PReLU),
            False: ns.call_module(torch.ao.nn.quantized.PReLU),
        }

        for num_parameter, quant_type, is_reference in options:
            self.checkGraphModeFxOp(
                M(num_parameter), X, quant_type, quantized_nodes[is_reference],
                is_reference=is_reference)

    def _test_norm_impl(
            self, float_module, float_op, op_args, data, quantized_module, quantized_op,
            skip_op_arg_for_functional=False):
        ''' Test for normalization op, float_op can be torch op or functional op,
        op_args is a list of positional argument for the module/op
        '''
        class M(torch.nn.Module):
            def __init__(self, is_module):
                super().__init__()
                self.is_module = is_module
                if self.is_module:
                    self.op = float_module(*op_args)
                else:
                    self.op = float_op

            def forward(self, input):
                if self.is_module:
                    return self.op(input)
                else:
                    args = [input]
                    if not skip_op_arg_for_functional:
                        args += op_args
                    return self.op(*args)

        options = itertools.product([True, False], self.static_quant_types)
        quantized_nodes = {
            # is_module
            True: ns.call_module(quantized_module),
            False: ns.call_function(quantized_op),
        }

        for is_module, quant_type in options:
            self.checkGraphModeFxOp(
                M(is_module), data, quant_type, quantized_nodes[is_module])

    def _test_norm_float16_impl(
            self, float_module, float_op, op_args, data,
            skip_op_arg_for_functional=False):
        ''' Test for normalization op, float_op can be torch op or functional op,
        op_args is a list of positional argument for the module/op
        '''
        class M(torch.nn.Module):
            def __init__(self, is_module):
                super().__init__()
                self.is_module = is_module
                if self.is_module:
                    self.op = float_module(*op_args)
                else:
                    self.op = float_op

            def forward(self, input):
                if self.is_module:
                    return self.op(input)
                else:
                    args = [input]
                    if not skip_op_arg_for_functional:
                        args += op_args
                    return self.op(*args)

        options = itertools.product([True, False], self.static_quant_types)
        qconfig_dict = {
            "object_type": [
                (float_module, float16_static_qconfig),
                (float_op, float16_static_qconfig)
            ]
        }
        node_occurrence = {
            ns.call_method("to"): 2
        }
        for is_module, quant_type in options:
            self.checkGraphModeFxOp(
                M(is_module), data, quant_type, custom_qconfig_dict=qconfig_dict, expected_node_occurrence=node_occurrence)

    def test_layer_norm(self):
        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
        self._test_norm_impl(
            nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm)

    def test_instance_norm(self):
        data_1d = (torch.rand((1, 4, 5), dtype=torch.float),)
        data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),)
        data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),)
        data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d}
        instance_norm_modules = {1 : nn.InstanceNorm1d,
                                 2 : nn.InstanceNorm2d,
                                 3 : nn.InstanceNorm3d}
        quantized_instance_norm_modules = {
            1 : nnq.InstanceNorm1d,
            2 : nnq.InstanceNorm2d,
            3 : nnq.InstanceNorm3d
        }
        for dim in [1, 2, 3]:
            data = data_dict[dim]
            module = instance_norm_modules[dim]
            quantized_module = quantized_instance_norm_modules[dim]
            self._test_norm_impl(
                module, F.instance_norm, [4], data,
                quantized_module, torch.ops.quantized.instance_norm,
                skip_op_arg_for_functional=True)

    def test_norm_weight_bias(self):
        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = Linear()
                self.scale = torch.randn(5, 5)
                self.bias = torch.randn(5, 5)

            def forward(self, x):
                x1 = self.mods1(x)
                y = F.layer_norm(x1, [5, 5], weight=self.scale, bias=self.bias)
                return y

        model = M()
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_function(torch.ops.quantized.linear): 1,
            ns.call_function(torch.ops.quantized.layer_norm): 1,
            ns.call_method("dequantize"): 1,
        }

        self.checkGraphModeFxOp(
            model,
            (torch.rand(5, 5),),
            QuantType.STATIC,
            expected_node_occurrence=expected_occurrence,
            custom_qconfig_dict=get_default_qconfig_mapping().to_dict()
        )

    def _test_default_node_quant_handler_ops(
            self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None
    ):
        class M(torch.nn.Module):
            def __init__(self, mod, func):
                super().__init__()
                self.module = mod()
                self.functional = func

            def forward(self, x):
                x = self.module(x)
                x = self.functional(x)
                return x

        if node_list is None:
            node_list = []
        if additional_quant_pattern_dict is None:
            additional_quant_pattern_dict = {}

        data = torch.randn((2, 2, 2, 2))
        quant_type = QuantType.STATIC
        prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict}
        qconfig_dict = {"": qconfig}

        m = M(module, functional).eval()
        m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict)
        m_prep(data)
        convert_fn = convert_to_reference_fx if is_reference else convert_fx
        m_quant = convert_fn(m_prep, is_reference=is_reference)
        m_quant(data)

        self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)

    @unittest.skip("TODO: reenable with backend_config api")
    def test_gelu_normal(self):
        module = torch.nn.GELU
        functional = torch.nn.functional.gelu
        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        is_reference = False
        node_list = [
            ns.call_module(module),
            ns.call_function(functional),
        ]
        self._test_default_node_quant_handler_ops(
            module, functional, qconfig, is_reference, node_list)

    @unittest.skip("TODO: reenable with backend_config api")
    def test_softmax_normal(self):
        module = torch.nn.Softmax
        functional = torch.nn.functional.softmax
        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        is_reference = False
        node_list = [
            ns.call_module(torch.ao.nn.quantized.Softmax),
            ns.call_function(functional),
        ]
        self._test_default_node_quant_handler_ops(
            module, functional, qconfig, is_reference, node_list)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_gelu_reference(self):
        module = torch.nn.GELU
        functional = torch.nn.functional.gelu
        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        is_reference = True
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize"),
            ns.call_module(module),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method('dequantize'),
            ns.call_function(functional),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method('dequantize')
        ]
        # TODO: change these to use backend_config
        additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler,
                               torch.nn.functional.gelu: DefaultNodeQuantizeHandler}
        self._test_default_node_quant_handler_ops(
            module, functional, qconfig, is_reference, node_list, additional_patterns)

        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
                                                  additional_quant_pattern_dict=self.common_quant_patterns)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_softmax_reference(self):
        module = torch.nn.Softmax
        functional = torch.nn.functional.softmax
        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        is_reference = True
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize"),
            ns.call_module(module),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method('dequantize'),
            ns.call_function(functional),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method('dequantize')
        ]
        additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler,
                               torch.nn.functional.softmax: DefaultNodeQuantizeHandler}
        self._test_default_node_quant_handler_ops(
            module, functional, qconfig, is_reference, node_list, additional_patterns)

        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
                                                  additional_quant_pattern_dict=self.common_quant_patterns)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_silu_reference(self):
        module = torch.nn.SiLU
        functional = torch.nn.functional.silu
        qconfig = float16_static_qconfig
        is_reference = True
        node_list = [
            ns.call_method("to"),
            ns.call_method("dequantize"),
            ns.call_module(module),
            ns.call_method("to"),
            ns.call_method('dequantize'),
            ns.call_function(functional),
            ns.call_method("to"),
            ns.call_method('dequantize')
        ]
        self._test_default_node_quant_handler_ops(
            module, functional, qconfig, is_reference, node_list)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize"),
            ns.call_module(module),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize"),
            ns.call_function(functional),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize")
        ]
        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
                                                  additional_quant_pattern_dict=self.common_quant_patterns)

    @unittest.skip("This is no longer needed right now, can enable later with new api")
    def test_mish_reference(self):
        module = torch.nn.Mish
        functional = torch.nn.functional.mish
        qconfig = float16_static_qconfig
        is_reference = True
        node_list = [
            ns.call_method("to"),
            ns.call_method("dequantize"),
            ns.call_module(module),
            ns.call_method("to"),
            ns.call_method('dequantize'),
            ns.call_function(functional),
            ns.call_method("to"),
            ns.call_method('dequantize')
        ]
        self._test_default_node_quant_handler_ops(
            module, functional, qconfig, is_reference, node_list)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize"),
            ns.call_module(module),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize"),
            ns.call_function(functional),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize")
        ]
        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
                                                  additional_quant_pattern_dict=self.common_quant_patterns)

    def test_bmm_int_reference(self):
        """ int8 is not supported for bmm so we won't produce reference
            pattern for it
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bmm = torch.bmm

            def forward(self, x, y):
                out = self.bmm(x, y)
                return out

        data_x = torch.randn((2, 2, 2,))
        data_y = torch.randn((2, 2, 2,))
        example_inputs = (data_x, data_y)
        qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")}
        is_reference = True
        node_list = [
            ns.call_function(torch.bmm),
        ]

        m = M().eval()
        m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m_prep(*example_inputs)
        convert_fn = convert_to_reference_fx if is_reference else convert_fx
        m_quant = convert_fn(m_prep)
        m_quant(*example_inputs)

        self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)

    @skipIfNoFBGEMM
    def test_clamp(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(2, 2, 2).float()
                self.relu6 = torch.nn.ReLU6()
                self.relu6_ = torch.nn.ReLU6(True)
                self.hardtanh = torch.nn.Hardtanh()
                self.hardtanh_ = torch.nn.Hardtanh(inplace=True)

            def forward(self, x):
                x = self.conv(x)
                x = self.relu6(x)
                self.relu6_(x)
                x = F.relu6(x)
                x = torch.clamp(x, -3, 3)
                x = x.clamp(-2.5, 2.5)
                # x = x.clamp_(-2, 2)  # Enable when quantized `clamp_` is ready
                x = self.hardtanh(x)
                self.hardtanh_(x)
                x = F.hardtanh(x)
                return x

        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
        # list of node that should occur in order
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]
        for quant_type in self.static_quant_types:
            self.checkGraphModeFxOp(
                M(), data, quant_type, expected_node_list=node_list)

    def test_fixed_qparams_ops_fp16(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                x = x.sigmoid()
                x = self.tanh(x)
                x = torch.tanh(x)
                x = x.tanh()
                return x

        data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
        quant_type = QuantType.STATIC
        # TODO: use get_default_qconfig_mapping once it handles fp16
        qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig)
        backend_config = get_test_only_legacy_native_backend_config()
        node_occurrence = {
            ns.call_method("to"): 7
        }
        self.checkGraphModeFxOp(
            M(), data, quant_type, custom_qconfig_dict=qconfig_mapping,
            expected_node_occurrence=node_occurrence,
            backend_config=backend_config)

    def test_fixed_qparams_ops_qint8(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                x = x.sigmoid()
                x = self.tanh(x)
                x = torch.tanh(x)
                x = x.tanh()
                return x

        data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
        quant_type = QuantType.STATIC
        qconfig = torch.ao.quantization.QConfig(
            activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8),
            weight=default_weight_observer)
        qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 7,
            ns.call_method("dequantize"): 7
        }
        self.checkGraphModeFxOp(
            M(), data, quant_type, custom_qconfig_dict=qconfig_mapping,
            expected_node_occurrence=node_occurrence, is_reference=True)

    def test_fixed_qparams_ops_wrong_qconfig(self):
        """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized.
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                x = x.sigmoid()
                x = self.tanh(x)
                x = torch.tanh(x)
                x = x.tanh()
                return x

        data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
        m = M().eval()
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method("dequantize"): 0,
        }
        self.checkGraphModeFxOp(
            m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping,
            expected_node_occurrence=node_occurrence, is_reference=True)
        self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid))
        self.assertTrue(isinstance(m.tanh, torch.nn.Tanh))

    @skipIfNoFBGEMM
    def test_general_shape_ops(self):
        """ A test that checks dequantize will be swapped for
        all supported general shape ops like aten::flatten
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
                self.dropout = torch.nn.Dropout()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv1(x)
                # add_scalar
                x = x + 3
                # mul_scalar
                x = x * 3
                # add_scalar_out
                x += 3
                # mul_scalar_out
                x *= 3
                # add_scalar_relu
                x = x + 3
                x = F.relu(x)
                # add_scalar_relu_out
                x += 3
                x = F.relu(x)
                # mul_scalar_relu
                x = x * 3
                x = F.relu(x)
                # mul_scalar_relu_out
                x *= 3
                x = F.relu(x)
                x = self.maxpool1d(x)
                x = self.maxpool2d(x)
                x = self.maxpool3d(x)
                x = torch.flatten(x)
                x = x.reshape([-1])
                x = x.resize_(1, 1, x)
                x = x.view(-1)
                # prim::ListConstruct
                xs = [x, x]
                # prim::ListUnpack
                x, y = xs
                # prim::TupleConstruct
                xs = (x, x)
                # prim::TupleUnpack
                x, y = xs
                x = x.transpose(1, 2)
                x = x.contiguous()
                # chunk is not supported since observer only supports
                # observing single Tensor currently
                x, y = torch.chunk(x, 2)
                x = F.dropout(x)
                x = self.dropout(x)
                x = x.permute(0, 2, 3, 1)
                x = x.repeat_interleave(3, 1)
                x = torch.repeat_interleave(x, 3, 1)
                x = self.relu(x)
                x = F.relu(x)
                x = F.relu(x, inplace=True)
                x = x.relu()
                x.relu_()
                x = x.squeeze(0)
                x.squeeze_(0)
                x = torch.squeeze(x, 0)
                x = x.unsqueeze(0)
                x.unsqueeze_(0)
                x = torch.unsqueeze(x, 0)
                x = x.detach()
                x.detach_()
                x = x.repeat(4, 2)
                y = []
                y.append(x)
                z = torch.stack(y, 0)
                z = [z, z]
                x, _ = z
                x = self.conv2(x)
                return x

        example_inputs = (torch.rand(1, 3, 10, 10),)
        # This model is not executable since we just put all ops
        # in the same forward
        m = M().eval()
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers and also successfully fused two quantized::conv2d
        # patterns
        # one quantize_per_tensor for input
        # check exact counts of quantize and dequantize
        count_check = {
            # input of conv and two outputs of getitem
            ns.call_function(torch.quantize_per_tensor) : 2,
            # output of the model and two outputs of getitem
            ns.call_method('dequantize') : 2
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(
            quantized,
            expected_node_occurrence=count_check,
            expected_node_list=order_check)


        # Checking the is_reference output
        m = M().eval()
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        # not runnable
        quantized = convert_to_reference_fx(prepared)


    @skipIfNoFBGEMM
    def test_ave_pool_with_custom_cfg(self):
        """ A test that checks correct patterns are produced for
        avg_pool2d with customized config
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.avg_pool2d = torch.nn.AvgPool2d(3)


            def forward(self, x):
                x = self.avg_pool2d(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M().eval()
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        example_inputs = (torch.randn(1, 3, 3, 3),)
        prepared = prepare_fx(
            m, qconfig_dict, example_inputs=example_inputs,
            prepare_custom_config={"input_quantized_idxs": [0]})

        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_method('dequantize') : 1
        }
        order_check = [
            ns.call_module(nn.AvgPool2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(
            quantized,
            expected_node_occurrence=count_check,
            expected_node_list=order_check)

    @skipIfNoFBGEMM
    def test_general_value_ops(self):
        """ A test that checks correct patterns are produced for
        all supported general value ops like aten::avg_pool2d \
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.avg_pool1d = torch.nn.AvgPool1d(3)
                self.avg_pool2d = torch.nn.AvgPool2d(3)
                self.avg_pool3d = torch.nn.AvgPool3d(3)
                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1)
                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))

            def forward(self, x):
                x = self.conv(x)
                x = self.avg_pool1d(x)
                x = self.avg_pool2d(x)
                x = self.avg_pool3d(x)
                x = self.adaptive_avg_pool1d(x)
                x = self.adaptive_avg_pool2d(x)
                x = self.adaptive_avg_pool3d(x)
                x = F.avg_pool1d(x, 3)
                x = F.avg_pool2d(x, 3)
                x = F.avg_pool3d(x, 3)
                x = F.adaptive_avg_pool1d(x, (1))
                x = F.adaptive_avg_pool2d(x, (1, 1))
                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
                x = torch.mean(x)
                x = torch.mean(x, [2, 3], False)
                x = x.mean()
                x = x.mean([2, 3], True)
                x = F.interpolate(x, 4, mode='nearest')
                x = F.interpolate(x, 4, mode='linear')
                x = self.conv(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M().eval()
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        example_inputs = (torch.randn(1, 3, 3, 3),)
        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor) : 1,
            ns.call_method('dequantize') : 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(
            quantized,
            expected_node_occurrence=count_check,
            expected_node_list=order_check)

    def test_copy_node_fp32_input(self):
        """ CopyNode works for both fp32 and int8 inputs, this is a test to make
        sure that a CopyNode can be successfully quantized in both cases
        """
        class M(torch.nn.Module):
            def forward(self, x):
                x = x.relu()
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),))
        m = convert_fx(m)
        # make sure it runs
        m(torch.rand(1))

    def test_getitem(self):
        """ Make sure we only insert observer for getitem if the following node is matched
        or needs to be quantized
        """
        class M(torch.nn.Module):
            def forward(self, xs):
                x = xs[0]
                return x

        m = M().eval()
        example_inputs = (torch.rand(1, 2),)
        qconfig_mapping = get_default_qconfig_mapping()
        m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs)
        self.checkGraphModuleNodes(m, expected_node_occurrence={
            ns.call_module(torch.ao.quantization.MinMaxObserver): 0
        })
        m = convert_fx(m)
        m(*example_inputs)

        class M2(torch.nn.Module):
            def forward(self, xs):
                x = xs[0]
                x = torch.sigmoid(x)
                return x

        m2 = M2().eval()
        example_inputs = ([torch.rand(1, 2)],)
        qconfig_mapping = get_default_qconfig_mapping()
        m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs)
        self.checkGraphModuleNodes(m2, expected_node_occurrence={
            ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2
        })
        m2 = convert_fx(m2)
        self.checkGraphModuleNodes(m2, expected_node_list=[
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize")
        ])
        m2(*example_inputs)

        # testing prepare recognizes non-Tensor input for getitem
        class M3(torch.nn.Module):
            def forward(self, x):
                s = x.shape
                n, c = s[:2]
                x = torch.sigmoid(x)
                return x

        m3 = M3().eval()
        example_inputs = (torch.rand(1, 2, 3, 4),)
        qconfig_mapping = get_default_qconfig_mapping()
        m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs)
        self.checkGraphModuleNodes(m3, expected_node_occurrence={
            ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2
        })
        m3 = convert_fx(m3)
        self.checkGraphModuleNodes(m3, expected_node_list=[
            ns.call_function(torch.quantize_per_tensor),
            ns.call_method("dequantize")
        ])
        m3(*example_inputs)


    @skipIfNoFBGEMM
    def test_fixed_qparams_ops(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.sigmoid = torch.nn.Sigmoid()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.tanh = torch.nn.Tanh()
                self.softmax = torch.nn.Softmax(dim=0)

            def forward(self, x):
                x = self.conv(x)
                # F.sigmoid is deprecated
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                x = x.sigmoid()
                x = self.hardsigmoid(x)
                x = F.hardsigmoid(x)
                x = F.hardsigmoid(x, inplace=True)
                x = self.tanh(x)
                # F.tanh is deprecated
                x = torch.tanh(x)
                x = x.tanh()
                # TODO(future PR): handle F.softmax
                x = self.softmax(x)
                return x

        for eval_mode in [True, False]:
            # This model is not executable since we just put all ops
            # in the same forward
            m = M()
            if eval_mode:
                m.eval()
                qconfig_mapping = get_default_qconfig_mapping()
                prepare = prepare_fx
                fq_count = 10
            else:
                m.train()
                qconfig_mapping = get_default_qat_qconfig_mapping()
                prepare = prepare_qat_fx
                fq_count = 10
            # nothing to fuse so skipping the fuse step
            m_copy = copy.deepcopy(m)
            example_inputs = (torch.rand(3, 3, 3, 3),)
            prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs)
            prepared_copy = copy.deepcopy(prepared)
            # check that prepare does not change model result
            if eval_mode:
                self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs))
            # check the correct number of activation_post_process is inserted
            expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize
            count_check = {
                ns.call_module(expected_activation_post_process) : fq_count,
            }
            self.checkGraphModuleNodes(
                prepared,
                expected_node_occurrence=count_check)
            # not runnable
            quantized = convert_fx(prepared)
            quantized_reference = convert_to_reference_fx(prepared_copy)

            # This checks that the dequantize from the output of first conv
            # is being propagated to the end, so that we don't insert extra
            # observers
            # check exact counts of quantize and dequantize
            count_check = {
                ns.call_function(torch.quantize_per_tensor) : 1,
                ns.call_method('dequantize') : 1
            }
            order_check = [
                ns.call_function(torch.quantize_per_tensor),
                ns.call_module(nnq.Conv2d),
                ns.call_module(nn.Sigmoid),
                ns.call_module(nnq.Softmax),
                ns.call_method('dequantize'),
            ]
            self.checkGraphModuleNodes(
                quantized,
                expected_node_occurrence=count_check,
                expected_node_list=order_check)

            reference_count_check = {
                ns.call_function(torch.quantize_per_tensor) : 12,
                ns.call_method('dequantize') : 12
            }
            reference_order_check = [
                ns.call_function(torch.quantize_per_tensor),
                ns.call_method('dequantize'),
                ns.call_module(nnqr.Conv2d),
                ns.call_function(torch.quantize_per_tensor),
                ns.call_method('dequantize'),
                ns.call_module(nn.Sigmoid),
                ns.call_function(torch.quantize_per_tensor),
                ns.call_method('dequantize'),
                ns.call_module(nn.Softmax),
                ns.call_function(torch.quantize_per_tensor),
                ns.call_method('dequantize'),
            ]
            self.checkGraphModuleNodes(
                quantized_reference,
                expected_node_occurrence=reference_count_check,
                expected_node_list=reference_order_check)

            # Verify that softmax scale and zero_point are correct
            self.assertTrue(quantized.softmax.scale - (1.0 / 256) <= 1e-8)
            self.assertTrue(quantized.softmax.zero_point == 0)

    def test_float_functional(self):
        class TorchAdd(nn.Module):
            """Wrapper around torch.add so that all ops can be found at build"""
            def __init__(self) -> None:
                super().__init__()
                self.add_func = nnq.FloatFunctional()

            def forward(self, x, y):
                return self.add_func.add(x, y)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.ff1 = TorchAdd()
                self.ff2 = nnq.FloatFunctional()
                self.ff3 = nnq.FloatFunctional()
                self.ff4 = nnq.FloatFunctional()
                self.ff5 = nnq.FloatFunctional()
                self.ff6 = nnq.FloatFunctional()

            def forward(self, x):
                x = self.ff1(x, x)
                x = self.ff2.add_scalar(x, 3)
                x = self.ff3.mul(x, x)
                x = self.ff4.mul_scalar(x, 3)
                x = self.ff5.add_relu(x, x)
                x = self.ff6.cat([x])
                return x

        example_inputs = (torch.rand(3, 3),)
        # Note: QAT test succeeded by chance, to make it actually work
        # we need to fix eager mode FloatFunctional by removing
        # activation_post_process in add_scalar and mul_scalar
        for quant_type in self.static_quant_types:
            m = M()
            ref_m = torch.ao.quantization.QuantWrapper(M())
            is_qat = quant_type == QuantType.QAT
            if is_qat:
                m.train()
                ref_m.train()
                qconfig = default_qat_qconfig
                expected_act_post_process = torch.ao.quantization.FakeQuantize
            else:
                m.eval()
                ref_m.eval()
                qconfig = default_qconfig
                expected_act_post_process = torch.ao.quantization.MinMaxObserver

            prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx
            qconfig_dict = {"": qconfig}
            m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs)
            node_occurrence = {
                ns.call_module(expected_act_post_process): 7,
                ns.call_module(torch.ao.nn.quantized.FloatFunctional): 0
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
            m(*example_inputs)
            node_list = [
                ns.call_function(torch.quantize_per_tensor),
                ns.call_function(torch.ops.quantized.add),
                ns.call_function(torch.ops.quantized.add),
                ns.call_function(torch.ops.quantized.mul),
                ns.call_function(torch.ops.quantized.mul),
                ns.call_function(torch.ops.quantized.add_relu),
                ns.call_function(torch.cat),
                ns.call_method('dequantize')
            ]
            m = convert_fx(m)
            self.checkGraphModuleNodes(m, expected_node_list=node_list)

            # make sure numerics match with eager mode
            ref_m.qconfig = qconfig
            prepare_function = prepare_qat if is_qat else prepare
            ref_m = prepare_function(ref_m)
            ref_m(*example_inputs)
            ref_m = convert(ref_m)
            # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar
            # self.assertEqual(m(data), ref_m(data))

    def test_embedding(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)

            def forward(self, indices):
                return self.emb(indices)

        for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]:
            model = M().eval()
            indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
            example_inputs = (indices,)
            quantized_node = ns.call_module(nnq.Embedding)

            # check dynamic quant
            self.checkGraphModeFxOp(
                model,
                example_inputs,
                QuantType.DYNAMIC,
                quantized_node,
                custom_qconfig_dict={"": qconfig_type}
            )
            model = M().eval()

            configs = [
                (qconfig_type, ns.call_module(nnq.Embedding)),
                (None, ns.call_module(nn.Embedding)),
                (default_qconfig, ns.call_module(nn.Embedding)),
            ]

            # check static quantization
            for qconfig, node in configs:
                qconfig_dict = {"": qconfig}
                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
                self.checkGraphModuleNodes(m, expected_node_occurrence={
                    ns.call_module(torch.ao.quantization.MinMaxObserver): 0
                })
                m = convert_fx(m)
                self.checkGraphModuleNodes(m, expected_node=node)
                # make sure it runs
                m(*example_inputs)

    def test_embedding_bag(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True)

            def forward(self, indices, offsets):
                return self.emb(indices, offsets)

        indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
        offsets = torch.tensor([0, 19, 20, 28, 28, 32])
        quantized_node = ns.call_module(nnq.EmbeddingBag)
        example_inputs = (indices, offsets)

        for dtype in [torch.quint8, torch.quint4x2]:
            model = M().eval()
            float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
                                                                        qscheme=torch.per_channel_affine_float_qparams,
                                                                        ch_axis=0)
            float_qparams_qconfig = QConfig(activation=default_placeholder_observer,
                                            weight=float_qparams_observer)
            self.checkGraphModeFxOp(
                model,
                example_inputs,
                QuantType.DYNAMIC,
                quantized_node,
                custom_qconfig_dict={"": float_qparams_qconfig}
            )

        # check it works in None and static qconfig
        for qconfig in [None, default_qconfig]:
            qconfig_dict = {"": default_qconfig}
            m = M().eval()
            m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
            self.checkGraphModuleNodes(m, expected_node_occurrence={
                ns.call_module(torch.ao.quantization.MinMaxObserver): 0
            })
            m = convert_fx(m)
            self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag))
            # make sure it runs
            m(*example_inputs)

    def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input):
        options = itertools.product(qconfigs, module_type_strs)
        for qconfig, module_type_str in options:
            model_eager = M(module_type_str).eval()
            model_graph = copy.deepcopy(model_eager)
            if torch.backends.quantized.engine == 'qnnpack' and \
               qconfig is float16_dynamic_qconfig:
                continue
                # fp16 dynamic quant is not supported for qnnpack

            eager_qconfig_dict = dict.fromkeys(module_types, qconfig)
            model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)

            graph_qconfig_dict = {
                "object_type": [
                    (x, qconfig) for x in module_types
                ]
            }
            model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,))
            model_graph = convert_fx(model_graph)
            self.assertEqual(model_eager(sample_input), model_graph(sample_input))
            self.checkScriptable(model_graph, [[sample_input]], True)

    @override_qengines
    def test_rnn_cell(self):
        if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
            return
        qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
        module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']
        module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell]
        sample_input = torch.tensor([[100, -155],
                                     [-155, 100],
                                     [100, -155]], dtype=torch.float)
        self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input)

    @override_qengines
    def test_rnn(self):
        if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
            return
        qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
        module_type_strs = ['LSTM', 'GRU']
        module_types = [torch.nn.LSTM, torch.nn.GRU]
        niter = 10
        sample_input = torch.tensor([[100, -155],
                                     [-155, 100],
                                     [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
        self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input)

    def _test_conv_transpose_impl(
            self, float_cls: Callable, q_cls: Callable, data: torch.Tensor):
        with override_quantized_engine('qnnpack'):
            # Create fp32 versions of FX and Eager models
            m1 = torch.nn.Sequential(float_cls(1, 1, 1))
            m2 = torch.nn.Sequential(float_cls(1, 1, 1))
            m2.load_state_dict(m1.state_dict())
            m2 = torch.ao.quantization.QuantWrapper(m2)
            # FX graph
            result_dict = self.checkGraphModeFxOp(
                m1, (data,), QuantType.STATIC,
                expected_node_occurrence={
                    ns.call_module(q_cls): 1,
                })
            q_result1 = result_dict["quantized_output"]
            # Eager
            m2.qconfig = get_default_qconfig(torch.backends.quantized.engine)
            m2.eval()
            m2p = torch.ao.quantization.prepare(m2)
            m2p(data)
            m2q = torch.ao.quantization.convert(m2p)
            q_result2 = m2q(data)
            # verify results match
            self.assertEqual(q_result1, q_result2)

    @unittest.skipUnless('qnnpack' in supported_qengines,
                         "This Pytorch Build has not been built with or does not support QNNPACK")
    def test_conv_transpose_1d(self):
        self._test_conv_transpose_impl(
            torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4))

    @unittest.skipUnless('qnnpack' in supported_qengines,
                         "This Pytorch Build has not been built with or does not support QNNPACK")
    def test_conv_transpose_2d(self):
        self._test_conv_transpose_impl(
            torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4))

    def test_reshape_fp16(self):
        class M(torch.nn.Module):
            def __init__(self, w, b):
                super().__init__()
                self.w = w
                self.b = b

            def forward(self, x):
                x = torch.nn.functional.linear(x, self.w)
                x = x.reshape(-1, 4)
                x = torch.nn.functional.linear(x, self.w)
                return x

        w = torch.randn(4, 4)
        b = torch.randn(4)
        m = M(w, b).eval()
        qconfig_dict = {
            # reshape will be quantized to fp16 as requested by this qconfig
            "": float16_static_qconfig,
            "object_type": [
                (torch.nn.functional.linear, default_qconfig)
            ]
        }
        backend_config = get_test_only_legacy_native_backend_config()
        example_inputs = (torch.randn(1, 4),)
        m = prepare_fx(
            m, qconfig_dict, example_inputs=example_inputs,
            backend_config=backend_config)
        expected_occurrence = {
            # input and weight of first and second linear, output of first and second linear
            ns.call_module(torch.ao.quantization.MinMaxObserver): 6,
            # we insert placeholder observer for both input and output of reshape
            ns.call_module(torch.ao.quantization.PlaceholderObserver): 2
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence
        )
        m = convert_fx(m, backend_config=backend_config)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            # dequantize after first linear, before reshape and before output
            ns.call_method("dequantize"): 3,
            # before reshape, to(fp16)
            ns.call_method("to"): 1,
            ns.call_function(torch.ops.quantized.linear): 2
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence
        )
        # make sure it runs
        m(torch.randn(2, 4))

    def test_multiple_qconfigs_for_single_value(self):
        """ Test multiple qconfigs for a single value"""
        class M(torch.nn.Module):
            def __init__(self, w, b):
                super().__init__()
                self.w = w
                self.b = b

            def forward(self, x):
                x = torch.nn.functional.linear(x, self.w)
                x = torch.sigmoid(x)
                return x

        w = torch.randn(4, 4)
        b = torch.randn(4)
        m = M(w, b).eval()
        # TODO: use get_default_qconfig_mapping once it handles fp16
        qconfig_mapping = QConfigMapping() \
            .set_global(float16_static_qconfig) \
            .set_object_type(torch.nn.functional.linear, default_qconfig)
        example_inputs = (torch.randn(1, 4),)
        backend_config = get_test_only_legacy_native_backend_config()
        m = prepare_fx(
            m, qconfig_mapping, example_inputs=example_inputs,
            backend_config=backend_config)
        expected_occurrence = {
            # input and weight of linear, output of linear
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
            # input and output of sigmoid
            ns.call_module(torch.ao.quantization.PlaceholderObserver): 2,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence
        )
        # make sure it runs
        m = convert_fx(m)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 3,
            ns.call_method("to"): 2
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence
        )

    def test_boolean_tensor(self):
        """ Make sure we don't insert observer for boolean Tensors """
        class M(torch.nn.Module):
            def forward(self, x, mask):
                mask = mask.unsqueeze(0)
                mask = mask.unsqueeze(1)
                x = x.masked_fill(mask, 1)
                return x

        m = M().eval()
        example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool())
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        expected_occurrence = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 0
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence)
        m = convert_fx(m)
        m(*example_inputs)

    def test_chunk(self):
        class M(torch.nn.Module):
            def forward(self, x):
                x, y = torch.chunk(x, 2)
                x = x + y
                return x
        m = M().eval()
        example_inputs = (torch.rand(2, 2, 2, 2),)
        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
        m(*example_inputs)
        m = convert_fx(m)
        m(*example_inputs)
        # make sure everything runs

    def test_ref_pattern_multi_use(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)
                self.linear1 = torch.nn.Linear(5, 5)

            def forward(self, x):
                y = self.linear(x)
                z = self.linear1(x)
                a = torch.mul(z, 5)
                b = torch.add(z, 5)
                return (y, a, b)

        m = M().eval()
        qconfig_dict = {
            "": None,
            "object_type": [
                (torch.nn.Linear, get_default_qconfig("fbgemm")),
                (torch.nn.ReLU, get_default_qconfig("fbgemm")),
            ],
        }
        example_inputs = (torch.randn(1, 5),)
        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        m = convert_fx(m)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_module(nnq.Linear): 2,
            ns.call_method("dequantize"): 2,
            ns.call_function(torch.add): 1,
            ns.call_function(torch.mul): 1,
        }
        self.checkGraphModuleNodes(
            m,
            expected_node_occurrence=expected_occurrence)

    def test_qmatmul(self):
        class M(torch.nn.Module):
            def forward(self, x, y):
                z = torch.matmul(x, y)
                return z

        m = M().eval()
        example_inputs = (torch.randn(2, 2), torch.randn(2, 2))
        qconfig_dict = get_default_qconfig_mapping("fbgemm")
        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
        mp(*example_inputs)
        mq = convert_fx(mp)
        expected_occurrence = {
            ns.call_function(torch.matmul): 0,
            ns.call_function(torch.ops.quantized.matmul): 1,
        }
        self.checkGraphModuleNodes(
            mq,
            expected_node_occurrence=expected_occurrence)
        # verify no crash
        res = mq(*example_inputs)

    def test_pixel_shuffle(self):
        class MyBias(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = nn.Parameter(torch.randn(8))

        class MyModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(8, 8, 1, bias=False)
                self.bias = MyBias()

            def forward(self, x):
                x = self.conv(x)
                x = nn.functional.pixel_shuffle(x, 2)
                x = x.view(-1, 8, 2, 2)
                bias = self.bias.bias
                return x + bias

        backend_config = get_qnnpack_backend_config()
        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
        model = MyModel()
        m = prepare_fx(
            model,
            qconfig_mapping=qconfig_mapping,
            example_inputs=(torch.randn(1, 8, 3, 3),),
            backend_config=backend_config
        )
        m = convert_fx(m)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 1,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)

    def test_pixel_shuffle_module(self) -> None:
        class MyBias(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = nn.Parameter(torch.randn(8))

        class MyModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(8, 8, 1, bias=False)
                self.ps = nn.PixelShuffle(upscale_factor=2)
                self.bias = MyBias()

            def forward(self, x):
                x = self.conv(x)
                x = self.ps(x)
                x = x.view(-1, 8, 2, 2)
                bias = self.bias.bias
                return x + bias

        backend_config = get_qnnpack_backend_config()
        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
        model = MyModel()
        m = prepare_fx(
            model,
            qconfig_mapping=qconfig_mapping,
            example_inputs=(torch.randn(1, 8, 3, 3),),
            backend_config=backend_config
        )
        m = convert_fx(m)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 1,
            ns.call_module(nn.PixelShuffle): 1,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)

    def test_pixel_unshuffle(self):
        class MyBias(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = nn.Parameter(torch.randn(64))

        class MyModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(8, 8, 1, bias=False)
                self.bias = MyBias()

            def forward(self, x):
                x = self.conv(x)
                x = nn.functional.pixel_unshuffle(x, 2)
                bias = self.bias.bias
                return x + bias

        for backend in ["fbgemm", "qnnpack"]:
            if backend == "fbgemm":
                backend_config = get_fbgemm_backend_config()
            else:
                backend_config = get_qnnpack_backend_config()
            qconfig_mapping = get_default_qconfig_mapping(backend)
            model = MyModel()
            m = prepare_fx(
                model,
                qconfig_mapping=qconfig_mapping,
                example_inputs=(torch.randn(1, 8, 6, 6),),
                backend_config=backend_config
            )
            m = convert_fx(m)
            expected_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 2,
                ns.call_method("dequantize"): 1,
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)

    def test_pixel_unshuffle_module(self) -> None:
        class MyBias(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = nn.Parameter(torch.randn(64))

        class MyModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(8, 8, 1, bias=False)
                self.unshuffle = nn.PixelUnshuffle(downscale_factor=2)
                self.bias = MyBias()

            def forward(self, x):
                x = self.conv(x)
                x = self.unshuffle(x)
                bias = self.bias.bias
                return x + bias

        for backend in ["fbgemm", "qnnpack"]:
            if backend == "fbgemm":
                backend_config = get_fbgemm_backend_config()
            else:
                backend_config = get_qnnpack_backend_config()
            qconfig_mapping = get_default_qconfig_mapping(backend)
            model = MyModel()
            m = prepare_fx(
                model,
                qconfig_mapping=qconfig_mapping,
                example_inputs=(torch.randn(1, 8, 6, 6),),
                backend_config=backend_config
            )
            m = convert_fx(m)
            expected_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 2,
                ns.call_method("dequantize"): 1,
                ns.call_module(nn.PixelUnshuffle): 1,
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)



    def test_narrow(self):
        class MyBias(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = nn.Parameter(torch.randn(4))

        class MyModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(8, 8, 1, bias=False)
                self.bias = MyBias()

            def forward(self, x):
                x = self.conv(x)
                x = torch.narrow(x, 1, 0, 4)
                bias = self.bias.bias
                return x + bias

        for backend in ["fbgemm", "qnnpack"]:
            if backend == "fbgemm":
                backend_config = get_fbgemm_backend_config()
            else:
                backend_config = get_qnnpack_backend_config()
            qconfig_mapping = get_default_qconfig_mapping(backend)
            model = MyModel()
            m = prepare_fx(
                model,
                qconfig_mapping=qconfig_mapping,
                example_inputs=(torch.randn(1, 8, 3, 3),),
                backend_config=backend_config
            )
            m = convert_fx(m)
            expected_occurrence = {
                ns.call_function(torch.quantize_per_tensor): 2,
                ns.call_method("dequantize"): 1,
            }
            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)

class TestQuantizeFxModels(QuantizationTestCase):
    @skipIfNoFBGEMM
    @unittest.skipIf(not TEST_CUDA, "gpu is not available.")
    def test_static_gpu_convert_basic(self):

        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu1 = nn.ReLU()
                self.conv1 = nn.Conv2d(1, 6, 5)
                self.linear1 = nn.Linear(120, 1)

            def forward(self, x):
                x = self.relu1(self.conv1(x))
                y = self.linear1(x.view(-1))
                return y

        input = torch.randn((5, 1, 6, 6)).to('cuda')
        example_inputs = (input,)
        model = Net().to('cuda').eval()
        qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
        model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
        model_prepared(*example_inputs)
        model_quantized = convert_to_reference_fx(model_prepared)
        out = model_quantized(*example_inputs)
        self.assertEqual(out.device.type, 'cuda')

    @skipIfNoFBGEMM
    @unittest.skipIf(not TEST_CUDA, "gpu is not available.")
    def test_switch_device_prepare_convert(self):

        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu1 = nn.ReLU()
                self.conv1 = nn.Conv2d(1, 6, 5)
                self.linear1 = nn.Linear(120, 1)

            def forward(self, x):
                x = self.relu1(self.conv1(x))
                y = self.linear1(x.view(-1))
                return y

        for device in ['cuda', 'cpu']:
            device_after = 'cuda' if device == 'cpu' else 'cpu'
            input = torch.randn((5, 1, 6, 6)).to(device)
            model = Net().to(device).eval()
            qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
            model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,))
            model_prepared(input)
            model_prepared.to(device_after)
            model_quantized = convert_to_reference_fx(model_prepared)
            out = model_quantized(input.to(device_after))
            self.assertEqual(out.device.type, device_after)

    @skipIfNoFBGEMM
    @unittest.skipIf(not TEST_CUDA, "gpu is not available.")
    def test_prepare_serialize_switch_device_convert(self):
        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 6, 5)
                self.linear1 = nn.Linear(120, 1)

            def forward(self, x):
                x = self.conv1(x)
                y = self.linear1(x.view(-1))
                return y

        for device in ['cuda', 'cpu']:
            for device_after in ['cuda', 'cpu']:
                input = torch.randn((5, 1, 6, 6)).to(device)
                model = Net().to(device).eval()
                qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
                model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,))
                model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,))
                model_prepared_first(input)
                state_dict = model_prepared_first.state_dict()
                del model_prepared_first
                model_prepared_second.load_state_dict(state_dict)
                model_prepared_second.to(device_after)
                model_quantized = convert_to_reference_fx(model_prepared_second)
                out = model_quantized(input.to(device_after))
                self.assertEqual(out.device.type, device_after)

    @skipIfTorchDynamo("too slow")
    @skip_if_no_torchvision
    def test_model_dropout(self):
        from torchvision import models
        m = models.mobilenet_v3_small()
        qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm')
        example_inputs = (torch.randn(1, 3, 224, 224),)
        mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs)
        mp(*example_inputs)
        with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext():
            mq = convert_fx(mp)
        mq(*example_inputs)

    def _test_model_impl(
            self, mode, name, model, eager_quantizable_model,
            check_with_eager=True,
            diff_of_quant=None,
            diff_from_eager=None):
        if diff_of_quant is None or diff_from_eager is None:
            diff_of_quant = {}
            diff_from_eager = {}

        if mode not in diff_of_quant or mode not in diff_from_eager:
            diff_of_quant[mode] = {}
            diff_from_eager[mode] = {}

        input_tensor = torch.rand(1, 3, 224, 224)
        input_tensor_inception = torch.rand(1, 3, 299, 299)
        output_value = torch.randint(0, 1, (1,))

        # print('quantizing:', name, ' mode:', mode)
        if name == 'inception_v3':
            input_value = input_tensor_inception
        else:
            input_value = input_tensor

        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
        qconfig_dict = {'': qconfig}
        script = torch.jit.script(model)

        # make sure graph module and script module are both runanble
        original_out = model(input_value)
        is_not_tuple_out = not isinstance(original_out, tuple)
        script_out = script(input_value)

        # set to train just before quantization
        prepare_fx_fn = prepare_fx
        if mode != 'static':
            model.train()
            prepare_fx_fn = prepare_qat_fx

        prepared = prepare_fx_fn(model, qconfig_dict)

        if mode == 'ddp':
            mp.spawn(run_ddp,
                     args=(world_size, prepared),  # noqa: F821
                     nprocs=world_size,  # noqa: F821
                     join=True)
        elif mode == 'qat':
            assert prepared.training, 'prepared must be in training mode for qat'
            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
            criterion = nn.CrossEntropyLoss()
            train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
        else:
            for i in range(10):
                prepared(input_value)

        # print('after observation root:', prepared.root)

        qgraph = convert_fx(prepared)
        # print('after quantization root:', qgraph.root)
        # print('after quantization code:', qgraph.src)
        qgraph.eval()
        qgraph_script = torch.jit.script(qgraph)
        # print('quantized and scripted:', qgraph_script.graph)

        qgraph_out = qgraph(input_value)
        qgraph_script = qgraph_script(input_value)

        if is_not_tuple_out:
            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
            assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph'
        else:
            print('tuple output')

        if eager_quantizable_model is not None:
            # comparing to eager mode quantization
            qeager = eager_quantizable_model
            ref_out = qeager(input_value)
            qeager.qconfig = qconfig
            if mode == 'static':
                qeager.fuse_model()
                prepare(qeager, inplace=True)
            else:
                qeager.train()
                qeager.fuse_model()
                prepare_qat(qeager, inplace=True)

            # calibration
            if mode == 'ddp':
                mp.spawn(run_ddp,
                         args=(world_size, qeager),  # noqa: F821
                         nprocs=world_size,  # noqa: F821
                         join=True)
            elif mode == 'qat':
                assert qeager.training, 'qeager should be in training mode for qat'
                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
                train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
            else:
                for i in range(10):
                    qeager(input_value)

            # print('ref after observation:', qeager)

            convert(qeager, inplace=True)
            qeager.eval()

            # print('ref after quantization:', qeager)
            qeager_out = qeager(input_value)
            qeager_script = torch.jit.script(qeager)
            qscript_out = qeager_script(input_value)
            if is_not_tuple_out:
                diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max()
                if check_with_eager:
                    self.assertEqual(diff_from_eager[mode][name], 0,
                                     'Result of graph mode quantization and ' +
                                     'eager mode quantization on model: ' + name +
                                     ' should match. Mode: ' + mode +
                                     ' diff:' + str(diff_from_eager[mode][name]))

    def _test_building_block(self, quant_type, BB):
        eager = BB().float()
        graph = copy.deepcopy(eager)

        if quant_type == QuantType.STATIC:
            qconfig = default_qconfig
            eager_prepare = prepare
            graph_prepare = prepare_fx
            eager.eval()
            graph.eval()
            calibrate_or_train = test_only_eval_fn
            data = self.img_data_2d
            is_qat = False
        else:
            assert quant_type == QuantType.QAT
            qconfig = default_qat_qconfig
            eager_prepare = prepare_qat
            graph_prepare = prepare_qat_fx
            eager.train()
            graph.train()
            calibrate_or_train = test_only_train_fn
            data = self.img_data_2d_train
            is_qat = True

        if hasattr(eager, "fuse_model"):
            eager.fuse_model()
        eager = QuantWrapper(eager)
        eager.qconfig = qconfig
        eager = eager_prepare(eager)

        qconfig_dict = {"": qconfig}
        graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],))

        eager_out = eager(data[0][0])
        graph_out = graph(data[0][0])
        # Eager Mode and FX Graph Mode QAT now differ in numerics both
        # in Post Training and QAT because FX Graph Mode uses same fake_quant instances
        # for input and output of CopyNode
        # self.assertEqual(eager_out, graph_out)

        calibrate_or_train(eager, data)
        calibrate_or_train(graph, data)

        eager = convert(eager)
        graph = convert_fx(graph)

        eager_out = eager(data[0][0])
        graph_out = graph(data[0][0])

    @override_qengines
    def test_resnet_base(self):
        models = [ResNetBase]
        options = itertools.product(self.static_quant_types, models)
        for quant_type, M in options:
            self._test_building_block(quant_type, M)

    @skip_if_no_torchvision
    @skipIfNoFBGEMM
    @unittest.skip("skip for now since tbb failed")
    def test_torchvision(self):
        from torchvision import models
        from torchvision.models import quantization as quantized_models
        from torchvision.models.quantization.utils import _replace_relu

        def get_available_classification_models(models):
            return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]

        model_list = get_available_classification_models(models)
        quantized_model_list = get_available_classification_models(quantized_models)

        quantized_model_list = set(quantized_model_list)
        # test eager and graph consistency
        model_list = quantized_model_list
        # mobilenet/inception_v3/googlenet qat is not working due to AdaptiveAveragePool qat
        # we might observe the output of AdaptiveAveragePool in the future
        # and re-enable the test
        fx_eager_not_matching = [
            ("mobilenet_v2", "qat"),
            ("inception_v3", "qat"),
            ("googlenet", "qat")
        ]  # because relu6 is replaced as relu in mobilenetv2

        diff_of_quant = {}
        diff_from_eager = {}
        modes = ['static', 'qat']
        options = itertools.product(modes, model_list)
        for mode, name in options:
            pretrained = name in quantized_model_list  # load pretrained model to compare with quantized model
            kwargs = {}
            # turn off transform input for inception_v3 since
            # it's not quantized in eager mode and in fx graph
            # mode we can't skip quantizing a method right now
            # (might be supported in the future)
            if name in ["inception_v3", "googlenet"]:
                kwargs["transform_input"] = False
            eager_quantizable_model = None
            if name in quantized_model_list:
                eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False, **kwargs).eval().float()
            # compare with eager mode quantized model when it is available
            pretrained = eager_quantizable_model is not None
            model = models.__dict__[name](pretrained=pretrained, **kwargs).eval().float()
            if name == "mobilenet_v2":
                _replace_relu(model)
            # disable aux logits
            if hasattr(model, "aux_logits"):
                model.aux_logits = False
                model.AuxLogits = None
                if eager_quantizable_model:
                    eager_quantizable_model.aux_logits = False
                    eager_quantizable_model.AuxLogits = None

            check_with_eager = (name, mode) not in fx_eager_not_matching
            self._test_model_impl(
                mode, name, model, eager_quantizable_model,
                check_with_eager,
                diff_of_quant, diff_from_eager)

        def print_diffs(diffs):
            for mode, diffs_for_mode in diffs.items():
                print('mode:', mode)
                for name, diff in diffs_for_mode.items():
                    print(name, ':', diff)

        # print('differences between float and quantized')
        # print_diffs(diff_of_quant)
        # print('----------------------')
        # print('differences between graph mode and eager mode')
        # print_diffs(diff_from_eager)
        # print('----------------------')

    @skip_if_no_torchvision
    @skipIfNoFBGEMM
    @unittest.skip("TODO: Test is always failing - https://github.com/pytorch/pytorch/issues/54979")
    def test_resnet18_ddp(self):
        from torchvision import models
        from torchvision.models import quantization as quantized_models
        eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float()  # noqa: F821
        model = models.__dict__[name](pretrained=False).eval().float()  # noqa: F821
        self._test_model_impl(
            'ddp', 'resnet18', model, eager_quantizable_model)

    @override_qengines
    def test_qat_embeddingbag_linear(self):
        for device in get_supported_device_types():
            class EmbeddingBagLinear(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
                    self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)

                def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
                            per_sample_weights: Optional[torch.Tensor] = None):
                    x = self.emb(input, offsets, per_sample_weights)
                    x = self.linear(x)
                    return x

            qengine = torch.backends.quantized.engine
            qconfig_dict = QConfigMapping() \
                .set_global(get_default_qat_qconfig(qengine)) \
                .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig)

            train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
            eval_output = [[torch.randint(0, 10, (12, 1))]]

            model = EmbeddingBagLinear().train()
            prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],))
            test_only_train_fn(prepared_fx_model, train_indices)
            quant_model = convert_fx(prepared_fx_model,
                                     qconfig_mapping=qconfig_dict)

            def checkQuantized(model):
                # Make sure EmbeddingBag is now a quantized EmbeddingBag.
                self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
                # Also test that Linear has been quantized.
                self.assertTrue(type(model.linear), nnq.Linear)

                test_only_eval_fn(model, eval_output)
                self.checkScriptable(model, eval_output)
                self.checkNoQconfig(model)
            checkQuantized(quant_model)


    @override_qengines
    def test_qat_embedding_linear(self):
        for device in get_supported_device_types():
            class EmbeddingLinear(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
                    self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)

                def forward(self, input: torch.Tensor):
                    x = torch.sum(self.emb(input), dim=1)
                    x = self.linear(x)
                    return x

            qengine = torch.backends.quantized.engine
            qconfig_dict = {"": get_default_qat_qconfig(qengine),
                            "object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]}


            train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
            eval_output = [[torch.randint(0, 10, (12, 1))]]

            model = EmbeddingLinear().train()
            prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],))
            test_only_train_fn(prepared_fx_model, train_indices)
            quant_model = convert_fx(prepared_fx_model,
                                     qconfig_mapping=qconfig_dict)

            def checkQuantized(model):
                # Make sure EmbeddingBag is now a quantized EmbeddingBag.
                self.assertTrue(type(model.emb), nn.quantized.Embedding)
                # Also test that Linear has been quantized.
                self.assertTrue(type(model.linear), nnq.Linear)

                test_only_eval_fn(model, eval_output)
                self.checkScriptable(model, eval_output)
                self.checkNoQconfig(model)
            checkQuantized(quant_model)

    @given(
        device=st.sampled_from(
            ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
        )
    )
    @settings(deadline=None)
    @override_qengines
    def test_qat_functional_linear(self, device):
        if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
            return

        class Linear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.w, self.b)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods1 = torch.nn.Sequential(Linear(), Linear())
                self.mods2 = Linear()

            def forward(self, x):
                x = self.mods1(x)
                x = self.mods2(x)
                return x

        model = M().train()
        ref_fake_quant = FakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=0,
            quant_max=255,
            dtype=torch.quint8,
            reduce_range=False,
        )
        ref_weight_fake_quant = FakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8,
            reduce_range=False,
        )
        ref_qat_qconfig = QConfig(
            activation=ref_fake_quant, weight=ref_weight_fake_quant
        )
        qconfig_dict = {"": ref_qat_qconfig}
        example_inputs = (torch.randn(1, 5),)
        prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)

        custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=0,
            quant_max=255,
            dtype=torch.quint8,
            reduce_range=False,
        )
        custom_weight_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8,
            reduce_range=False,
        )
        custom_qconfig = QConfig(
            activation=custom_fake_quant, weight=custom_weight_fake_quant
        )
        custom_qconfig_dict = {"": custom_qconfig}
        prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs)

        prepared.to(device)
        prepared_ref.to(device)

        prepared.apply(torch.ao.quantization.disable_fake_quant)
        prepared.apply(torch.ao.quantization.disable_observer)
        prepared_ref.apply(torch.ao.quantization.disable_fake_quant)
        prepared_ref.apply(torch.ao.quantization.disable_observer)

        inp = torch.randn(5, 5, device=device, requires_grad=True)
        for i in range(10):
            if i == 2:
                prepared.apply(torch.ao.quantization.enable_observer)
                prepared_ref.apply(torch.ao.quantization.enable_observer)
            if i == 4:
                prepared.apply(torch.ao.quantization.enable_fake_quant)
                prepared_ref.apply(torch.ao.quantization.enable_fake_quant)

            inp = torch.randn(5, 5, device=device, requires_grad=True)
            out_ref = prepared_ref(inp)
            out = prepared(inp)
            torch.testing.assert_close(out, out_ref)

            # try backward pass
            labels = torch.randn(5, 5, device=device)
            loss = (out - labels).sum()
            grad = torch.autograd.grad(loss, [inp])
            loss_ref = (out_ref - labels).sum()
            grad_ref = torch.autograd.grad(loss_ref, [inp])
            torch.testing.assert_close(grad[0], grad_ref[0])

        if 'fbgemm' in torch.backends.quantized.supported_engines:
            # During the lowering step in convert, fold_weight calls quantized::linear_prepack
            # which doesn't support QuantizedCuda backend
            prepared.cpu()
            prepared_ref.cpu()
            converted = convert_fx(prepared)
            converted_ref = convert_fx(prepared_ref)
            inp = torch.rand(5, 5)
            out = converted(inp)
            out_ref = converted_ref(inp)

            torch.testing.assert_close(out, out_ref)
if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_quantization.py TESTNAME\n\n"
                       "instead.")
