Source code for brainpy.math.sparse.utils

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Tuple

from jax import numpy as jnp
from jax.experimental.sparse import csr_todense

from brainpy.math.interoperability import as_jax

__all__ = [
    'coo_to_csr',
    'csr_to_coo',
    'csr_to_dense'
]


[docs] def coo_to_csr( pre_ids: jnp.ndarray, post_ids: jnp.ndarray, *, num_row: int ) -> Tuple[jnp.ndarray, jnp.ndarray]: """convert pre_ids, post_ids to (indices, indptr).""" pre_ids = as_jax(pre_ids) post_ids = as_jax(post_ids) # sorting sort_ids = jnp.argsort(pre_ids, kind='stable') post_ids = post_ids[sort_ids] indices = post_ids unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) final_pre_count = jnp.zeros(num_row) final_pre_count[unique_pre_ids] = pre_count indptr = final_pre_count.cumsum() indptr = jnp.insert(indptr, 0, 0) return indices, indptr
[docs] def csr_to_coo( indices: jnp.ndarray, indptr: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Given CSR (indices, indptr) return COO (row, col)""" indices = as_jax(indices) indptr = as_jax(indptr) return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
csr_to_dense = csr_todense