import operator_benchmark as op_bench

import torch


embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
    num_embeddings=(80,), embedding_dim=(128, 256, 512), tags=("short",)
)

embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
    num_embeddings=(100, 120, 1000),
    embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
    tags=("long",),
)

embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
    num_embeddings=(80,),
    embedding_dim=(128, 256, 512),
    batch_size=(10,),
    tags=("short",),
)

conversion_ops = op_bench.op_list(
    attrs=(
        ("qembeddingbag_byte_prepack", torch.ops.quantized.embedding_bag_byte_prepack),
        ("qembeddingbag_4bit_prepack", torch.ops.quantized.embedding_bag_4bit_prepack),
        ("qembeddingbag_2bit_prepack", torch.ops.quantized.embedding_bag_2bit_prepack),
    ),
    attr_names=("op_name", "op_func"),
)

unpack_ops = op_bench.op_list(
    attrs=(
        ("qembeddingbag_byte_unpack", torch.ops.quantized.embedding_bag_byte_unpack),
        ("qembeddingbag_4bit_unpack", torch.ops.quantized.embedding_bag_4bit_unpack),
        ("qembeddingbag_2bit_unpack", torch.ops.quantized.embedding_bag_2bit_unpack),
    ),
    attr_names=("op_name", "op_func"),
)


class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, op_func):
        self.inputs = {
            "weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
        }
        self.op_func = op_func

    def forward(self, weight):
        return self.op_func(weight)


class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
        self.inputs = {
            "weight": torch.rand(
                batch_size, num_embeddings, embedding_dim, dtype=torch.float
            )
            + 1
        }
        self.op_func = op_func

    def forward(self, weight):
        return self.op_func(weight)


class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, op_func):
        weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
        self.inputs = {"packed_weight": weight.to(torch.uint8)}
        self.op_func = op_func

    def forward(self, packed_weight):
        return self.op_func(packed_weight)


class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
        weight = torch.randn(
            batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float
        )
        self.inputs = {"packed_weight": weight.to(torch.uint8)}
        self.op_func = op_func

    def forward(self, packed_weight):
        return self.op_func(packed_weight)


op_bench.generate_pt_tests_from_op_list(
    conversion_ops,
    embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
    EmbeddingBagFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
    unpack_ops,
    embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
    EmbeddingBagFusedToFloatBase,
)
op_bench.generate_pt_tests_from_op_list(
    conversion_ops,
    embeddingbag_conversion_three_dim_configs,
    EmbeddingBagThreeDimFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
    unpack_ops,
    embeddingbag_conversion_three_dim_configs,
    EmbeddingBagThreeDimFusedToFloatBase,
)

if __name__ == "__main__":
    op_bench.benchmark_runner.main()
