Source code for brainpy._src.math.sparse.csr_mv

# -*- coding: utf-8 -*-


from typing import Union, Tuple

import jax
from jax import numpy as jnp
from jax.experimental.sparse import csr
from jax.interpreters import ad

import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (register_general_batching, XLACustomOp)
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy.errors import PackageMissingError

ti = import_taichi(error_if_not_found=False)

__all__ = [
  'csrmv',
]


[docs] def csrmv( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], indptr: Union[jnp.ndarray, Array], vector: Union[jnp.ndarray, Array], *, shape: Tuple[int, int], transpose: bool = False, ): """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. Parameters ---------- data: ndarray, float An array of shape ``(nse,)``. indices: ndarray An array of shape ``(nse,)``. indptr: ndarray An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. vector: ndarray An array of shape ``(shape[0] if transpose else shape[1],)`` and dtype ``data.dtype``. shape: tuple of int A length-2 tuple representing the matrix shape. transpose: bool A boolean specifying whether to transpose the sparse matrix before computing. method: str The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. The candidate methods are: - ``None``: default using Taichi kernel. - ``cusparse``: using cuSPARSE library. - ``scalar``: - ``vector``: - ``adaptive``: Returns ------- y : ndarry The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ data = jnp.atleast_1d(as_jax(data)) indices = as_jax(indices) indptr = as_jax(indptr) vector = as_jax(vector) if vector.dtype == jnp.bool_: vector = as_jax(vector, dtype=data.dtype) if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: raise TypeError('Only support float16, float32 or float64 type. ' f'But we got {data.dtype}.') if data.dtype != vector.dtype: raise TypeError('The types of data and vector should be the same. ' f'But we got {data.dtype} != {vector.dtype}.') assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 if not jnp.issubdtype(indices.dtype, jnp.integer): raise ValueError('indices should be a 1D vector with integer type.') if not jnp.issubdtype(indptr.dtype, jnp.integer): raise ValueError('indptr should be a 1D vector with integer type.') # if the shape of indices is (0,), then we return a zero vector if indices.shape[0] == 0: return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0]
def raw_csrmv_taichi( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], indptr: Union[jnp.ndarray, Array], vector: Union[jnp.ndarray, Array], *, shape: Tuple[int, int], transpose: bool = False, ): if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] if data.shape[0] != 1: if bm.get_platform() == 'gpu': return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] else: if transpose: prim = _csr_matvec_transpose_heter_p else: prim = _csr_matvec_heter_p else: if transpose: prim = _csr_matvec_transpose_homo_p else: prim = _csr_matvec_homo_p return prim(data, indices, indptr, vector, outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], transpose=transpose, shape=shape) if ti is not None: # ------------- # CPU operators # ------------- @ti.kernel def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): value = values[0] ti.loop_config(serialize=True) for row_i in range(row_ptr.shape[0] - 1): for j in range(row_ptr[row_i], row_ptr[row_i + 1]): out[col_indices[j]] += value * vector[row_i] @ti.kernel def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): ti.loop_config(serialize=True) for row_i in range(row_ptr.shape[0] - 1): for j in range(row_ptr[row_i], row_ptr[row_i + 1]): out[col_indices[j]] += vector[row_i] * values[j] @ti.kernel def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): value = values[0] # ti.loop_config(serialize=True) for row_i in range(row_ptr.shape[0] - 1): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): r += vector[col_indices[j]] out[row_i] = r * value @ti.kernel def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): # ti.loop_config(serialize=True) for row_i in range(row_ptr.shape[0] - 1): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): r += values[j] * vector[col_indices[j]] out[row_i] = r # ------------- # GPU operators # ------------- @ti.kernel def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): value = values[0] for i in range((row_ptr.shape[0] - 1) * 32): row_i = i >> 5 index = i & 31 j = row_ptr[row_i] + index end_index = row_ptr[row_i + 1] while j < end_index: out[col_indices[j]] += value * vector[row_i] j += 32 @ti.kernel def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): value = values[0] for i in range((row_ptr.shape[0] - 1) * 32): row_i = i >> 5 index = i & 31 r = 0. j = row_ptr[row_i] + index end_index = row_ptr[row_i + 1] while j < end_index: r += vector[col_indices[j]] j += 32 out[row_i] += value * r @ti.kernel def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): for i in range((row_ptr.shape[0] - 1) * 32): row_i = i >> 5 index = i & 31 j = row_ptr[row_i] + index end_index = row_ptr[row_i + 1] while j < end_index: out[col_indices[j]] += values[j] * vector[row_i] j += 32 @ti.kernel def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), vector: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): for i in range((row_ptr.shape[0] - 1) * 32): row_i = i >> 5 index = i & 31 r = 0. j = row_ptr[row_i] + index end_index = row_ptr[row_i + 1] while j < end_index: r += values[j] * vector[col_indices[j]] j += 32 out[row_i] += r # TODO: warp-level primitive def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) def _sparse_csr_matvec_transpose( ct, data, indices, indptr, vector, *, outs, transpose, shape, ): if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(vector): ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) else: if type(ct[0]) is ad.Zero: ct_data = ad.Zero(data) else: if data.aval.shape[0] == 1: # scalar ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] ct_data = jnp.inner(ct[0], ct_data) else: row, col = csr_to_coo(indices, indptr) ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] return ct_data, indices, indptr, vector def _define_op(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) prim.def_transpose_rule(_sparse_csr_matvec_transpose) return prim # transpose homo _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) # no transpose homo _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, gpu_kernel=_sparse_csr_matvec_homo_gpu) # transpose heter _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) # no transpose heter _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, gpu_kernel=_sparse_csr_matvec_heter_gpu) # heter cusparse _csr_matvec_cusparse_p = csr.csr_matvec_p register_general_batching(_csr_matvec_cusparse_p)