# -*- coding: utf-8 -*-
"""
Key points for the operator customization:
1. `index` has two kinds of types: int32, int64
2. `data` has two kinds of types: float32, float64
3. `events` has three kinds of types: bool (True or False), float32, float64
"""
from typing import Union, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from jax.interpreters import ad
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import XLACustomOp
from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy.errors import PackageMissingError
__all__ = [
'csrmv'
]
ti = import_taichi(error_if_not_found=False)
[docs]
def csrmv(
data: Union[float, jax.Array],
indices: jax.Array,
indptr: jax.Array,
events: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False,
) -> jax.Array:
"""Product of a sparse CSR matrix and a dense event vector.
This function supports JAX transformations, including `jit()`, `grad()`,
`vmap()` and `pmap()`.
Parameters
----------
data: ndarray, float
An array of shape ``(nse,)``.
indices: ndarray
An array of shape ``(nse,)``.
indptr: ndarray
An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
events: ndarray
An array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``.
shape: tuple
A length-2 tuple representing the matrix shape.
transpose: bool
A boolean specifying whether to transpose the sparse matrix
before computing.
If ``transpose=True``, the operator will compute based on the
event-driven property of the ``events`` vector.
Returns
-------
y : Array
The array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
data = as_jax(data)
indices = as_jax(indices)
indptr = as_jax(indptr)
events = as_jax(events)
# checking
data = jnp.atleast_1d(data)
if np.ndim(data) == 1:
if data.shape[0] not in [1, indices.shape[0]]:
raise ValueError('The size of data should be 1 or be consistent with indices.'
f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.')
else:
raise ValueError('data should be a scalar or 1D vector. '
f'But we got {np.ndim(data)}-D array.')
if np.ndim(indices) != 1:
raise ValueError('indices should be a 1D vector with integer type.')
if np.ndim(indptr) != 1:
raise ValueError('indptr should be a 1D vector with integer type.')
if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]:
raise ValueError(
'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.')
if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]:
raise ValueError(
'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.')
if np.ndim(events) != 1:
raise ValueError('events should be a 1D vector.')
if len(shape) != 2:
raise ValueError('shape should be a length-2 tuple.')
if transpose:
if events.shape[0] != shape[0]:
raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.')
else:
if events.shape[0] != shape[1]:
raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).')
# if the shape of indices is (0,), then we return a zero vector
if indices.shape[0] == 0:
return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)
return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]
def raw_csrmv_taichi(
data: Union[float, jax.Array],
indices: jax.Array,
indptr: jax.Array,
events: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False
):
if ti is None:
raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators')
if transpose:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_bool_homo_p
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p
# computing
return prim(data,
indices,
indptr,
events,
outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
transpose=transpose,
shape=shape)
if ti is not None:
# -------------
# CPU operators
# -------------
# 1. The benchmarking shows that the performance of the following transpose
# kernels is maximized when using serialized mode
# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable
# arguments, we have to define each kernel separately when the
# non-differentiable/non-jittable arguments are different.
@ti.kernel
def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
if events[row_i]:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value
@ti.kernel
def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
if events[row_i]:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += values[j]
@ti.kernel
def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
if events[row_i] != 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value
@ti.kernel
def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
if events[row_i] != 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += values[j]
@ti.kernel
def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
# ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
r += value
out[row_i] = r
@ti.kernel
def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
# ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
r += values[j]
out[row_i] = r
@ti.kernel
def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
# ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] != 0.:
r += value
out[row_i] = r
@ti.kernel
def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
# ti.loop_config(serialize=True)
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] != 0.:
r += values[j]
out[row_i] = r
# -------------
# GPU operators
# -------------
# 1. GPU kernels are different from the CPU ones, since the GPU kernels need
# to use warp-level parallelism to achieve the best performance.
@ti.kernel
def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
if events[row_i]:
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
out[indices[j]] += value
j += 32
@ti.kernel
def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
if events[row_i] != 0.:
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
out[indices[j]] += value
j += 32
# TODO
# It is important to note that the following warp-based kernels
# should be improved, since the atomic_add for each thread is not
# very efficient. Instead, the warp-level reduction primitive
# should be used.
# see ``warp_reduce_sum()`` function in tifunc.py.
# However, currently Taichi does not support general warp-level primitives.
@ti.kernel
def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
r = 0.
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
if events[indices[j]]:
r += value
j += 32
out[row_i] += r # TODO: warp-level primitive
@ti.kernel
def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
r = 0.
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
if events[indices[j]] != 0.:
r += value
j += 32
out[row_i] += r # TODO: warp-level primitive
@ti.kernel
def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
if events[row_i]:
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
out[indices[j]] += values[j]
j += 32
@ti.kernel
def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
if events[row_i] != 0.:
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
out[indices[j]] += values[j]
j += 32
@ti.kernel
def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
r = 0.
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
if events[indices[j]]:
r += values[j]
j += 32
out[row_i] += r # TODO: warp-level primitive
@ti.kernel
def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for i in range((indptr.shape[0] - 1) * 32):
row_i = i >> 5
index = i & 31
r = 0.
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
if events[indices[j]] != 0.:
r += values[j]
j += 32
out[row_i] += r # TODO: warp-level primitive
def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)
def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape):
return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose)
def _event_csr_matvec_transpose_taichi(
ct, values, indices, indptr, events, *, outs, transpose, shape
):
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(events):
ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0]
return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events)
else:
if type(ct[0]) is ad.Zero:
ct_values = ad.Zero(values)
else:
if values.aval.shape[0] == 1: # scalar
ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
return ct_values, indices, indptr, events
def _define_op(cpu_kernel, gpu_kernel):
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi)
prim.def_transpose_rule(_event_csr_matvec_transpose_taichi)
return prim
# transpose bool homo
_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu,
_event_csr_matvec_transpose_bool_homo_gpu)
# transpose homo
_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu,
_event_csr_matvec_transpose_homo_gpu)
# not transpose bool homo
_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu,
_event_csr_matvec_bool_homo_gpu)
# not transpose homo
_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu,
_event_csr_matvec_homo_gpu)
# transpose bool heter
_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu,
_event_csr_matvec_transpose_bool_heter_gpu)
# transpose heter
_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu,
_event_csr_matvec_transpose_heter_gpu)
# not transpose bool heter
_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu,
_event_csr_matvec_bool_heter_gpu)
# not transpose heter
_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu,
_event_csr_matvec_heter_gpu)