Source code for brainpy.math.sparse.coo_mv
# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Tuple
import brainevent
from jax import numpy as jnp
from brainpy.math.ndarray import Array as Array
__all__ = [
'coomv',
]
[docs]
def coomv(
data: Union[float, jnp.ndarray, Array],
row: Union[jnp.ndarray, Array],
col: Union[jnp.ndarray, Array],
vector: Union[jnp.ndarray, Array],
*,
shape: Tuple[int, int],
transpose: bool = False,
):
"""Product of COO sparse matrix and a dense vector.
The ``brainevent`` COO format was removed in v0.1.0, so the COO indices are
converted to CSR (via :func:`brainevent.coo2csr`) and the multiplication is
delegated to :class:`brainevent.CSR`.
This function supports JAX transformations, including ``jit()``, ``grad()``,
``vmap()`` and ``pmap()``.
Parameters
----------
data : ndarray, float
An array of shape ``(nse,)``.
row : ndarray
An array of shape ``(nse,)``.
col : ndarray
An array of shape ``(nse,)`` and dtype ``row.dtype``.
vector : ndarray
An array of shape ``(shape[0] if transpose else shape[1],)`` and
dtype ``data.dtype``.
shape : tuple of int
The shape of the sparse matrix.
transpose : bool
A boolean specifying whether to transpose the sparse matrix
before computing.
Returns
-------
y : ndarray
An array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
if isinstance(data, Array):
data = data.value
if isinstance(row, Array):
row = row.value
if isinstance(col, Array):
col = col.value
if isinstance(vector, Array):
vector = vector.value
# The COO format was removed in brainevent 0.1.0; convert COO indices to
# CSR before delegating to brainevent.CSR.
indptr, indices, order = brainevent.coo2csr(row, col, shape=shape)
data = jnp.asarray(data)
if data.ndim == 0:
# scalar weight: broadcast to one entry per non-zero
data = jnp.broadcast_to(data, (indices.shape[0],))
data = data[order]
csr = brainevent.CSR((data, indices, indptr), shape=shape)
if transpose:
return vector @ csr
else:
return csr @ vector