# Owner(s): ["oncall: package/deploy"]

import torch


try:
    from torchvision.models import resnet18

    class TorchVisionTest(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.tvmod = resnet18()

        def forward(self, x):
            x = a_non_torch_leaf(x, x)
            return torch.relu(x + 3.0)

except ImportError:
    pass


def a_non_torch_leaf(a, b):
    return a + b
