Source code for brainpy._src.math.sparse.utils

# -*- coding: utf-8 -*-

import warnings
from typing import Tuple

import numpy as np
from jax import core, numpy as jnp
from jax.interpreters import mlir, ad
from jaxlib import gpu_sparse

from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import register_general_batching

__all__ = [
  'coo_to_csr',
  'csr_to_coo',
  'csr_to_dense'
]


[docs] def coo_to_csr( pre_ids: jnp.ndarray, post_ids: jnp.ndarray, *, num_row: int ) -> Tuple[jnp.ndarray, jnp.ndarray]: """convert pre_ids, post_ids to (indices, indptr).""" pre_ids = as_jax(pre_ids) post_ids = as_jax(post_ids) # sorting sort_ids = jnp.argsort(pre_ids, kind='stable') post_ids = post_ids[sort_ids] indices = post_ids unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) final_pre_count = jnp.zeros(num_row) final_pre_count[unique_pre_ids] = pre_count indptr = final_pre_count.cumsum() indptr = jnp.insert(indptr, 0, 0) return indices, indptr
[docs] def csr_to_coo( indices: jnp.ndarray, indptr: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Given CSR (indices, indptr) return COO (row, col)""" indices = as_jax(indices) indptr = as_jax(indptr) return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
def csr_to_csc(): pass def coo_to_dense( data: jnp.ndarray, rows: jnp.ndarray, cols: jnp.ndarray, *, shape: Tuple[int, int] ) -> jnp.ndarray: pass
[docs] def csr_to_dense( data: jnp.ndarray, indices: jnp.ndarray, indptr: jnp.ndarray, *, shape: Tuple[int, int] ) -> jnp.ndarray: data = as_jax(data) indices = as_jax(indices) indptr = as_jax(indptr) return csr_to_dense_p.bind(data, indices, indptr, shape=shape)
def _coo_extract(row, col, mat): """Extract values of dense matrix mat at given COO indices.""" return mat[row, col] def _csr_extract(indices, indptr, mat): """Extract values of dense matrix mat at given CSR indices.""" return _coo_extract(*csr_to_coo(indices, indptr), mat) def _coo_todense(data, row, col, *, shape): """Convert CSR-format sparse matrix to a dense matrix. Args: data : array of shape ``(nse,)``. row : array of shape ``(nse,)`` col : array of shape ``(nse,)`` and dtype ``row.dtype`` shape : COOInfo object containing matrix metadata Returns: mat : array with specified shape and dtype matching ``data`` """ return jnp.zeros(shape, data.dtype).at[row, col].add(data) def _csr_to_dense_impl(data, indices, indptr, *, shape): return _coo_todense(data, *csr_to_coo(indices, indptr), shape=shape) def _csr_to_dense_abstract_eval(data, indices, indptr, *, shape): assert data.ndim == indices.ndim == indptr.ndim == 1 assert indices.dtype == indptr.dtype assert data.shape == indices.shape assert indptr.shape[0] == shape[0] + 1 return core.ShapedArray(shape, data.dtype) _csr_to_dense_lowering = mlir.lower_fun(_csr_to_dense_impl, multiple_results=False) def _csr_to_dense_gpu_lowering(ctx, data, indices, indptr, *, shape): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. " "Falling back to default implementation.", UserWarning) return _csr_to_dense_lowering(ctx, data, indices, indptr, shape=shape) return [gpu_sparse.cuda_csr_todense(data, indices, indptr, shape=shape, data_dtype=dtype, index_dtype=indices_aval.dtype)] def _csr_to_dense_jvp(data_dot, data, indices, indptr, *, shape): return csr_to_dense(data_dot, indices, indptr, shape=shape) def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == shape assert indices.aval.dtype == indptr.aval.dtype assert ct.dtype == data.aval.dtype return _csr_extract(indices, indptr, ct), indices, indptr csr_to_dense_p = core.Primitive('csr_to_dense') csr_to_dense_p.def_impl(_csr_to_dense_impl) csr_to_dense_p.def_abstract_eval(_csr_to_dense_abstract_eval) ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None) ad.primitive_transposes[csr_to_dense_p] = _csr_to_dense_transpose mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering) register_general_batching(csr_to_dense_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda')