#pragma once

#include <ATen/Tensor.h>
#include <ATen/core/Scalar.h>

namespace at::native::sparse::impl {

TORCH_API Tensor& _compressed_row_strided_mm_out(
    const Tensor& compressed_row_sparse,
    const Tensor& strided,
    Tensor& result);

TORCH_API Tensor& _compressed_row_strided_addmm_out(
    const Tensor& self,
    const Tensor& mat1,
    const Tensor& mat2,
    const Scalar& beta,
    const Scalar& alpha,
    Tensor& result);

namespace cpu {

void addmv_out_sparse_csr(
    const Tensor& mat,
    const Tensor& vec,
    const Scalar& beta,
    const Scalar& alpha,
    const Tensor& result);

void add_out_sparse_csr(
    const Tensor& mat1,
    const Tensor& mat2,
    const Scalar& alpha,
    const Tensor& result);

void triangular_solve_out_sparse_csr(
    const Tensor& A,
    const Tensor& B,
    const Tensor& X,
    bool upper,
    bool transpose,
    bool unitriangular);

} // namespace cpu
} // namespace at::native::sparse::impl
