# Owner(s): ["module: distributions"]

import io
from numbers import Number

import pytest

import torch
from torch.autograd import grad
from torch.autograd.functional import jacobian
from torch.distributions import (
    constraints,
    Dirichlet,
    Independent,
    Normal,
    TransformedDistribution,
)
from torch.distributions.transforms import (
    _InverseTransform,
    AbsTransform,
    AffineTransform,
    ComposeTransform,
    CorrCholeskyTransform,
    CumulativeDistributionTransform,
    ExpTransform,
    identity_transform,
    IndependentTransform,
    LowerCholeskyTransform,
    PositiveDefiniteTransform,
    PowerTransform,
    ReshapeTransform,
    SigmoidTransform,
    SoftmaxTransform,
    SoftplusTransform,
    StickBreakingTransform,
    TanhTransform,
    Transform,
)
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
from torch.testing._internal.common_utils import run_tests


def get_transforms(cache_size):
    transforms = [
        AbsTransform(cache_size=cache_size),
        ExpTransform(cache_size=cache_size),
        PowerTransform(exponent=2, cache_size=cache_size),
        PowerTransform(exponent=-2, cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
        SigmoidTransform(cache_size=cache_size),
        TanhTransform(cache_size=cache_size),
        AffineTransform(0, 1, cache_size=cache_size),
        AffineTransform(1, -2, cache_size=cache_size),
        AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size),
        AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size),
        SoftmaxTransform(cache_size=cache_size),
        SoftplusTransform(cache_size=cache_size),
        StickBreakingTransform(cache_size=cache_size),
        LowerCholeskyTransform(cache_size=cache_size),
        CorrCholeskyTransform(cache_size=cache_size),
        PositiveDefiniteTransform(cache_size=cache_size),
        ComposeTransform(
            [
                AffineTransform(
                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
                ),
            ]
        ),
        ComposeTransform(
            [
                AffineTransform(
                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
                ),
                ExpTransform(cache_size=cache_size),
            ]
        ),
        ComposeTransform(
            [
                AffineTransform(0, 1, cache_size=cache_size),
                AffineTransform(
                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
                ),
                AffineTransform(1, -2, cache_size=cache_size),
                AffineTransform(
                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
                ),
            ]
        ),
        ReshapeTransform((4, 5), (2, 5, 2)),
        IndependentTransform(
            AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1
        ),
        CumulativeDistributionTransform(Normal(0, 1)),
    ]
    transforms += [t.inv for t in transforms]
    return transforms


def reshape_transform(transform, shape):
    # Needed to squash batch dims for testing jacobian
    if isinstance(transform, AffineTransform):
        if isinstance(transform.loc, Number):
            return transform
        try:
            return AffineTransform(
                transform.loc.expand(shape),
                transform.scale.expand(shape),
                cache_size=transform._cache_size,
            )
        except RuntimeError:
            return AffineTransform(
                transform.loc.reshape(shape),
                transform.scale.reshape(shape),
                cache_size=transform._cache_size,
            )
    if isinstance(transform, ComposeTransform):
        reshaped_parts = []
        for p in transform.parts:
            reshaped_parts.append(reshape_transform(p, shape))
        return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
    if isinstance(transform.inv, AffineTransform):
        return reshape_transform(transform.inv, shape).inv
    if isinstance(transform.inv, ComposeTransform):
        return reshape_transform(transform.inv, shape).inv
    return transform


# Generate pytest ids
def transform_id(x):
    assert isinstance(x, Transform)
    name = (
        f"Inv({type(x._inv).__name__})"
        if isinstance(x, _InverseTransform)
        else f"{type(x).__name__}"
    )
    return f"{name}(cache_size={x._cache_size})"


def generate_data(transform):
    torch.manual_seed(1)
    while isinstance(transform, IndependentTransform):
        transform = transform.base_transform
    if isinstance(transform, ReshapeTransform):
        return torch.randn(transform.in_shape)
    if isinstance(transform.inv, ReshapeTransform):
        return torch.randn(transform.inv.out_shape)
    domain = transform.domain
    while (
        isinstance(domain, constraints.independent)
        and domain is not constraints.real_vector
    ):
        domain = domain.base_constraint
    codomain = transform.codomain
    x = torch.empty(4, 5)
    positive_definite_constraints = [
        constraints.lower_cholesky,
        constraints.positive_definite,
    ]
    if domain in positive_definite_constraints:
        x = torch.randn(6, 6)
        x = x.tril(-1) + x.diag().exp().diag_embed()
        if domain is constraints.positive_definite:
            return x @ x.T
        return x
    elif codomain in positive_definite_constraints:
        return torch.randn(6, 6)
    elif domain is constraints.real:
        return x.normal_()
    elif domain is constraints.real_vector:
        # For corr_cholesky the last dim in the vector
        # must be of size (dim * dim) // 2
        x = torch.empty(3, 6)
        x = x.normal_()
        return x
    elif domain is constraints.positive:
        return x.normal_().exp()
    elif domain is constraints.unit_interval:
        return x.uniform_()
    elif isinstance(domain, constraints.interval):
        x = x.uniform_()
        x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
        return x
    elif domain is constraints.simplex:
        x = x.normal_().exp()
        x /= x.sum(-1, True)
        return x
    elif domain is constraints.corr_cholesky:
        x = torch.empty(4, 5, 5)
        x = x.normal_().tril()
        x /= x.norm(dim=-1, keepdim=True)
        x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
        return x
    raise ValueError(f"Unsupported domain: {domain}")


TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
ALL_TRANSFORMS = (
    TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
)


@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_inv_inv(transform, ids=transform_id):
    assert transform.inv.inv is transform


@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_equality(x, y):
    if x is y:
        assert x == y
    else:
        assert x != y
    assert identity_transform == identity_transform.inv


@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_with_cache(transform):
    if transform._cache_size == 0:
        transform = transform.with_cache(1)
    assert transform._cache_size == 1
    x = generate_data(transform).requires_grad_()
    try:
        y = transform(x)
    except NotImplementedError:
        pytest.skip("Not implemented.")
    y2 = transform(x)
    assert y2 is y


@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("test_cached", [True, False])
def test_forward_inverse(transform, test_cached):
    x = generate_data(transform).requires_grad_()
    assert transform.domain.check(x).all()  # verify that the input data are valid
    try:
        y = transform(x)
    except NotImplementedError:
        pytest.skip("Not implemented.")
    assert y.shape == transform.forward_shape(x.shape)
    if test_cached:
        x2 = transform.inv(y)  # should be implemented at least by caching
    else:
        try:
            x2 = transform.inv(y.clone())  # bypass cache
        except NotImplementedError:
            pytest.skip("Not implemented.")
    assert x2.shape == transform.inverse_shape(y.shape)
    y2 = transform(x2)
    if transform.bijective:
        # verify function inverse
        assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join(
            [
                f"{transform} t.inv(t(-)) error",
                f"x = {x}",
                f"y = t(x) = {y}",
                f"x2 = t.inv(y) = {x2}",
            ]
        )
    else:
        # verify weaker function pseudo-inverse
        assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join(
            [
                f"{transform} t(t.inv(t(-))) error",
                f"x = {x}",
                f"y = t(x) = {y}",
                f"x2 = t.inv(y) = {x2}",
                f"y2 = t(x2) = {y2}",
            ]
        )


def test_compose_transform_shapes():
    transform0 = ExpTransform()
    transform1 = SoftmaxTransform()
    transform2 = LowerCholeskyTransform()

    assert transform0.event_dim == 0
    assert transform1.event_dim == 1
    assert transform2.event_dim == 2
    assert ComposeTransform([transform0, transform1]).event_dim == 1
    assert ComposeTransform([transform0, transform2]).event_dim == 2
    assert ComposeTransform([transform1, transform2]).event_dim == 2


transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))


@pytest.mark.parametrize(
    ("batch_shape", "event_shape", "dist"),
    [
        ((4, 4), (), base_dist0),
        ((4,), (4,), base_dist1),
        ((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
        ((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
        ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
        ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
        ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
        ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
        ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
        ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
        ((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
        ((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
        ((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
        ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
        ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
        ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
        ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
        ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
        ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
        ((3, 4, 4), (), base_dist2),
        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
    ],
)
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
    assert dist.batch_shape == batch_shape
    assert dist.event_shape == event_shape
    x = dist.rsample()
    try:
        dist.log_prob(x)  # this should not crash
    except NotImplementedError:
        pytest.skip("Not implemented.")


@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_fwd(transform):
    x = generate_data(transform).requires_grad_()

    def f(x):
        return transform(x)

    try:
        traced_f = torch.jit.trace(f, (x,))
    except NotImplementedError:
        pytest.skip("Not implemented.")

    # check on different inputs
    x = generate_data(transform).requires_grad_()
    assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)


@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_inv(transform):
    y = generate_data(transform.inv).requires_grad_()

    def f(y):
        return transform.inv(y)

    try:
        traced_f = torch.jit.trace(f, (y,))
    except NotImplementedError:
        pytest.skip("Not implemented.")

    # check on different inputs
    y = generate_data(transform.inv).requires_grad_()
    assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)


@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_jacobian(transform):
    x = generate_data(transform).requires_grad_()

    def f(x):
        y = transform(x)
        return transform.log_abs_det_jacobian(x, y)

    try:
        traced_f = torch.jit.trace(f, (x,))
    except NotImplementedError:
        pytest.skip("Not implemented.")

    # check on different inputs
    x = generate_data(transform).requires_grad_()
    assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)


@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_jacobian(transform):
    x = generate_data(transform)
    try:
        y = transform(x)
        actual = transform.log_abs_det_jacobian(x, y)
    except NotImplementedError:
        pytest.skip("Not implemented.")
    # Test shape
    target_shape = x.shape[: x.dim() - transform.domain.event_dim]
    assert actual.shape == target_shape

    # Expand if required
    transform = reshape_transform(transform, x.shape)
    ndims = len(x.shape)
    event_dim = ndims - transform.domain.event_dim
    x_ = x.view((-1,) + x.shape[event_dim:])
    n = x_.shape[0]
    # Reshape to squash batch dims to a single batch dim
    transform = reshape_transform(transform, x_.shape)

    # 1. Transforms with unit jacobian
    if isinstance(transform, ReshapeTransform) or isinstance(
        transform.inv, ReshapeTransform
    ):
        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
    # 2. Transforms with 0 off-diagonal elements
    elif transform.domain.event_dim == 0:
        jac = jacobian(transform, x_)
        # assert off-diagonal elements are zero
        assert torch.allclose(jac, jac.diagonal().diag_embed())
        expected = jac.diagonal().abs().log().reshape(x.shape)
    # 3. Transforms with non-0 off-diagonal elements
    else:
        if isinstance(transform, CorrCholeskyTransform):
            jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
        elif isinstance(transform.inv, CorrCholeskyTransform):
            jac = jacobian(
                lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
                tril_matrix_to_vec(x_, diag=-1),
            )
        elif isinstance(transform, StickBreakingTransform):
            jac = jacobian(lambda x: transform(x)[..., :-1], x_)
        else:
            jac = jacobian(transform, x_)

        # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
        # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
        # after reshaping the event dims (see above) to give a batched square matrix whose determinant
        # can be computed.
        gather_idx_shape = list(jac.shape)
        gather_idx_shape[-2] = 1
        gather_idxs = (
            torch.arange(n)
            .reshape((n,) + (1,) * (len(jac.shape) - 1))
            .expand(gather_idx_shape)
        )
        jac = jac.gather(-2, gather_idxs).squeeze(-2)
        out_ndims = jac.shape[-2]
        jac = jac[
            ..., :out_ndims
        ]  # Remove extra zero-valued dims (for inverse stick-breaking).
        expected = torch.slogdet(jac).logabsdet

    assert torch.allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize(
    "event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str
)
def test_compose_affine(event_dims):
    transforms = [
        AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims
    ]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == max(event_dims)
    assert transform.domain.event_dim == max(event_dims)

    base_dist = Normal(0, 1)
    if transform.domain.event_dim:
        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
    dist = TransformedDistribution(base_dist, transform.parts)
    assert dist.support.event_dim == max(event_dims)

    base_dist = Dirichlet(torch.ones(5))
    if transform.domain.event_dim > 1:
        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
    dist = TransformedDistribution(base_dist, transforms)
    assert dist.support.event_dim == max(1, *event_dims)


@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
    transforms = [
        ReshapeTransform((), ()),
        ReshapeTransform((2,), (1, 2)),
        ReshapeTransform((3, 1, 2), (6,)),
        ReshapeTransform((6,), (2, 3)),
    ]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == 2
    assert transform.domain.event_dim == 2
    data = torch.randn(batch_shape + (3, 2))
    assert transform(data).shape == batch_shape + (2, 3)

    dist = TransformedDistribution(Normal(data, 1), transforms)
    assert dist.batch_shape == batch_shape
    assert dist.event_shape == (2, 3)
    assert dist.support.event_dim == 2


@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(
    base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape
):
    shape = torch.Size([2, 3, 4, 5])
    base_dist = Normal(0, 1)
    base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :])
    if base_event_dim:
        base_dist = Independent(base_dist, base_event_dim)
    transforms = [
        AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1),
        ReshapeTransform((4, 5), (20,)),
        ReshapeTransform((3, 20), (6, 10)),
    ]
    transforms = transforms[:num_transforms]
    transform = ComposeTransform(transforms)

    # Check validation in .__init__().
    if base_batch_dim + base_event_dim < transform.domain.event_dim:
        with pytest.raises(ValueError):
            TransformedDistribution(base_dist, transforms)
        return
    d = TransformedDistribution(base_dist, transforms)

    # Check sampling is sufficiently expanded.
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + d.batch_shape + d.event_shape
    num_unique = len(set(x.reshape(-1).tolist()))
    assert num_unique >= 0.9 * x.numel()

    # Check log_prob shape on full samples.
    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + d.batch_shape

    # Check log_prob shape on partial samples.
    y = x
    while y.dim() > len(d.event_shape):
        y = y[0]
    log_prob = d.log_prob(y)
    assert log_prob.shape == d.batch_shape


def test_save_load_transform():
    # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
    # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
    dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
    x = torch.linspace(0, 1, 10)
    log_prob = dist.log_prob(x)
    stream = io.BytesIO()
    torch.save(dist, stream)
    stream.seek(0)
    other = torch.load(stream)
    assert torch.allclose(log_prob, other.log_prob(x))


@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_transform_sign(transform: Transform):
    try:
        sign = transform.sign
    except NotImplementedError:
        pytest.skip("Not implemented.")

    x = generate_data(transform).requires_grad_()
    y = transform(x).sum()
    (derivatives,) = grad(y, [x])
    assert torch.less(torch.as_tensor(0.0), derivatives * sign).all()


if __name__ == "__main__":
    run_tests()
