import torch


input = []
input.append(torch.tensor([1.0, 2.0, 3.0, 4.0]))
input.append(torch.tensor([[1.0, 2.0, 3.0, 4.0]]))
input.append(torch.tensor([[[1.0, 2.0, 3.0, 4.0]]]))
reveal_type(input[0].shape[0])  # E: int
reveal_type(input[1].shape[1])  # E: int
reveal_type(input[2].shape[2])  # E: int
