Source code for brainpy.dnn.linear

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numbers
from typing import Dict, Optional, Union, Callable

import jax
import jax.numpy as jnp
import numpy as np
from brainevent import (
    update_csr_on_binary_pre,
    update_csr_on_binary_post,
    update_dense_on_binary_pre,
    update_dense_on_binary_post,
)

from brainpy import connect, initialize as init
from brainpy import math as bm
from brainpy._errors import MathError
from brainpy.check import is_initializer
from brainpy.connect import csr2csc
from brainpy.context import share
from brainpy.dnn.base import Layer
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.mixin import SupportOnline, SupportOffline, SupportSTDP
from brainpy.types import ArrayType, Sharding

__all__ = [
    'Dense', 'Linear',
    'Identity',
    'AllToAll',
    'OneToOne',
    'MaskedLinear',
    'CSRLinear', 'EventCSRLinear',
    'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear',
    'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear',
]


[docs] class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): r"""A linear transformation applied over the last dimension of the input. Mathematically, this node can be defined as: .. math:: y = x \cdot weight + b Parameters:: num_in: int The number of the input feature. A positive integer. num_out: int The number of the output features. A positive integer. W_initializer: optional, Initializer The weight initialization. b_initializer: optional, Initializer The bias initialization. mode: Mode Enable training this node or not. (default True) """ def __init__( self, num_in: int, num_out: int, W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(), b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(Dense, self).__init__(mode=mode, name=name) # shape self.num_in = num_in self.num_out = num_out if num_in < 0: raise ValueError(f'Received an invalid value for `num_out`, expected ' f'a positive integer. Received: num_in={num_in}') if num_out < 0: raise ValueError(f'Received an invalid value for `num_out`, expected ' f'a positive integer. Received: num_out={num_out}') # weight initializer self.W_initializer = W_initializer self.bias_initializer = b_initializer is_initializer(W_initializer, 'weight_initializer') is_initializer(b_initializer, 'bias_initializer', allow_none=True) # parameter initialization W = parameter(self.W_initializer, (num_in, self.num_out)) b = parameter(self.bias_initializer, (self.num_out,)) if isinstance(self.mode, bm.TrainingMode): W = bm.TrainVar(W) b = None if (b is None) else bm.TrainVar(b) self.W = W self.b = b # fitting parameters self.online_fit_by = None # support online training self.offline_fit_by = None # support offline training self.fit_record = dict() def __repr__(self): return (f'{self.__class__.__name__}(name={self.name}, ' f'num_in={self.num_in}, ' f'num_out={self.num_out}, ' f'mode={self.mode})')
[docs] def update(self, x): x = bm.as_jax(x) res = x @ self.W if self.b is not None: res += self.b # online fitting data if share.load('fit', False) and self.online_fit_by is not None: self.fit_record['input'] = x self.fit_record['output'] = res # offline fitting data if share.load('fit', False) and self.offline_fit_by is not None: self.fit_record['input'] = x self.fit_record['output'] = res return res
def online_init(self): if self.b is None: num_input = self.num_in else: num_input = self.num_in + 1 self.online_fit_by.register_target(feature_in=num_input, identifier=self.name) def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): if not isinstance(target, (bm.ndarray, jnp.ndarray)): raise MathError(f'"target" must be a tensor, but got {type(target)}') x = fit_record['input'] y = fit_record['output'] if x.ndim != 2: raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, ' f'num_feature), but we got {x.shape}') if target.ndim != 2: raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, ' f'num_feature), but we got {target.shape}') if x.shape[0] != target.shape[0]: raise ValueError(f'Batch size of the input and target data should be ' f'the same, while we got {x.shape[0]} != {target.shape[0]}.') if target.shape[1] != y.shape[1]: raise MathError(f'The output dimension of output and target data should be ' f'the same, while we got {target.shape[1]} != {y.shape[1]}') # data if self.b is not None: x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1) # fitting dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name) # assign trained weights if self.b is None: self.W += dW else: db, dW = jnp.split(dW, [1]) self.b += db[0] self.W += dW
[docs] def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): """The offline training interface for the Dense node.""" # data checking if not isinstance(target, (bm.ndarray, jnp.ndarray)): raise MathError(f'"targets" must be a tensor, but got {type(target)}') xs = fit_record['input'] ys = fit_record['output'] if xs.ndim != 3: raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, ' f'num_feature), but we got {xs.shape}') if target.ndim != 3: raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, ' f'num_feature), but we got {target.shape}') if ys.shape != target.shape: raise ValueError(f'The shapes of output and target data should be ' f'the same, while we got {ys.shape} != {target.shape}.') if xs.shape[0] != target.shape[0]: raise ValueError(f'Batch size of the input and target data should be ' f'the same, while we got {xs.shape[0]} != {target.shape[0]}.') if xs.shape[1] != target.shape[1]: raise MathError(f'The time dimension of input and target data should be ' f'the same, while we got {xs.shape[1]} != {target.shape[1]}') # get input and target training data if self.b is not None: xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input) # solve weights by offline training methods weights = self.offline_fit_by(target, xs, ys) # assign trained weights if self.b is None: self.W.value = weights else: bias, Wff = jnp.split(weights, [1]) self.W.value = Wff self.b.value = bias[0]
def stdp_update( self, on_pre: Dict = None, on_post: Dict = None, w_min: numbers.Number = None, w_max: numbers.Number = None ): if not isinstance(self.W, bm.Variable): raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] self.W.value = update_dense_on_binary_pre(self.W.value, spike, trace, w_min, w_max) if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] self.W.value = update_dense_on_binary_post(self.W.value, trace, spike, w_min, w_max)
Linear = Dense
[docs] class Identity(Layer): r"""A placeholder identity operator that is argument-insensitive. """ def __init__(self, *args, **kwargs) -> None: super(Identity, self).__init__(*args, **kwargs)
[docs] def update(self, x): return x
[docs] class AllToAll(Layer, SupportSTDP): """Synaptic matrix multiplication with All2All connections. Args: num_pre: int. The number of neurons in the presynaptic neuron group. num_post: int. The number of neurons in the postsynaptic neuron group. weight: The synaptic weights. sharding: The sharding strategy. include_self: bool. Whether connect the neuron with at the same position. mode: Mode. The computing mode. name: str. The object name. """ def __init__( self, num_pre: int, num_post: int, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, include_self: bool = True, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(mode=mode, name=name) self.num_pre = num_pre self.num_post = num_post self.include_self = include_self self.sharding = sharding weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) if isinstance(self.mode, bm.TrainingMode): weight = bm.TrainVar(weight) self.weight = weight
[docs] def update(self, pre_val): if bm.ndim(self.weight) == 0: # weight is a scalar if isinstance(self.mode, bm.BatchingMode): assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.' post_val = bm.sum(pre_val, keepdims=True, axis=1) else: assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.' post_val = bm.sum(pre_val) if not self.include_self: if self.num_pre == self.num_post: post_val = post_val - pre_val elif self.num_pre > self.num_post: val = pre_val[:self.num_post] post_val = post_val - val else: val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)]) post_val = post_val - val post_val = self.weight * post_val else: # weight is a matrix assert self.weight.ndim == 2, '"weight" must be a 2D matrix.' if not self.include_self: post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) else: post_val = pre_val @ self.weight return post_val
def stdp_update( self, on_pre: Dict = None, on_post: Dict = None, w_min: numbers.Number = None, w_max: numbers.Number = None ): if not isinstance(self.weight, bm.Variable): raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value = update_dense_on_binary_pre(self.weight.value, spike, trace, w_min, w_max) if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] self.weight.value = update_dense_on_binary_post(self.weight.value, trace, spike, w_min, w_max)
[docs] class OneToOne(Layer, SupportSTDP): """Synaptic matrix multiplication with One2One connection. Args: num: int. The number of neurons. weight: The synaptic weight. sharding: The sharding strategy. mode: The computing mode. name: The object name. """ def __init__( self, num: int, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(mode=mode, name=name) self.num = num self.sharding = sharding weight = init.parameter(weight, (self.num,), sharding=sharding) if isinstance(self.mode, bm.TrainingMode): weight = bm.TrainVar(weight) self.weight = weight
[docs] def update(self, pre_val): return pre_val * self.weight
def stdp_update( self, on_pre: Dict = None, on_post: Dict = None, w_min: numbers.Number = None, w_max: numbers.Number = None ): if isinstance(self.weight, float): raise ValueError(f'Cannot update the weight of a constant node.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value += spike * trace if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] self.weight.value += spike * trace
[docs] class MaskedLinear(Layer, SupportSTDP): r"""Synaptic matrix multiplication with masked dense computation. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a dense matrix. >>> import brainpy as bp >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100), >>> weight=0.1) Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. mask_fun: Masking function. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], mask_fun: Callable = Identity(), sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, mode=mode) assert isinstance(conn, connect.TwoEndConnector) self.conn = conn self.sharding = sharding self.mask_fun = mask_fun # weight weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding) if isinstance(self.mode, bm.TrainingMode): weight = bm.TrainVar(weight) self.weight = weight # connection self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding)
[docs] def update(self, x): return x @ self.mask_fun(self.weight * self.mask)
def stdp_update( self, on_pre: Dict = None, on_post: Dict = None, w_min: numbers.Number = None, w_max: numbers.Number = None ): if isinstance(self.weight, float): raise ValueError(f'Cannot update the weight of a constant node.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value = update_dense_on_binary_pre(self.weight.value, spike, trace, w_min, w_max) if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] self.weight.value = update_dense_on_binary_post(self.weight.value, trace, spike, w_min, w_max)
class _CSRLayer(Layer, SupportSTDP): def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = True, ): super().__init__(name=name, mode=mode) assert isinstance(conn, connect.TwoEndConnector) assert sharding is None, 'Currently this model does not support sharding.' self.conn = conn self.sharding = sharding self.transpose = transpose # connection self.indices, self.indptr = self.conn.require('csr') # weight weight = init.parameter(weight, (self.indices.size,)) if isinstance(self.mode, bm.TrainingMode): weight = bm.TrainVar(weight) self.weight = weight def stdp_update( self, on_pre: Dict = None, on_post: Dict = None, w_min: numbers.Number = None, w_max: numbers.Number = None ): if bm.isscalar(self.weight): raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') if self.weight.shape != self.indices.shape: raise ValueError( f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) if on_pre is not None: # update on presynaptic spike spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value = update_csr_on_binary_pre( self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max, shape=(spike.shape[0], trace.shape[0]), ) if on_post is not None: # update on postsynaptic spike if not hasattr(self, '_pre_ids'): with jax.ensure_compile_time_eval(): self._pre_ids, self._post_indptr, self.w_indices = csr2csc( [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) ) spike = on_post['spike'] trace = on_post['trace'] self.weight.value = update_csr_on_binary_post( self.weight.value, self._pre_ids, self._post_indptr, self.w_indices, trace, spike, w_min, w_max, shape=(trace.shape[0], spike.shape[0]), )
[docs] class CSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with CSR sparse computation. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a CSR sparse matrix. Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, method: str = None, transpose: bool = True, ): super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) self.method = method
[docs] def update(self, x): if x.ndim == 1: return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_csrmv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose)
[docs] class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, :math:`M` the synaptic weight using a CSR sparse matrix. Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = True, ): super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
[docs] def update(self, x): if x.ndim == 1: return bm.event.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_csrmv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_csrmv(self, x): return bm.event.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose)
class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a CSC sparse matrix. Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, mode=mode) assert isinstance(conn, connect.TwoEndConnector) self.conn = conn self.sharding = sharding class BcsrMM(Layer): r"""Synaptic matrix multiplication with BCSR sparse computation. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a BCSR sparse matrix. Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, mode=mode) assert isinstance(conn, connect.TwoEndConnector) self.conn = conn self.sharding = sharding class BcscMM(Layer): r"""Synaptic matrix multiplication with BCSC sparse computation. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a BCSC sparse matrix. Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, conn: connect.TwoEndConnector, weight: Union[float, ArrayType, Callable], sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, mode=mode) assert isinstance(conn, connect.TwoEndConnector) self.conn = conn self.sharding = sharding class JitLinear(Layer): def get_conn_matrix(self): pass class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class JitFPUniformLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) class JitFPNormalLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)
[docs] class JitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is the same :math:`weight`. Args: num_in: int. The number of the input feature. A positive integer. num_out: int. The number of the input feature. A positive integer. prob: float. The connectivity probability. weight: float. The synaptic value at each position. seed: int. The random seed used to keep the reproducibility of the connectivity. transpose: bool. Transpose the JIT matrix or not. Default False. atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. May be changed in the future. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, num_in: int, num_out: int, prob: float, weight: float, seed: Optional[int] = None, sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, atomic: bool = False, ): super().__init__(name=name, mode=mode) self.prob = prob self.sharding = sharding self.transpose = transpose self.seed = np.random.randint(0, 100000) if seed is None else seed self.atomic = atomic self.num_in = num_in self.num_out = num_out # weight if isinstance(self.mode, bm.TrainingMode): weight = bm.TrainVar(weight) self.weight = weight
[docs] def update(self, x): if x.ndim == 1: return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) elif x.ndim == 2: return jax.vmap(self._batch_mv)(x) elif x.ndim > 2: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_mv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_mv(self, x): return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)
[docs] class JitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. Args: num_in: int. The number of the input feature. A positive integer. num_out: int. The number of the input feature. A positive integer. prob: float. The connectivity probability. w_low: float. The lowest value of the uniform distribution. w_high: float. The highest value of the uniform distribution. seed: int. The random seed used to keep the reproducibility of the connectivity. transpose: bool. Transpose the JIT matrix or not. Default False. atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. May be changed in the future. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, num_in: int, num_out: int, prob: float, w_low: float, w_high: float, seed: Optional[int] = None, sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, atomic: bool = False, ): super().__init__(name=name, mode=mode) self.prob = prob self.sharding = sharding self.transpose = transpose self.seed = np.random.randint(0, 100000) if seed is None else seed self.atomic = atomic self.num_in = num_in self.num_out = num_out # weight self.w_low = w_low self.w_high = w_high
[docs] def update(self, x): if x.ndim == 1: return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) elif x.ndim == 2: return jax.vmap(self._batch_mv)(x) elif x.ndim > 2: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_mv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_mv(self, x): return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)
[docs] class JitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. Args: num_in: int. The number of the input feature. A positive integer. num_out: int. The number of the input feature. A positive integer. prob: float. The connectivity probability. w_mu: float. The center of the normal distribution. w_sigma: float. The standard variance of the normal distribution. seed: int. The random seed used to keep the reproducibility of the connectivity. transpose: bool. Transpose the JIT matrix or not. Default False. atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. May be changed in the future. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, num_in: int, num_out: int, prob: float, w_mu: float, w_sigma: float, seed: Optional[int] = None, sharding: Optional[Sharding] = None, transpose: bool = False, atomic: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, mode=mode) self.prob = prob self.sharding = sharding self.transpose = transpose self.seed = np.random.randint(0, 100000) if seed is None else seed self.atomic = atomic self.num_in = num_in self.num_out = num_out # weight self.w_mu = w_mu self.w_sigma = w_sigma
[docs] def update(self, x): if x.ndim == 1: return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) elif x.ndim == 2: return jax.vmap(self._batch_mv)(x) elif x.ndim > 2: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_mv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_mv(self, x): return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)
[docs] class EventJitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is the same :math:`weight`. Args: num_in: int. The number of the input feature. A positive integer. num_out: int. The number of the input feature. A positive integer. prob: float. The connectivity probability. weight: float. The synaptic value at each position. seed: int. The random seed used to keep the reproducibility of the connectivity. transpose: bool. Transpose the JIT matrix or not. Default False. atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. May be changed in the future. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, num_in: int, num_out: int, prob: float, weight: float, seed: Optional[int] = None, sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, atomic: bool = True, ): super().__init__(name=name, mode=mode) self.prob = prob self.sharding = sharding self.transpose = transpose self.seed = np.random.randint(0, 1000000) if seed is None else seed self.atomic = atomic self.num_in = num_in self.num_out = num_out # weight if isinstance(self.mode, bm.TrainingMode): weight = bm.TrainVar(weight) self.weight = weight
[docs] def update(self, x): if x.ndim == 1: return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) elif x.ndim == 2: return jax.vmap(self._batch_mv)(x) elif x.ndim > 2: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_mv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_mv(self, x): return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)
[docs] class EventJitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. Args: num_in: int. The number of the input feature. A positive integer. num_out: int. The number of the input feature. A positive integer. prob: float. The connectivity probability. w_low: float. The lowest value of the uniform distribution. w_high: float. The highest value of the uniform distribution. seed: int. The random seed used to keep the reproducibility of the connectivity. transpose: bool. Transpose the JIT matrix or not. Default False. atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. May be changed in the future. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, num_in: int, num_out: int, prob: float, w_low: float, w_high: float, seed: Optional[int] = None, sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, atomic: bool = True, ): super().__init__(name=name, mode=mode) self.prob = prob self.sharding = sharding self.transpose = transpose self.seed = np.random.randint(0, 100000) if seed is None else seed self.atomic = atomic self.num_in = num_in self.num_out = num_out # weight self.w_low = w_low self.w_high = w_high
[docs] def update(self, x): if x.ndim == 1: return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) elif x.ndim == 2: return jax.vmap(self._batch_mv)(x) elif x.ndim > 2: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_mv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_mv(self, x): return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)
[docs] class EventJitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: .. math:: y = x @ M where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. Args: num_in: int. The number of the input feature. A positive integer. num_out: int. The number of the input feature. A positive integer. prob: float. The connectivity probability. w_mu: float. The center of the normal distribution. w_sigma: float. The standard variance of the normal distribution. seed: int. The random seed used to keep the reproducibility of the connectivity. transpose: bool. Transpose the JIT matrix or not. Default False. atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. May be changed in the future. sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ def __init__( self, num_in: int, num_out: int, prob: float, w_mu: float, w_sigma: float, seed: Optional[int] = None, sharding: Optional[Sharding] = None, transpose: bool = False, atomic: bool = True, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, mode=mode) self.prob = prob self.sharding = sharding self.transpose = transpose self.seed = np.random.randint(0, 100000) if seed is None else seed self.atomic = atomic self.num_in = num_in self.num_out = num_out # weight self.w_mu = w_mu self.w_sigma = w_sigma
[docs] def update(self, x): if x.ndim == 1: return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) elif x.ndim == 2: return jax.vmap(self._batch_mv)(x) elif x.ndim > 2: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) y = jax.vmap(self._batch_mv)(x) return bm.reshape(y, shapes + (y.shape[-1],)) else: raise ValueError
def _batch_mv(self, x): return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic)