Source code for brainpy.math.environment

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import functools
import gc
import inspect
import os
import re
import sys
import warnings
from typing import Any, Callable, TypeVar, cast

import brainstate.environ
import jax
from jax import config, numpy as jnp, devices

from . import modes
from . import scales
from .defaults import defaults
from .object_transform import naming

__all__ = [
    # context manage for environment setting
    'environment',
    'batching_environment',
    'training_environment',
    'set_environment',
    'set',

    # default data types
    'set_float', 'get_float',
    'set_int', 'get_int',
    'set_bool', 'get_bool',
    'set_complex', 'get_complex',

    # default numerical integration step
    'set_dt', 'get_dt',

    # default computation modes
    'set_mode', 'get_mode',

    # default membrane_scaling
    'set_membrane_scaling', 'get_membrane_scaling',

    # set jax environments
    'enable_x64', 'disable_x64',
    'set_platform', 'get_platform',
    'set_host_device_count',

    # device memory
    'clear_buffer_memory',
    'enable_gpu_memory_preallocation',
    'disable_gpu_memory_preallocation',

    # deprecated
    'ditype',
    'dftype',

]

# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)


class _DecoratorContextManager:
    """Allow a context manager to be used as a decorator"""

    def __call__(self, func: F) -> F:
        if inspect.isgeneratorfunction(func):
            return self._wrap_generator(func)

        @functools.wraps(func)
        def decorate_context(*args, **kwargs):
            with self.clone():
                return func(*args, **kwargs)

        return cast(F, decorate_context)

    def _wrap_generator(self, func):
        """Wrap each generator invocation with the context manager"""

        @functools.wraps(func)
        def generator_context(*args, **kwargs):
            gen = func(*args, **kwargs)

            # Generators are suspended and unsuspended at `yield`, hence we
            # make sure the grad modes is properly set every time the execution
            # flow returns into the wrapped generator and restored when it
            # returns through our `yield` to our caller (see PR #49017).
            try:
                # Issuing `None` to a generator fires it up
                with self.clone():
                    response = gen.send(None)

                while True:
                    try:
                        # Forward the response to our caller and get its next request
                        request = yield response
                    except GeneratorExit:
                        # Inform the still active generator about its imminent closure
                        with self.clone():
                            gen.close()
                        raise
                    except BaseException:
                        # Propagate the exception thrown at us by the caller
                        with self.clone():
                            response = gen.throw(*sys.exc_info())
                    else:
                        # Pass the last request to the generator and get its response
                        with self.clone():
                            response = gen.send(request)

            # We let the exceptions raised above by the generator's `.throw` or
            # `.send` methods bubble up to our caller, except for StopIteration
            except StopIteration as e:
                # The generator informed us that it is done: take whatever its
                # returned value (if any) was and indicate that we're done too
                # by returning it (see docs for python's return-statement).
                return e.value

        return generator_context

    def __enter__(self) -> None:
        raise NotImplementedError

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        raise NotImplementedError

    def clone(self):
        # override this method if your children class takes __init__ parameters
        return self.__class__()


[docs] class environment(_DecoratorContextManager): r"""Context-manager that sets a computing environment for brain dynamics computation. In BrainPy, there are several basic computation settings when constructing models, including ``mode`` for controlling model computing behavior, ``dt`` for numerical integration, ``int_`` for integer precision, and ``float_`` for floating precision. :py:class:`~.environment`` provides a context for model construction and computation. In this temporal environment, models are constructed with the given ``mode``, ``dt``, ``int_``, etc., environment settings. For instance:: >>> import brainpy as bp >>> import brainpy.math as bm >>> >>> with bm.environment(mode=bm.training_mode, dt=0.1): >>> lif1 = bp.neurons.LIF(1) >>> >>> with bm.environment(mode=bm.nonbatching_mode, dt=0.05, float_=bm.float64): >>> lif2 = bp.neurons.LIF(1) """ def __init__( self, mode: modes.Mode = None, membrane_scaling: scales.Scaling = None, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ) -> None: super().__init__() if dt is not None: assert isinstance(dt, float), '"dt" must a float.' self.old_dt = get_dt() if mode is not None: assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' self.old_mode = get_mode() if membrane_scaling is not None: assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' self.old_membrane_scaling = get_membrane_scaling() if x64 is not None: assert isinstance(x64, bool), f'"x64" must be a bool.' self.old_x64 = config.read("jax_enable_x64") if float_ is not None: assert isinstance(float_, type), '"float_" must a float.' self.old_float = get_float() if int_ is not None: assert isinstance(int_, type), '"int_" must a type.' self.old_int = get_int() if bool_ is not None: assert isinstance(bool_, type), '"bool_" must a type.' self.old_bool = get_bool() if complex_ is not None: assert isinstance(complex_, type), '"complex_" must a type.' self.old_complex = get_complex() if bp_object_as_pytree is not None: assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' self.old_bp_object_as_pytree = defaults.bp_object_as_pytree if numpy_func_return is not None: assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' assert numpy_func_return in ['bp_array', 'jax_array'], \ f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' self.old_numpy_func_return = defaults.numpy_func_return self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling self.x64 = x64 self.complex_ = complex_ self.float_ = float_ self.int_ = int_ self.bool_ = bool_ self.bp_object_as_pytree = bp_object_as_pytree self.numpy_func_return = numpy_func_return def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) if self.mode is not None: set_mode(self.mode) if self.membrane_scaling is not None: set_membrane_scaling(self.membrane_scaling) if self.x64 is not None: set_x64(self.x64) if self.float_ is not None: set_float(self.float_) if self.int_ is not None: set_int(self.int_) if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) if self.bp_object_as_pytree is not None: defaults.bp_object_as_pytree = self.bp_object_as_pytree if self.numpy_func_return is not None: defaults.numpy_func_return = self.numpy_func_return return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.dt is not None: set_dt(self.old_dt) if self.mode is not None: set_mode(self.old_mode) if self.membrane_scaling is not None: set_membrane_scaling(self.old_membrane_scaling) if self.x64 is not None: set_x64(self.old_x64) if self.int_ is not None: set_int(self.old_int) if self.float_ is not None: set_float(self.old_float) if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) if self.bp_object_as_pytree is not None: defaults.bp_object_as_pytree = self.old_bp_object_as_pytree if self.numpy_func_return is not None: defaults.numpy_func_return = self.old_numpy_func_return def clone(self): return self.__class__(dt=self.dt, mode=self.mode, membrane_scaling=self.membrane_scaling, x64=self.x64, bool_=self.bool_, complex_=self.complex_, float_=self.float_, int_=self.int_, bp_object_as_pytree=self.bp_object_as_pytree, numpy_func_return=self.numpy_func_return) def __eq__(self, other): return id(self) == id(other)
[docs] class training_environment(environment): """Environment with the training mode. This is a short-cut context setting for an environment with the training mode. It is equivalent to:: >>> import brainpy.math as bm >>> with bm.environment(mode=bm.training_mode): >>> pass """ def __init__( self, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, complex_=complex_, float_=float_, int_=int_, bool_=bool_, membrane_scaling=membrane_scaling, mode=modes.TrainingMode(batch_size), bp_object_as_pytree=bp_object_as_pytree, numpy_func_return=numpy_func_return)
[docs] class batching_environment(environment): """Environment with the batching mode. This is a short-cut context setting for an environment with the batching mode. It is equivalent to:: >>> import brainpy.math as bm >>> with bm.environment(mode=bm.batching_mode): >>> pass """ def __init__( self, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, complex_=complex_, float_=float_, int_=int_, bool_=bool_, mode=modes.BatchingMode(batch_size), membrane_scaling=membrane_scaling, bp_object_as_pytree=bp_object_as_pytree, numpy_func_return=numpy_func_return)
[docs] def set( mode: modes.Mode = None, membrane_scaling: scales.Scaling = None, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ): """Set the default computation environment. Parameters:: mode: Mode The computing mode. membrane_scaling: Scaling The numerical membrane_scaling. dt: float The numerical integration precision. x64: bool Enable x64 computation. complex_: type The complex data type. float_ The floating data type. int_ The integer data type. bool_ The bool data type. bp_object_as_pytree: bool Whether to register brainpy object as pytree. numpy_func_return: str The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' set_dt(dt) if mode is not None: assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' set_mode(mode) if membrane_scaling is not None: assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' set_membrane_scaling(membrane_scaling) if x64 is not None: assert isinstance(x64, bool), f'"x64" must be a bool.' set_x64(x64) if float_ is not None: assert isinstance(float_, type), '"float_" must a float.' set_float(float_) if int_ is not None: assert isinstance(int_, type), '"int_" must a type.' set_int(int_) if bool_ is not None: assert isinstance(bool_, type), '"bool_" must a type.' set_bool(bool_) if complex_ is not None: assert isinstance(complex_, type), '"complex_" must a type.' set_complex(complex_) if bp_object_as_pytree is not None: defaults.bp_object_as_pytree = bp_object_as_pytree if numpy_func_return is not None: assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' defaults.numpy_func_return = numpy_func_return
set_environment = set # default dtype # --------------------------
[docs] def ditype(): """Default int type. .. deprecated:: 2.3.1 Use `brainpy.math.int_` instead. """ # raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n' # 'Use `brainpy.math.int_` instead.') return defaults.int_
[docs] def dftype(): """Default float type. .. deprecated:: 2.3.1 Use `brainpy.math.float_` instead. """ # raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n' # 'Use `brainpy.math.float_` instead.') return defaults.float_
[docs] def set_float(dtype: type): """Set global default float type. Parameters:: dtype: type The float type. """ defaults.float_ = dtype
[docs] def get_float(): """Get the default float data type. Returns:: dftype: type The default float data type. """ return defaults.float_
[docs] def set_int(dtype: type): """Set global default integer type. Parameters:: dtype: type The integer type. """ defaults.int_ = dtype
[docs] def get_int(): """Get the default int data type. Returns:: dftype: type The default int data type. """ return defaults.int_
[docs] def set_bool(dtype: type): """Set global default boolean type. Parameters:: dtype: type The bool type. """ defaults.bool_ = dtype
[docs] def get_bool(): """Get the default boolean data type. Returns:: dftype: type The default bool data type. """ return defaults.bool_
[docs] def set_complex(dtype: type): """Set global default complex type. Parameters:: dtype: type The complex type. """ defaults.complex_ = dtype
[docs] def get_complex(): """Get the default complex data type. Returns:: dftype: type The default complex data type. """ return defaults.complex_
# numerical precision # --------------------------
[docs] def set_dt(dt): """Set the default numerical integrator precision. Parameters:: dt : float Numerical integration precision. """ assert isinstance(dt, float), f'"dt" must a float, but we got {dt}' defaults.dt = dt
[docs] def get_dt(): """Get the numerical integrator precision. Returns:: dt : float Numerical integration precision. """ return defaults.dt
[docs] def set_mode(mode: modes.Mode): """Set the default computing mode. Parameters:: mode: Mode The instance of :py:class:`~.Mode`. """ if not isinstance(mode, modes.Mode): raise TypeError(f'Must be instance of brainpy.math.Mode. ' f'But we got {type(mode)}: {mode}') defaults.mode = mode
[docs] def get_mode() -> modes.Mode: """Get the default computing mode. References:: mode: Mode The default computing mode. """ return defaults.mode
def set_membrane_scaling(membrane_scaling: scales.Scaling): """Set the default computing membrane_scaling. Parameters:: scaling: Scaling The instance of :py:class:`~.Scaling`. """ if not isinstance(membrane_scaling, scales.Scaling): raise TypeError(f'Must be instance of brainpy.math.Scaling. ' f'But we got {type(membrane_scaling)}: {membrane_scaling}') defaults.membrane_scaling = membrane_scaling def get_membrane_scaling() -> scales.Scaling: """Get the default computing membrane_scaling. Returns:: membrane_scaling: Scaling The default computing membrane_scaling. """ return defaults.membrane_scaling
[docs] def enable_x64(x64=None): if x64 is None: x64 = True else: warnings.warn( '\n' 'Instead of "brainpy.math.enable_x64(True)", use "brainpy.math.enable_x64()". \n' 'Instead of "brainpy.math.enable_x64(False)", use "brainpy.math.disable_x64()". \n', DeprecationWarning ) if x64: brainstate.environ.set(precision=64) set_int(jnp.int64) set_float(jnp.float64) set_complex(jnp.complex128) else: brainstate.environ.set(precision=32) disable_x64()
[docs] def disable_x64(): config.update("jax_enable_x64", False) set_int(jnp.int32) set_float(jnp.float32) set_complex(jnp.complex64)
def set_x64(enable: bool): assert isinstance(enable, bool) if enable: enable_x64() else: disable_x64()
[docs] def set_platform(platform: str): """ Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program. """ assert platform in ['cpu', 'gpu', 'tpu'] config.update("jax_platform_name", platform)
[docs] def get_platform() -> str: """Get the computing platform. Returns:: platform: str Either 'cpu', 'gpu' or 'tpu'. """ return devices()[0].platform
[docs] def set_host_device_count(n): """ By default, XLA considers all CPU cores as one device. This utility tells XLA that there are `n` host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform. .. note:: This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where `[num_device]` is the desired number of CPU devices `n`. .. warning:: Our understanding of the side effects of using the `xla_force_host_platform_device_count` flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this `JAX issue <https://github.com/google/jax/issues/1408>`_. :param int n: number of devices to use. """ xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
[docs] def clear_buffer_memory( platform: str = None, array: bool = True, transform: bool = True, compilation: bool = False, object_name: bool = False, ): """Clear all on-device buffers. This function will be very useful when you call models in a Python loop, because it can clear all cached arrays, and clear device memory. .. warning:: This operation may cause errors when you use a deleted buffer. Therefore, regenerate data always. Parameters:: platform: str The device to clear its memory. array: bool Clear all buffer array. Default is True. compilation: bool Clear compilation cache. Default is False. transform: bool Clear transform cache. Default is True. object_name: bool Clear name cache. Default is True. """ from brainstate._compatible_import import get_backend if array: for buf in get_backend(platform).live_buffers(): buf.delete() if compilation: jax.clear_caches() if transform: naming.clear_stack_cache() if object_name: naming.clear_name_cache() gc.collect()
[docs] def disable_gpu_memory_preallocation(release_memory: bool = True): """Disable pre-allocating the GPU memory. This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. However, this behavior is more prone to GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled. Args: release_memory: bool. Whether we release memory during the computation. """ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' if release_memory: os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
[docs] def enable_gpu_memory_preallocation(): """Disable pre-allocating the GPU memory.""" os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)
def gpu_memory_preallocation(percent: float): """GPU memory allocation. If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory, instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts. """ assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.' os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)