Source code for brainpy._src.math.modes
# -*- coding: utf-8 -*-
__all__ = [
'Mode',
'NonBatchingMode',
'BatchingMode',
'TrainingMode',
'nonbatching_mode',
'batching_mode',
'training_mode',
]
[docs]
class Mode(object):
"""Base class for computation Mode
"""
def __repr__(self):
return self.__class__.__name__
def __eq__(self, other: 'Mode'):
if not isinstance(other, Mode):
return False
return other.__class__ == self.__class__
def is_one_of(self, *modes):
for m_ in modes:
if not isinstance(m_, type):
raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}')
return self.__class__ in modes
[docs]
def is_a(self, mode: type):
"""Check whether the mode is exactly the desired mode."""
assert isinstance(mode, type), 'Must be a type.'
return self.__class__ == mode
[docs]
def is_parent_of(self, *modes):
"""Check whether the mode is a parent of the given modes."""
cls = self.__class__
for m_ in modes:
if not isinstance(m_, type):
raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}')
if all([not issubclass(m_, cls) for m_ in modes]):
return False
else:
return True
[docs]
def is_child_of(self, *modes):
"""Check whether the mode is a children of one of the given modes."""
for m_ in modes:
if not isinstance(m_, type):
raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}')
return isinstance(self, modes)
def is_batch_mode(self):
return isinstance(self, BatchingMode)
def is_train_mode(self):
return isinstance(self, TrainingMode)
def is_nonbatch_mode(self):
return isinstance(self, NonBatchingMode)
[docs]
class NonBatchingMode(Mode):
"""Normal non-batching mode.
:py:class:`~.NonBatchingMode` is usually used in models of traditional
computational neuroscience.
"""
pass
@property
def batch_size(self):
return tuple()
[docs]
class BatchingMode(Mode):
"""Batching mode.
:py:class:`~.NonBatchingMode` is usually used in models of model trainings.
"""
def __init__(self, batch_size: int = 1):
self.batch_size = batch_size
def __repr__(self):
return f'{self.__class__.__name__}(batch_size={self.batch_size})'
[docs]
class TrainingMode(BatchingMode):
"""Training mode requires data batching."""
def to_batch_mode(self):
return BatchingMode(self.batch_size)
nonbatching_mode = NonBatchingMode()
'''Default instance of the non-batching computation mode.'''
batching_mode = BatchingMode()
'''Default instance of the batching computation mode.'''
training_mode = TrainingMode()
'''Default instance of the training computation mode.'''