Source code for brainpy.math.sparse.utils

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Tuple

import brainevent
from jax import numpy as jnp

from brainpy.math.interoperability import as_jax

__all__ = [
    'coo_to_csr',
    'csr_to_coo',
    'csr_to_dense'
]


[docs] def coo_to_csr( pre_ids: jnp.ndarray, post_ids: jnp.ndarray, *, num_row: int ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Convert COO ``(pre_ids, post_ids)`` connectivity to CSR ``(indices, indptr)``. Parameters ---------- pre_ids : ndarray Row (pre-synaptic) index of each non-zero entry. Every value must be in ``[0, num_row)``. post_ids : ndarray Column (post-synaptic) index of each non-zero entry, aligned with ``pre_ids``. num_row : int Number of rows of the sparse matrix (``shape[0]``). Returns ------- indices : ndarray CSR column indices of shape ``(nse,)``. indptr : ndarray CSR row pointers of shape ``(num_row + 1,)`` and dtype ``int32``. Raises ------ ValueError If any ``pre_ids`` falls outside ``[0, num_row)``. Such an entry would otherwise be silently dropped from ``indptr`` (its scatter index is out-of-bounds), producing a structurally invalid CSR in which ``indptr[-1] != len(indices)``. Notes ----- This is an eager preprocessing helper: it relies on ``jnp.unique`` (whose output size is data-dependent) and therefore cannot be traced under ``jit``/``vmap``. """ pre_ids = as_jax(pre_ids) post_ids = as_jax(post_ids) # Validate the pre (row) indices eagerly. An out-of-range ``pre_id`` would be # silently dropped by the out-of-bounds ``.at[].set`` scatter below, yielding # a corrupt CSR (``indptr[-1] != nse``) instead of an error. ``coo_to_csr`` # already cannot be ``jit``-traced (``jnp.unique``), so this concrete check # does not regress any JAX transformation behaviour. if pre_ids.size > 0: pre_min = int(jnp.min(pre_ids)) pre_max = int(jnp.max(pre_ids)) if pre_min < 0 or pre_max >= num_row: raise ValueError( f'"pre_ids" must lie in [0, num_row) = [0, {num_row}), ' f'but got values in [{pre_min}, {pre_max}].' ) # sorting sort_ids = jnp.argsort(pre_ids, stable=True) post_ids = post_ids[sort_ids] indices = post_ids unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) final_pre_count = jnp.zeros(num_row, dtype=jnp.int32) final_pre_count = final_pre_count.at[unique_pre_ids].set(pre_count) indptr = final_pre_count.cumsum() indptr = jnp.insert(indptr, 0, 0).astype(jnp.int32) return indices, indptr
[docs] def csr_to_coo( indices: jnp.ndarray, indptr: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Given CSR (indices, indptr) return COO (row, col)""" indices = as_jax(indices) indptr = as_jax(indptr) return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
[docs] def csr_to_dense(data, indices, indptr, *, shape): """Convert a CSR sparse matrix to a dense array. Parameters ---------- data : ndarray An array of shape ``(nse,)`` holding the non-zero values. indices : ndarray An array of shape ``(nse,)`` holding the column index of each value. indptr : ndarray An array of shape ``(shape[0] + 1,)`` holding the row pointers. shape : tuple of int A length-2 tuple ``(n_rows, n_cols)`` for the dense matrix. Returns ------- dense : ndarray The dense matrix of shape ``shape``. """ return brainevent.CSR((data, indices, indptr), shape=shape).todense()