import torch


RPC_SPARSE = "rpc_sparse"
RPC_DENSE = "rpc_dense"


def sparse_tensor_to_rpc_format(sparse_tensor):
    r"""
    A helper function creates a list containing the indices, values, and size
    of a coalesced sparse tensor.
    Args:
        sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
    """
    sparse_tensor = sparse_tensor.coalesce()
    return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]


def sparse_rpc_format_to_tensor(sparse_rpc_format):
    r"""
    A helper function creates a sparse_coo_tensor from indices, values, and size.
    Args:
        sparse_rpc_format (list): sparse_coo_tensor represented as a list
    """
    return torch.sparse_coo_tensor(
        sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
    ).coalesce()


def process_bucket_with_remote_server(state, bucket):
    r"""
    Processes a gradient bucket passed by a DDP communication hook
    during .backward(). The method supports processing sparse and dense
    tensors. It records RPC future completion time metric for the trainer.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    cref = state.cref
    tensor = bucket.buffer()
    if not cref.use_cuda_rpc:
        tensor = tensor.cpu()
    sparse = tensor.is_sparse
    if sparse:
        tensor = sparse_tensor_to_rpc_format(tensor)
    b_index = bucket.get_index()
    server_args = [cref.server_rref, state.batch_number, b_index, tensor]
    key = state.get_key(b_index)
    cref.record_start("hook_future_metric", key, RPC_SPARSE if sparse else RPC_DENSE)
    fut = cref.server_rref.rpc_async().average_gradient(*server_args)

    def callback(fut):
        cref.record_end("hook_future_metric", key)
        tensor = fut.wait()
        if type(tensor) is list:
            tensor = sparse_rpc_format_to_tensor(tensor)
        tensor = tensor.cuda(cref.rank)
        return [tensor]

    return fut.then(callback)
