Source code for brainpy._src.math.environment

# -*- coding: utf-8 -*-


import functools
import gc
import inspect
import os
import re
import sys
import warnings
from typing import Any, Callable, TypeVar, cast

import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge

from . import modes
from . import scales
from . import defaults
from .object_transform import naming
from brainpy._src.dependency_check import import_taichi

ti = import_taichi(error_if_not_found=False)

__all__ = [
  # context manage for environment setting
  'environment',
  'batching_environment',
  'training_environment',
  'set_environment',
  'set',

  # default data types
  'set_float', 'get_float',
  'set_int', 'get_int',
  'set_bool', 'get_bool',
  'set_complex', 'get_complex',

  # default numerical integration step
  'set_dt', 'get_dt',

  # default computation modes
  'set_mode', 'get_mode',

  # default membrane_scaling
  'set_membrane_scaling', 'get_membrane_scaling',

  # set jax environments
  'enable_x64', 'disable_x64',
  'set_platform', 'get_platform',
  'set_host_device_count',

  # device memory
  'clear_buffer_memory',
  'enable_gpu_memory_preallocation',
  'disable_gpu_memory_preallocation',

  # deprecated
  'ditype',
  'dftype',

]

# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)


class _DecoratorContextManager:
  """Allow a context manager to be used as a decorator"""

  def __call__(self, func: F) -> F:
    if inspect.isgeneratorfunction(func):
      return self._wrap_generator(func)

    @functools.wraps(func)
    def decorate_context(*args, **kwargs):
      with self.clone():
        return func(*args, **kwargs)

    return cast(F, decorate_context)

  def _wrap_generator(self, func):
    """Wrap each generator invocation with the context manager"""

    @functools.wraps(func)
    def generator_context(*args, **kwargs):
      gen = func(*args, **kwargs)

      # Generators are suspended and unsuspended at `yield`, hence we
      # make sure the grad modes is properly set every time the execution
      # flow returns into the wrapped generator and restored when it
      # returns through our `yield` to our caller (see PR #49017).
      try:
        # Issuing `None` to a generator fires it up
        with self.clone():
          response = gen.send(None)

        while True:
          try:
            # Forward the response to our caller and get its next request
            request = yield response
          except GeneratorExit:
            # Inform the still active generator about its imminent closure
            with self.clone():
              gen.close()
            raise
          except BaseException:
            # Propagate the exception thrown at us by the caller
            with self.clone():
              response = gen.throw(*sys.exc_info())
          else:
            # Pass the last request to the generator and get its response
            with self.clone():
              response = gen.send(request)

      # We let the exceptions raised above by the generator's `.throw` or
      # `.send` methods bubble up to our caller, except for StopIteration
      except StopIteration as e:
        # The generator informed us that it is done: take whatever its
        # returned value (if any) was and indicate that we're done too
        # by returning it (see docs for python's return-statement).
        return e.value

    return generator_context

  def __enter__(self) -> None:
    raise NotImplementedError

  def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
    raise NotImplementedError

  def clone(self):
    # override this method if your children class takes __init__ parameters
    return self.__class__()


[docs] class environment(_DecoratorContextManager): r"""Context-manager that sets a computing environment for brain dynamics computation. In BrainPy, there are several basic computation settings when constructing models, including ``mode`` for controlling model computing behavior, ``dt`` for numerical integration, ``int_`` for integer precision, and ``float_`` for floating precision. :py:class:`~.environment`` provides a context for model construction and computation. In this temporal environment, models are constructed with the given ``mode``, ``dt``, ``int_``, etc., environment settings. For instance:: >>> import brainpy as bp >>> import brainpy.math as bm >>> >>> with bm.environment(mode=bm.training_mode, dt=0.1): >>> lif1 = bp.neurons.LIF(1) >>> >>> with bm.environment(mode=bm.nonbatching_mode, dt=0.05, float_=bm.float64): >>> lif2 = bp.neurons.LIF(1) """ def __init__( self, mode: modes.Mode = None, membrane_scaling: scales.Scaling = None, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ) -> None: super().__init__() if dt is not None: assert isinstance(dt, float), '"dt" must a float.' self.old_dt = get_dt() if mode is not None: assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' self.old_mode = get_mode() if membrane_scaling is not None: assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' self.old_membrane_scaling = get_membrane_scaling() if x64 is not None: assert isinstance(x64, bool), f'"x64" must be a bool.' self.old_x64 = config.read("jax_enable_x64") if float_ is not None: assert isinstance(float_, type), '"float_" must a float.' self.old_float = get_float() if int_ is not None: assert isinstance(int_, type), '"int_" must a type.' self.old_int = get_int() if bool_ is not None: assert isinstance(bool_, type), '"bool_" must a type.' self.old_bool = get_bool() if complex_ is not None: assert isinstance(complex_, type), '"complex_" must a type.' self.old_complex = get_complex() if bp_object_as_pytree is not None: assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' self.old_bp_object_as_pytree = defaults.bp_object_as_pytree if numpy_func_return is not None: assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' assert numpy_func_return in ['bp_array', 'jax_array'], \ f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' self.old_numpy_func_return = defaults.numpy_func_return self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling self.x64 = x64 self.complex_ = complex_ self.float_ = float_ self.int_ = int_ self.bool_ = bool_ self.bp_object_as_pytree = bp_object_as_pytree self.numpy_func_return = numpy_func_return def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) if self.mode is not None: set_mode(self.mode) if self.membrane_scaling is not None: set_membrane_scaling(self.membrane_scaling) if self.x64 is not None: set_x64(self.x64) if self.float_ is not None: set_float(self.float_) if self.int_ is not None: set_int(self.int_) if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.numpy_func_return return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.dt is not None: set_dt(self.old_dt) if self.mode is not None: set_mode(self.old_mode) if self.membrane_scaling is not None: set_membrane_scaling(self.old_membrane_scaling) if self.x64 is not None: set_x64(self.old_x64) if self.int_ is not None: set_int(self.old_int) if self.float_ is not None: set_float(self.old_float) if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.old_numpy_func_return def clone(self): return self.__class__(dt=self.dt, mode=self.mode, membrane_scaling=self.membrane_scaling, x64=self.x64, bool_=self.bool_, complex_=self.complex_, float_=self.float_, int_=self.int_, bp_object_as_pytree=self.bp_object_as_pytree, numpy_func_return=self.numpy_func_return) def __eq__(self, other): return id(self) == id(other)
[docs] class training_environment(environment): """Environment with the training mode. This is a short-cut context setting for an environment with the training mode. It is equivalent to:: >>> import brainpy.math as bm >>> with bm.environment(mode=bm.training_mode): >>> pass """ def __init__( self, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, complex_=complex_, float_=float_, int_=int_, bool_=bool_, membrane_scaling=membrane_scaling, mode=modes.TrainingMode(batch_size), bp_object_as_pytree=bp_object_as_pytree, numpy_func_return=numpy_func_return)
[docs] class batching_environment(environment): """Environment with the batching mode. This is a short-cut context setting for an environment with the batching mode. It is equivalent to:: >>> import brainpy.math as bm >>> with bm.environment(mode=bm.batching_mode): >>> pass """ def __init__( self, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, complex_=complex_, float_=float_, int_=int_, bool_=bool_, mode=modes.BatchingMode(batch_size), membrane_scaling=membrane_scaling, bp_object_as_pytree=bp_object_as_pytree, numpy_func_return=numpy_func_return)
[docs] def set( mode: modes.Mode = None, membrane_scaling: scales.Scaling = None, dt: float = None, x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, numpy_func_return: str = None, ): """Set the default computation environment. Parameters ---------- mode: Mode The computing mode. membrane_scaling: Scaling The numerical membrane_scaling. dt: float The numerical integration precision. x64: bool Enable x64 computation. complex_: type The complex data type. float_ The floating data type. int_ The integer data type. bool_ The bool data type. bp_object_as_pytree: bool Whether to register brainpy object as pytree. numpy_func_return: str The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' set_dt(dt) if mode is not None: assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' set_mode(mode) if membrane_scaling is not None: assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' set_membrane_scaling(membrane_scaling) if x64 is not None: assert isinstance(x64, bool), f'"x64" must be a bool.' set_x64(x64) if float_ is not None: assert isinstance(float_, type), '"float_" must a float.' set_float(float_) if int_ is not None: assert isinstance(int_, type), '"int_" must a type.' set_int(int_) if bool_ is not None: assert isinstance(bool_, type), '"bool_" must a type.' set_bool(bool_) if complex_ is not None: assert isinstance(complex_, type), '"complex_" must a type.' set_complex(complex_) if bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree if numpy_func_return is not None: assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' defaults.__dict__['numpy_func_return'] = numpy_func_return
set_environment = set # default dtype # --------------------------
[docs] def ditype(): """Default int type. .. deprecated:: 2.3.1 Use `brainpy.math.int_` instead. """ # raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n' # 'Use `brainpy.math.int_` instead.') return defaults.int_
[docs] def dftype(): """Default float type. .. deprecated:: 2.3.1 Use `brainpy.math.float_` instead. """ # raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n' # 'Use `brainpy.math.float_` instead.') return defaults.float_
[docs] def set_float(dtype: type): """Set global default float type. Parameters ---------- dtype: type The float type. """ if dtype in [jnp.float16, 'float16', 'f16']: defaults.__dict__['float_'] = jnp.float16 if ti is not None: defaults.__dict__['ti_float'] = ti.float16 elif dtype in [jnp.float32, 'float32', 'f32']: defaults.__dict__['float_'] = jnp.float32 if ti is not None: defaults.__dict__['ti_float'] = ti.float32 elif dtype in [jnp.float64, 'float64', 'f64']: defaults.__dict__['float_'] = jnp.float64 if ti is not None: defaults.__dict__['ti_float'] = ti.float64 else: raise NotImplementedError
[docs] def get_float(): """Get the default float data type. Returns ------- dftype: type The default float data type. """ return defaults.float_
[docs] def set_int(dtype: type): """Set global default integer type. Parameters ---------- dtype: type The integer type. """ if dtype in [jnp.int8, 'int8', 'i8']: defaults.__dict__['int_'] = jnp.int8 if ti is not None: defaults.__dict__['ti_int'] = ti.int8 elif dtype in [jnp.int16, 'int16', 'i16']: defaults.__dict__['int_'] = jnp.int16 if ti is not None: defaults.__dict__['ti_int'] = ti.int16 elif dtype in [jnp.int32, 'int32', 'i32']: defaults.__dict__['int_'] = jnp.int32 if ti is not None: defaults.__dict__['ti_int'] = ti.int32 elif dtype in [jnp.int64, 'int64', 'i64']: defaults.__dict__['int_'] = jnp.int64 if ti is not None: defaults.__dict__['ti_int'] = ti.int64 else: raise NotImplementedError
[docs] def get_int(): """Get the default int data type. Returns ------- dftype: type The default int data type. """ return defaults.int_
[docs] def set_bool(dtype: type): """Set global default boolean type. Parameters ---------- dtype: type The bool type. """ defaults.__dict__['bool_'] = dtype
[docs] def get_bool(): """Get the default boolean data type. Returns ------- dftype: type The default bool data type. """ return defaults.bool_
[docs] def set_complex(dtype: type): """Set global default complex type. Parameters ---------- dtype: type The complex type. """ defaults.__dict__['complex_'] = dtype
[docs] def get_complex(): """Get the default complex data type. Returns ------- dftype: type The default complex data type. """ return defaults.complex_
# numerical precision # --------------------------
[docs] def set_dt(dt): """Set the default numerical integrator precision. Parameters ---------- dt : float Numerical integration precision. """ assert isinstance(dt, float), f'"dt" must a float, but we got {dt}' defaults.__dict__['dt'] = dt
[docs] def get_dt(): """Get the numerical integrator precision. Returns ------- dt : float Numerical integration precision. """ return defaults.dt
[docs] def set_mode(mode: modes.Mode): """Set the default computing mode. Parameters ---------- mode: Mode The instance of :py:class:`~.Mode`. """ if not isinstance(mode, modes.Mode): raise TypeError(f'Must be instance of brainpy.math.Mode. ' f'But we got {type(mode)}: {mode}') defaults.__dict__['mode'] = mode
[docs] def get_mode() -> modes.Mode: """Get the default computing mode. References ---------- mode: Mode The default computing mode. """ return defaults.mode
def set_membrane_scaling(membrane_scaling: scales.Scaling): """Set the default computing membrane_scaling. Parameters ---------- scaling: Scaling The instance of :py:class:`~.Scaling`. """ if not isinstance(membrane_scaling, scales.Scaling): raise TypeError(f'Must be instance of brainpy.math.Scaling. ' f'But we got {type(membrane_scaling)}: {membrane_scaling}') defaults.__dict__['membrane_scaling'] = membrane_scaling def get_membrane_scaling() -> scales.Scaling: """Get the default computing membrane_scaling. Returns ------- membrane_scaling: Scaling The default computing membrane_scaling. """ return defaults.membrane_scaling
[docs] def enable_x64(x64=None): if x64 is None: x64 = True else: warnings.warn( '\n' 'Instead of "brainpy.math.enable_x64(True)", use "brainpy.math.enable_x64()". \n' 'Instead of "brainpy.math.enable_x64(False)", use "brainpy.math.disable_x64()". \n', DeprecationWarning ) if x64: config.update("jax_enable_x64", True) set_int(jnp.int64) set_float(jnp.float64) set_complex(jnp.complex128) else: disable_x64()
[docs] def disable_x64(): config.update("jax_enable_x64", False) set_int(jnp.int32) set_float(jnp.float32) set_complex(jnp.complex64)
def set_x64(enable: bool): assert isinstance(enable, bool) if enable: enable_x64() else: disable_x64()
[docs] def set_platform(platform: str): """ Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program. """ assert platform in ['cpu', 'gpu', 'tpu'] config.update("jax_platform_name", platform)
[docs] def get_platform() -> str: """Get the computing platform. Returns ------- platform: str Either 'cpu', 'gpu' or 'tpu'. """ return devices()[0].platform
[docs] def set_host_device_count(n): """ By default, XLA considers all CPU cores as one device. This utility tells XLA that there are `n` host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform. .. note:: This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where `[num_device]` is the desired number of CPU devices `n`. .. warning:: Our understanding of the side effects of using the `xla_force_host_platform_device_count` flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this `JAX issue <https://github.com/google/jax/issues/1408>`_. :param int n: number of devices to use. """ xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
[docs] def clear_buffer_memory( platform: str = None, array: bool = True, transform: bool = True, compilation: bool = False, object_name: bool = False, ): """Clear all on-device buffers. This function will be very useful when you call models in a Python loop, because it can clear all cached arrays, and clear device memory. .. warning:: This operation may cause errors when you use a deleted buffer. Therefore, regenerate data always. Parameters ---------- platform: str The device to clear its memory. array: bool Clear all buffer array. Default is True. compilation: bool Clear compilation cache. Default is False. transform: bool Clear transform cache. Default is True. object_name: bool Clear name cache. Default is True. """ if array: for buf in xla_bridge.get_backend(platform).live_buffers(): buf.delete() if compilation: jax.clear_caches() if transform: naming.clear_stack_cache() if object_name: naming.clear_name_cache() gc.collect()
[docs] def disable_gpu_memory_preallocation(release_memory: bool = True): """Disable pre-allocating the GPU memory. This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. However, this behavior is more prone to GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled. Args: release_memory: bool. Whether we release memory during the computation. """ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' if release_memory: os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
[docs] def enable_gpu_memory_preallocation(): """Disable pre-allocating the GPU memory.""" os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)
def gpu_memory_preallocation(percent: float): """GPU memory allocation. If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory, instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts. """ assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.' os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)