# -*- coding: utf-8 -*-
import functools
import inspect
import os
import re
import sys
import warnings
from typing import Any, Callable, TypeVar, cast
import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge
from . import modes
from . import scales
bm = None
__all__ = [
# context manage for environment setting
'environment',
'batching_environment',
'training_environment',
'set_environment',
'set',
# default data types
'set_float', 'get_float',
'set_int', 'get_int',
'set_bool', 'get_bool',
'set_complex', 'get_complex',
# default numerical integration step
'set_dt', 'get_dt',
# default computation modes
'set_mode', 'get_mode',
# default membrane_scaling
'set_membrane_scaling', 'get_membrane_scaling',
# set jax environments
'enable_x64', 'disable_x64',
'set_platform', 'get_platform',
'set_host_device_count',
# device memory
'clear_buffer_memory',
'enable_gpu_memory_preallocation',
'disable_gpu_memory_preallocation',
# deprecated
'ditype',
'dftype',
]
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func: F) -> F:
if inspect.isgeneratorfunction(func):
return self._wrap_generator(func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self.clone():
return func(*args, **kwargs)
return cast(F, decorate_context)
def _wrap_generator(self, func):
"""Wrap each generator invocation with the context manager"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad modes is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
try:
# Issuing `None` to a generator fires it up
with self.clone():
response = gen.send(None)
while True:
try:
# Forward the response to our caller and get its next request
request = yield response
except GeneratorExit:
# Inform the still active generator about its imminent closure
with self.clone():
gen.close()
raise
except BaseException:
# Propagate the exception thrown at us by the caller
with self.clone():
response = gen.throw(*sys.exc_info())
else:
# Pass the last request to the generator and get its response
with self.clone():
response = gen.send(request)
# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
return generator_context
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()
[docs]class environment(_DecoratorContextManager):
r"""Context-manager that sets a computing environment for brain dynamics computation.
In BrainPy, there are several basic computation settings when constructing models,
including ``mode`` for controlling model computing behavior, ``dt`` for numerical
integration, ``int_`` for integer precision, and ``float_`` for floating precision.
:py:class:`~.environment`` provides a context for model construction and
computation. In this temporal environment, models are constructed with the given
``mode``, ``dt``, ``int_``, etc., environment settings.
For instance::
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> with bm.environment(mode=bm.training_mode, dt=0.1):
>>> lif1 = bp.neurons.LIF(1)
>>>
>>> with bm.environment(mode=bm.nonbatching_mode, dt=0.05, float_=bm.float64):
>>> lif2 = bp.neurons.LIF(1)
"""
def __init__(
self,
mode: modes.Mode = None,
membrane_scaling: scales.Scaling = None,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None,
) -> None:
super().__init__()
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
self.old_dt = get_dt()
if mode is not None:
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
self.old_mode = get_mode()
if membrane_scaling is not None:
assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
self.old_membrane_scaling = get_membrane_scaling()
if x64 is not None:
assert isinstance(x64, bool), f'"x64" must be a bool.'
self.old_x64 = config.read("jax_enable_x64")
if float_ is not None:
assert isinstance(float_, type), '"float_" must a float.'
self.old_float = get_float()
if int_ is not None:
assert isinstance(int_, type), '"int_" must a type.'
self.old_int = get_int()
if bool_ is not None:
assert isinstance(bool_, type), '"bool_" must a type.'
self.old_bool = get_bool()
if complex_ is not None:
assert isinstance(complex_, type), '"complex_" must a type.'
self.old_complex = get_complex()
self.dt = dt
self.mode = mode
self.membrane_scaling = membrane_scaling
self.x64 = x64
self.complex_ = complex_
self.float_ = float_
self.int_ = int_
self.bool_ = bool_
def __enter__(self) -> 'environment':
if self.dt is not None: set_dt(self.dt)
if self.mode is not None: set_mode(self.mode)
if self.membrane_scaling is not None: set_membrane_scaling(self.membrane_scaling)
if self.x64 is not None: set_x64(self.x64)
if self.float_ is not None: set_float(self.float_)
if self.int_ is not None: set_int(self.int_)
if self.complex_ is not None: set_complex(self.complex_)
if self.bool_ is not None: set_bool(self.bool_)
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.dt is not None: set_dt(self.old_dt)
if self.mode is not None: set_mode(self.old_mode)
if self.membrane_scaling is not None: set_membrane_scaling(self.old_membrane_scaling)
if self.x64 is not None: set_x64(self.old_x64)
if self.int_ is not None: set_int(self.old_int)
if self.float_ is not None: set_float(self.old_float)
if self.complex_ is not None: set_complex(self.old_complex)
if self.bool_ is not None: set_bool(self.old_bool)
def clone(self):
return self.__class__(dt=self.dt,
mode=self.mode,
membrane_scaling=self.membrane_scaling,
x64=self.x64,
bool_=self.bool_,
complex_=self.complex_,
float_=self.float_,
int_=self.int_)
def __eq__(self, other):
return id(self) == id(other)
[docs]class training_environment(environment):
"""Environment with the training mode.
This is a short-cut context setting for an environment with the training mode.
It is equivalent to::
>>> import brainpy.math as bm
>>> with bm.environment(mode=bm.training_mode):
>>> pass
"""
def __init__(
self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None,
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
):
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
membrane_scaling=membrane_scaling,
mode=modes.TrainingMode(batch_size))
[docs]class batching_environment(environment):
"""Environment with the batching mode.
This is a short-cut context setting for an environment with the batching mode.
It is equivalent to::
>>> import brainpy.math as bm
>>> with bm.environment(mode=bm.batching_mode):
>>> pass
"""
def __init__(
self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None,
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
):
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
mode=modes.BatchingMode(batch_size),
membrane_scaling=membrane_scaling)
[docs]def set(
mode: modes.Mode = None,
membrane_scaling: scales.Scaling = None,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None,
):
"""Set the default computation environment.
Parameters
----------
mode: Mode
The computing mode.
membrane_scaling: Scaling
The numerical membrane_scaling.
dt: float
The numerical integration precision.
x64: bool
Enable x64 computation.
complex_: type
The complex data type.
float_
The floating data type.
int_
The integer data type.
bool_
The bool data type.
"""
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
set_dt(dt)
if mode is not None:
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
set_mode(mode)
if membrane_scaling is not None:
assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
set_membrane_scaling(membrane_scaling)
if x64 is not None:
assert isinstance(x64, bool), f'"x64" must be a bool.'
set_x64(x64)
if float_ is not None:
assert isinstance(float_, type), '"float_" must a float.'
set_float(float_)
if int_ is not None:
assert isinstance(int_, type), '"int_" must a type.'
set_int(int_)
if bool_ is not None:
assert isinstance(bool_, type), '"bool_" must a type.'
set_bool(bool_)
if complex_ is not None:
assert isinstance(complex_, type), '"complex_" must a type.'
set_complex(complex_)
set_environment = set
# default dtype
# --------------------------
[docs]def ditype():
"""Default int type.
.. deprecated:: 2.3.1
Use `brainpy.math.int_` instead.
"""
# raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n'
# 'Use `brainpy.math.int_` instead.')
global bm
if bm is None: from brainpy import math as bm
return bm.int_
[docs]def dftype():
"""Default float type.
.. deprecated:: 2.3.1
Use `brainpy.math.float_` instead.
"""
# raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n'
# 'Use `brainpy.math.float_` instead.')
global bm
if bm is None: from brainpy import math as bm
return bm.float_
[docs]def set_float(dtype: type):
"""Set global default float type.
Parameters
----------
dtype: type
The float type.
"""
if dtype not in [jnp.float16, jnp.float32, jnp.float64, ]:
raise TypeError(f'Float data type {dtype} is not supported.')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['float_'] = dtype
[docs]def get_float():
"""Get the default float data type.
Returns
-------
dftype: type
The default float data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.float_
[docs]def set_int(dtype: type):
"""Set global default integer type.
Parameters
----------
dtype: type
The integer type.
"""
if dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, ]:
raise TypeError(f'Integer data type {dtype} is not supported.')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['int_'] = dtype
[docs]def get_int():
"""Get the default int data type.
Returns
-------
dftype: type
The default int data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.int_
[docs]def set_bool(dtype: type):
"""Set global default boolean type.
Parameters
----------
dtype: type
The bool type.
"""
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['bool_'] = dtype
[docs]def get_bool():
"""Get the default boolean data type.
Returns
-------
dftype: type
The default bool data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.bool_
[docs]def set_complex(dtype: type):
"""Set global default complex type.
Parameters
----------
dtype: type
The complex type.
"""
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['complex_'] = dtype
[docs]def get_complex():
"""Get the default complex data type.
Returns
-------
dftype: type
The default complex data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.complex_
# numerical precision
# --------------------------
[docs]def set_dt(dt):
"""Set the default numerical integrator precision.
Parameters
----------
dt : float
Numerical integration precision.
"""
assert isinstance(dt, float), f'"dt" must a float, but we got {dt}'
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['dt'] = dt
[docs]def get_dt():
"""Get the numerical integrator precision.
Returns
-------
dt : float
Numerical integration precision.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.dt
[docs]def set_mode(mode: modes.Mode):
"""Set the default computing mode.
Parameters
----------
mode: Mode
The instance of :py:class:`~.Mode`.
"""
if not isinstance(mode, modes.Mode):
raise TypeError(f'Must be instance of brainpy.math.Mode. '
f'But we got {type(mode)}: {mode}')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['mode'] = mode
[docs]def get_mode() -> modes.Mode:
"""Get the default computing mode.
References
----------
mode: Mode
The default computing mode.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.mode
def set_membrane_scaling(membrane_scaling: scales.Scaling):
"""Set the default computing membrane_scaling.
Parameters
----------
scaling: Scaling
The instance of :py:class:`~.Scaling`.
"""
if not isinstance(membrane_scaling, scales.Scaling):
raise TypeError(f'Must be instance of brainpy.math.Scaling. '
f'But we got {type(membrane_scaling)}: {membrane_scaling}')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['membrane_scaling'] = membrane_scaling
def get_membrane_scaling() -> scales.Scaling:
"""Get the default computing membrane_scaling.
Returns
-------
membrane_scaling: Scaling
The default computing membrane_scaling.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.membrane_scaling
[docs]def enable_x64(x64=None):
if x64 is None:
x64 = True
else:
warnings.warn(
'\n'
'Instead of "brainpy.math.enable_x64(True)", use "brainpy.math.enable_x64()". \n'
'Instead of "brainpy.math.enable_x64(False)", use "brainpy.math.disable_x64()". \n',
DeprecationWarning
)
if x64:
config.update("jax_enable_x64", True)
set_int(jnp.int64)
set_float(jnp.float64)
set_complex(jnp.complex128)
else:
disable_x64()
[docs]def disable_x64():
config.update("jax_enable_x64", False)
set_int(jnp.int32)
set_float(jnp.float32)
set_complex(jnp.complex64)
def set_x64(enable: bool):
assert isinstance(enable, bool)
if enable:
enable_x64()
else:
disable_x64()
[docs]def set_host_device_count(n):
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA
that there are `n` host (CPU) devices available to use. As a consequence, this
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
.. note:: This utility only takes effect at the beginning of your program.
Under the hood, this sets the environment variable
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
`[num_device]` is the desired number of CPU devices `n`.
.. warning:: Our understanding of the side effects of using the
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
observe some strange phenomenon when using this utility, please let us
know through our issue or forum page. More information is available in this
`JAX issue <https://github.com/google/jax/issues/1408>`_.
:param int n: number of devices to use.
"""
xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
[docs]def clear_buffer_memory(
platform: str = None,
array: bool = True,
compilation: bool = False
):
"""Clear all on-device buffers.
This function will be very useful when you call models in a Python loop,
because it can clear all cached arrays, and clear device memory.
.. warning::
This operation may cause errors when you use a deleted buffer.
Therefore, regenerate data always.
Parameters
----------
platform: str
The device to clear its memory.
array: bool
Clear all buffer array.
compilation: bool
Clear compilation cache.
"""
if array:
for buf in xla_bridge.get_backend(platform=platform).live_buffers():
buf.delete()
if compilation:
jax.clear_caches()
[docs]def disable_gpu_memory_preallocation(release_memory: bool = True):
"""Disable pre-allocating the GPU memory.
This disables the preallocation behavior. JAX will instead allocate GPU memory as needed,
potentially decreasing the overall memory usage. However, this behavior is more prone to
GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory
may OOM with preallocation disabled.
Args:
release_memory: bool. Whether we release memory during the computation.
"""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
if release_memory:
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
[docs]def enable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)
def gpu_memory_preallocation(percent: float):
"""GPU memory allocation.
If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory,
instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
"""
assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)