# Owner(s): ["module: unknown"]

import functools
import unittest

import torch
import torch.nn.functional as F
import torch.utils.flop_counter
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.testing._internal.common_cuda import (
    PLATFORM_SUPPORTS_FLASH_ATTENTION,
    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_utils import (
    run_tests,
    TEST_WITH_TORCHDYNAMO,
    TestCase,
    skipIfRocm,
)

try:
    from torchvision import models as torchvision_models

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

HAS_CUDA = torch.cuda.is_available()


def FlopCounterMode(*args, **kwargs):
    return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)


def get_total_flops(mode):
    return str(sum(v for _, v in mode.flop_counts["Global"].items()))


def T(*shape, requires_grad=False):
    return torch.randn(*shape, requires_grad=requires_grad)


@unittest.skipIf(
    TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now"
)
class TestFlopCounter(TestCase):
    def test_flop_counter_variety(self):
        mod = torch.nn.Linear(9, 10)
        with FlopCounterMode() as mode:
            torch.mm(T(4, 5), T(5, 6))
            torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
            torch.matmul(T(5, 6), T(6, 7))
            torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
            mod(T(8, 9))

        self.assertExpectedInline(get_total_flops(mode), """3012""")

    def test_op(self):
        with FlopCounterMode() as mode:
            torch.mm(T(4, 5), T(5, 6))
        # 4 * 6 * 2 * 5 = 240
        self.assertExpectedInline(get_total_flops(mode), """240""")

        with mode:
            torch.bmm(T(3, 4, 5), T(3, 5, 6))
        # 3 * 4 * 6 * 2 * 5 = 720
        self.assertExpectedInline(get_total_flops(mode), """720""")

        with mode:
            torch.addmm(T(4, 6), T(4, 5), T(5, 6))
            torch.addmm(T(4, 1), T(4, 5), T(5, 6))
            torch.addmm(T(6), T(4, 5), T(5, 6))

        # 4 * 6 * 2 * 5 = 240
        self.assertExpectedInline(get_total_flops(mode), """720""")

        with mode:
            torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))

        # 3 * 4 * 6 * 2 * 5 = 720
        self.assertExpectedInline(get_total_flops(mode), """720""")

        with mode:
            torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)

        # out_image_size = 2 * 5 * 5
        # kernel_size = 4 * 4
        # c_out = 6
        # c_in = 3
        # out_image_size * kernel_size * c_out * 2 * c_in

        # NB: I don't think this properly accounts for padding?
        self.assertExpectedInline(get_total_flops(mode), """28800""")

        with mode:
            torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)

        # out_image_size = 2 * 5
        # kernel_size = 4
        # c_out = 6
        # c_in = 3
        # out_image_size * kernel_size * c_out * 2 * c_in

        # NB: I don't think this properly accounts for padding?
        self.assertExpectedInline(get_total_flops(mode), """1440""")

    def test_backward(self):
        with FlopCounterMode() as mode:
            a = T(4, 5, requires_grad=True)
            a = torch.mm(a, T(5, 6))
            a = a.unsqueeze(0).expand(7, 4, 6)
            a = torch.bmm(a, T(7, 6, 7))
            a.sum().backward()

        self.assertExpectedInline(get_total_flops(mode), """5184""")

    def test_backward_reset(self):
        with FlopCounterMode() as mode:
            a = T(4, 5, requires_grad=True)
            a.mm(a.t()).sum().backward()
            a.mm(a.t()).sum().backward()

        self.assertExpectedInline(get_total_flops(mode), """960""")

    def test_torchscript(self):
        def foo(x):
            return torch.mm(x, x)

        with FlopCounterMode() as mode:
            foo(T(5, 5))
        unscripted_flops = get_total_flops(mode)
        ts_foo = torch.jit.script(foo)
        with mode:
            ts_foo(T(5, 5))
        self.assertEqual(unscripted_flops, get_total_flops(mode))

    def test_autograd_op(self):
        class _CustomOp(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input: torch.Tensor) -> torch.Tensor:
                return torch.mm(input, input)

            @staticmethod
            def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
                return torch.mm(grad_output, grad_output) + torch.mm(
                    grad_output, grad_output
                )

        a = T(5, 5, requires_grad=True)
        with FlopCounterMode() as mode:
            a = _CustomOp.apply(a)
            a.sum().backward()

        self.assertExpectedInline(get_total_flops(mode), """750""")

    def test_conv_backwards_as_decomposition(self):
        # [conv backwards decomposition as conv forwards]

        class onlyConvs(torch.autograd.Function):
            @staticmethod
            def forward(inp, weight, transposed):
                if not transposed:
                    return F.conv1d(inp, weight)
                else:
                    return F.conv_transpose1d(inp, weight)

            @staticmethod
            def setup_context(ctx, inputs, output):
                inp, weight, transposed = inputs
                ctx.save_for_backward(inp, weight)
                ctx.transposed = transposed

            @staticmethod
            def backward(ctx, grad_out):
                inp, weight = ctx.saved_tensors
                if not ctx.transposed:
                    grad_inp = F.conv_transpose1d(grad_out, weight)
                    grad_weight = F.conv1d(inp, grad_out)
                    return grad_inp, grad_weight, None
                else:
                    grad_inp = F.conv1d(grad_out, weight)
                    grad_weight = F.conv1d(
                        grad_out.transpose(1, 0), inp.transpose(1, 0)
                    )
                    return grad_inp, grad_weight.transpose(1, 0), None

        from torch.func import grad

        x = torch.randn(2, 3, 16, dtype=torch.float64)
        weight = torch.randn(3, 4, 4, dtype=torch.float64)

        def boring_conv(x, weight, transposed):
            if not transposed:
                return F.conv1d(x, weight).pow(2).sum()
            else:
                return F.conv_transpose1d(x, weight).pow(2).sum()

        def only_convs(x, weight, transposed):
            return onlyConvs.apply(x, weight, transposed).pow(2).sum()

        boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
        fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)

        self.assertEqual(boring_grads, fun_grads)

    def test_convs(self):
        def assert_equivalence(f, expected_forward=None):
            with FlopCounterMode() as mode:
                f()
            conv_forward_flops = mode.get_flop_counts()["Global"][
                torch.ops.aten.convolution
            ]
            conv_backward_flops = mode.get_flop_counts()["Global"][
                torch.ops.aten.convolution_backward
            ]

            self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
            if expected_forward is not None:
                self.assertEqual(conv_forward_flops, expected_forward)

        x = torch.rand(1, 1, 2, 2, requires_grad=True)
        weight = torch.randn(1, 1, 2, 2, requires_grad=True)
        assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)

        x = torch.rand(1, 1, 2, 2, requires_grad=True)
        weight = torch.randn(1, 1, 1, 1, requires_grad=True)
        assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)

        for in_channels, out_channels, groups in [
            (1, 1, 1),
            (1, 3, 1),
            (3, 1, 1),
            (3, 7, 1),
            (2, 4, 2),
            (4, 2, 2),
        ]:
            x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
            weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
            assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
            transposed_weight = torch.randn(
                in_channels, out_channels, 2, 2, requires_grad=True
            )
            assert_equivalence(
                lambda: F.conv_transpose2d(x, transposed_weight).sum().backward()
            )

    @skipIfNoTorchVision
    def test_module(self):
        resnet18 = torchvision_models.resnet18()
        with FlopCounterMode(resnet18) as mode:
            a = T(1, 3, 224, 224, requires_grad=True)
            resnet18(a).sum().backward()

        self.assertExpectedInline(get_total_flops(mode), """10884440064""")
        layer1_conv_flops = mode.flop_counts["ResNet.layer1"][
            torch.ops.aten.convolution
        ]
        layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][
            torch.ops.aten.convolution_backward
        ]
        self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
        self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")

    def test_conv_transpose_loop(self):
        x = torch.rand(1, 4, 30, 2)
        model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)

        with FlopCounterMode() as mode:
            for i in range(50):
                out = model(x)
                out.sum().backward()
        self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")

    def test_custom(self):
        mode = FlopCounterMode(
            custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5}
        )
        with mode:
            a = T(4, 5)
            a + a

        self.assertExpectedInline(get_total_flops(mode), """5""")

        def count(*args, out_val):
            return out_val.numel()

        count._get_raw = True

        mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
        with mode:
            a = T(4, 5)
            a + a

        self.assertExpectedInline(get_total_flops(mode), """20""")

    def test_noop(self):
        with FlopCounterMode() as mode:
            T(4, 5).cos()

    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FLASH_ATTENTION
        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
    )
    def test_sdpa(self):
        batch_size = 4
        n_heads = 8
        seq_len_q = 128
        seq_len_k = 256
        head_dim = 64
        head_dim_v = 64
        dtype = torch.float16

        torch.manual_seed(0)

        def get_flops(
            batch_size,
            n_heads,
            seq_len_q,
            seq_len_k,
            head_dim,
            head_dim_v,
            dtype,
            backend,
            with_backward=False,
        ):
            query = torch.randn(
                batch_size,
                n_heads,
                seq_len_q,
                head_dim,
                device="cuda",
                dtype=dtype,
                requires_grad=True,
            )
            key = torch.randn(
                batch_size,
                n_heads,
                seq_len_k,
                head_dim,
                device="cuda",
                dtype=dtype,
                requires_grad=True,
            )
            value = torch.randn(
                batch_size,
                n_heads,
                seq_len_k,
                head_dim_v,
                device="cuda",
                dtype=dtype,
                requires_grad=True,
            )

            if backend == "math":
                backend = torch.backends.cuda.sdp_kernel(
                    enable_flash=False, enable_math=True, enable_mem_efficient=False
                )
            elif backend == "flash":
                backend = torch.backends.cuda.sdp_kernel(
                    enable_flash=True, enable_math=False, enable_mem_efficient=False
                )
            elif backend == "mem_efficient":
                backend = torch.backends.cuda.sdp_kernel(
                    enable_flash=False, enable_math=False, enable_mem_efficient=True
                )

            mode = FlopCounterMode()
            with backend, mode:
                out = F.scaled_dot_product_attention(
                    query, key, value, dropout_p=0, is_causal=True
                )
                if with_backward:
                    out.sum().backward()
            return int(get_total_flops(mode))

        # Sets seq_len_q == seq_len_k and dim_q == dim_v
        run_uniform_flops = functools.partial(
            get_flops,
            batch_size,
            n_heads,
            seq_len_q,
            seq_len_q,
            head_dim,
            head_dim,
            dtype,
        )

        flops = [
            run_uniform_flops(backend, with_backward=False)
            for backend in ["math", "flash", "mem_efficient"]
        ]
        flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
        self.assertEqual(flops_fw_math, flops_fw_flash)
        self.assertEqual(flops_fw_math, flops_fw_efficient)

        self.assertExpectedInline(str(flops_fw_math), """134217728""")

        flops = [
            run_uniform_flops(backend, with_backward=True)
            for backend in ["math", "flash", "mem_efficient"]
        ]
        flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
        self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
        self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
        self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)

        run_nonuniform_flops = functools.partial(
            get_flops,
            batch_size,
            n_heads,
            seq_len_q,
            seq_len_k,
            head_dim,
            head_dim_v,
            dtype,
        )
        # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
        non_uniform_backends = ["math", "mem_efficient"]
        flops = [
            run_nonuniform_flops(backend, with_backward=False)
            for backend in non_uniform_backends
        ]
        flops_fw_math, flops_fw_efficient = flops
        self.assertEqual(flops_fw_math, flops_fw_efficient)

        self.assertExpectedInline(str(flops_fw_math), """268435456""")

        flops = [
            run_nonuniform_flops(backend, with_backward=True)
            for backend in non_uniform_backends
        ]
        flops_fw_bw_math, flops_fw_bw_efficient = flops
        self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
        self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")

    @skipIfRocm  # Nested tensor
    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FLASH_ATTENTION
        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
    )
    def test_sdpa_nested_tensor(self):
        def get_flops(q, k, v, backend, with_backward=False):
            mode = FlopCounterMode()

            if backend == "math":
                backend = torch.backends.cuda.sdp_kernel(
                    enable_flash=False, enable_math=True, enable_mem_efficient=False
                )
            elif backend == "flash":
                backend = torch.backends.cuda.sdp_kernel(
                    enable_flash=True, enable_math=False, enable_mem_efficient=False
                )
            elif backend == "mem_efficient":
                backend = torch.backends.cuda.sdp_kernel(
                    enable_flash=False, enable_math=False, enable_mem_efficient=True
                )

            with backend, mode:
                out = F.scaled_dot_product_attention(
                    q, k, v, dropout_p=0, is_causal=True
                )
                if with_backward:
                    if out.is_nested:
                        out.values().sum().backward()
                    else:
                        out.sum().backward()

            return int(get_total_flops(mode))

        def get_nested_inputs(
            batch_size,
            n_heads,
            max_seq_len_q,
            max_seq_len_k,
            head_dim,
            head_dim_v,
            dtype,
        ):
            q_lengths = torch.tensor(
                [
                    max_seq_len_q // 4,
                    max_seq_len_q // 4 * 2,
                    max_seq_len_q // 4 * 3,
                    max_seq_len_q // 4 * 4,
                ]
            )
            k_lengths = torch.tensor(
                [
                    max_seq_len_k // 4,
                    max_seq_len_k // 4 * 2,
                    max_seq_len_k // 4 * 3,
                    max_seq_len_k // 4 * 4,
                ]
            )
            q_offsets, k_offsets = (
                torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda()
                for lengths in (q_lengths, k_lengths)
            )
            q_values = torch.randn(
                q_offsets[-1],
                head_dim * n_heads,
                dtype=dtype,
                requires_grad=True,
                device="cuda",
            )
            k_values = torch.randn(
                k_offsets[-1],
                head_dim * n_heads,
                dtype=dtype,
                requires_grad=True,
                device="cuda",
            )
            v_values = torch.randn(
                k_offsets[-1],
                head_dim_v * n_heads,
                dtype=dtype,
                requires_grad=True,
                device="cuda",
            )

            q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets)
            k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets)
            v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets)

            q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
            k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
            v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2)

            return q, k, v

        def get_dense_flops(q, k, v, backend, with_backward=False):
            def split_tensor(x):
                return (
                    y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True)
                    for y in x.transpose(1, 2).unbind(0)
                )

            q_tensors = split_tensor(q)
            k_tensors = split_tensor(k)
            v_tensors = split_tensor(v)

            flops = 0
            for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors):
                flops += get_flops(
                    q_i, k_i, v_i, backend=backend, with_backward=with_backward
                )

            return flops

        uniform_config = {
            "batch_size": 4,
            "n_heads": 8,
            "max_seq_len_q": 128,
            "max_seq_len_k": 128,
            "head_dim": 64,
            "head_dim_v": 64,
            "dtype": torch.float16,
        }

        # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors.
        differing_config = {
            "batch_size": 4,
            "n_heads": 8,
            "max_seq_len_q": 128,
            "max_seq_len_k": 256,
            "head_dim": 64,
            "head_dim_v": 64,
            "dtype": torch.float16,
        }

        self.assertEqual(
            get_dense_flops(
                *get_nested_inputs(**uniform_config),
                backend="flash",
                with_backward=False,
            ),
            get_flops(
                *get_nested_inputs(**uniform_config),
                backend="flash",
                with_backward=False,
            ),
        )
        self.assertEqual(
            get_dense_flops(
                *get_nested_inputs(**uniform_config),
                backend="mem_efficient",
                with_backward=False,
            ),
            get_flops(
                *get_nested_inputs(**uniform_config),
                backend="mem_efficient",
                with_backward=False,
            ),
        )
        self.assertEqual(
            get_dense_flops(
                *get_nested_inputs(**differing_config),
                backend="mem_efficient",
                with_backward=False,
            ),
            get_flops(
                *get_nested_inputs(**differing_config),
                backend="mem_efficient",
                with_backward=False,
            ),
        )

        self.assertEqual(
            get_dense_flops(
                *get_nested_inputs(**uniform_config),
                backend="flash",
                with_backward=True,
            ),
            get_flops(
                *get_nested_inputs(**uniform_config),
                backend="flash",
                with_backward=True,
            ),
        )
        self.assertEqual(
            get_dense_flops(
                *get_nested_inputs(**uniform_config),
                backend="mem_efficient",
                with_backward=True,
            ),
            get_flops(
                *get_nested_inputs(**uniform_config),
                backend="mem_efficient",
                with_backward=True,
            ),
        )
        self.assertEqual(
            get_dense_flops(
                *get_nested_inputs(**differing_config),
                backend="mem_efficient",
                with_backward=True,
            ),
            get_flops(
                *get_nested_inputs(**differing_config),
                backend="mem_efficient",
                with_backward=True,
            ),
        )

    @skipIfRocm  # Nested tensor
    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
    )
    def test_nested_attention_fake_tensors(self):
        x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
        offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
        max_seqlen = 40
        with FakeTensorMode() as fake_mode:
            fake_x = fake_mode.from_tensor(x)
            fake_offsets = fake_mode.from_tensor(offsets)

            with FlopCounterMode() as fake_flop_counter_mode:
                torch.ops.aten._flash_attention_forward(
                    fake_x,
                    fake_x,
                    fake_x,
                    fake_offsets,
                    fake_offsets,
                    max_seqlen,
                    max_seqlen,
                    0.0,
                    False,
                    False,
                )

        dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)

        with FlopCounterMode() as real_flop_counter_mode:
            torch.ops.aten._flash_attention_forward(
                dense_x,
                dense_x,
                dense_x,
                None,
                None,
                max_seqlen,
                max_seqlen,
                0.0,
                False,
                False,
            )

        self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))


    def test_addmm_out(self):
        def f(x):
            y = torch.zeros(10, 10)
            return torch.mm(x, x, out=y)

        with FlopCounterMode() as mode:
            f(torch.randn(10, 10))

        self.assertExpectedInline(get_total_flops(mode), """2000""")

    def test_hook_registration(self):
        model = torch.nn.Linear(100, 100)
        x = torch.randn(3, 100)

        with FlopCounterMode() as mode:
            self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1)
            self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1)
            model(x).sum().backward()

        self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0)
        self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0)

    def test_pytrees(self):
        class Foo(torch.nn.Module):
            def forward(self, x):
                x = x["a"].relu_()
                return {"a": torch.mm(x, x)}

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = Foo()
                self.b = Foo()

            def forward(self, x):
                return self.b(self.a(x))

        mod = Mod()
        with FlopCounterMode() as mode:
            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
                "a"
            ].sum().backward()
        self.assertExpectedInline(
            (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000"""
        )

        class Mod2(torch.nn.Module):
            def forward(self, x):
                return (torch.mm(x, x),)

        mod = Mod2()
        with FlopCounterMode() as mode:
            mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
        self.assertExpectedInline(
            (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000"""
        )

    def test_warning(self):
        mod = torch.nn.Linear(2, 2)
        with self.assertWarnsRegex(UserWarning, "not needed"):
            FlopCounterMode(mod)

    def test_custom_op(self):
        from torch.utils.flop_counter import FlopCounterMode, register_flop_formula

        @torch.library.custom_op("mylib::foo", mutates_args=())
        def foo(x: torch.Tensor) -> torch.Tensor:
            return x.sin()

        called = 0

        with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
            register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)

        @register_flop_formula(torch.ops.mylib.foo)
        def formula(*args, **kwargs):
            nonlocal called
            called += 1
            return 9001

        x = torch.randn(3)
        with FlopCounterMode(display=False) as mode:
            y = foo(x)

        self.assertEqual(called, 1)
        self.assertExpectedInline(get_total_flops(mode), """9001""")


if __name__ == "__main__":
    run_tests()
