# -*- coding: utf-8 -*-
import inspect
import math
import time
import warnings
from typing import Callable, Union, Dict, Sequence, Tuple
import jax.numpy as jnp
import numpy as np
import jax
from jax.scipy.optimize import minimize
from jax.tree_util import tree_flatten, tree_map
import brainpy._src.math as bm
from brainpy import optim, losses
from brainpy._src.analysis import utils, base, constants
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
from brainpy._src.helpers import clear_input
from brainpy._src.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.types import ArrayType
from brainpy._src.deprecations import _input_deprecate_msg
__all__ = [
'SlowPointFinder',
]
F_OPT_SOLVER = 'function_for_opt_solver'
F_GRADIENT_DESCENT = 'function_for_gradient_descent'
SUPPORTED_OPT_SOLVERS = {
'BFGS': lambda f, x0: minimize(f, x0, method='BFGS')
}
[docs]
class SlowPointFinder(base.DSAnalyzer):
"""Find fixed/slow points by numerical optimization.
This class can help you:
- optimize to find the closest fixed points / slow points
- exclude any fixed points whose fixed point loss is above threshold
- exclude any non-unique fixed points according to a tolerance
- exclude any far-away "outlier" fixed points
Parameters
----------
f_cell : callable, function, DynamicalSystem
The target of computing the recurrent units.
f_type : str
The system's type: continuous system or discrete system.
- 'continuous': continuous derivative function, denotes this is a continuous system, or
- 'discrete': discrete update function, denotes this is a discrete system.
verbose : bool
Whether output the optimization progress.
f_loss: callable
The loss function.
- If ``f_type`` is `"discrete"`, the loss function must receive three
arguments, i.e., ``loss(outputs, targets, axis)``.
- If ``f_type`` is `"continuous"`, the loss function must receive two
arguments, i.e., ``loss(outputs, axis)``.
.. versionadded:: 2.2.0
t: float
Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`.
The time to evaluate the fixed points. Default is 0.
.. versionadded:: 2.2.0
dt: float
Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`.
The numerical integration step, which can be used when .
The default is given by `brainpy.math.get_dt()`.
.. versionadded:: 2.2.0
inputs: sequence, callable
Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`.
Same as ``inputs`` in :py:class:`~.DSRunner`.
.. versionadded:: 2.2.0
excluded_vars: sequence, dict
Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`.
The excluded variables (can be a sequence of `Variable` instances).
These variables will not be included for optimization of fixed points.
.. versionadded:: 2.2.0
target_vars: dict
Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`.
The target variables (can be a dict of `Variable` instances).
These variables will be included for optimization of fixed points.
The candidate points later provided should have same keys as in ``target_vars``.
.. versionadded:: 2.2.0
f_loss_batch : callable, function
Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`.
The function to compute the loss.
.. deprecated:: 2.2.0
Has been removed. Please use ``f_loss`` to set different loss function.
fun_inputs: callable
.. deprecated:: 2.3.1
Will be removed since version 2.4.0.
"""
def __init__(
self,
f_cell: Union[Callable, DynamicalSystem],
f_type: str = None,
f_loss: Callable = None,
verbose: bool = True,
args: Tuple = (),
# parameters for `f_cell` is DynamicalSystem instance
inputs: Sequence = None,
t: float = None,
dt: float = None,
target_vars: Dict[str, bm.Variable] = None,
excluded_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
# deprecated
f_loss_batch: Callable = None,
fun_inputs: Callable = None,
):
super().__init__()
# static arguments
if not isinstance(args, tuple):
raise ValueError(f'args must be an instance of tuple, but we got {type(args)}')
self.args = args
# update function
if target_vars is None:
self.target_vars = bm.ArrayCollector()
else:
if not isinstance(target_vars, dict):
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
self.target_vars = bm.ArrayCollector(target_vars)
excluded_vars = () if excluded_vars is None else excluded_vars
if isinstance(excluded_vars, dict):
excluded_vars = tuple(excluded_vars.values())
if not isinstance(excluded_vars, (tuple, list)):
raise TypeError(f'"excluded_vars" must be a sequence but we got {type(excluded_vars)}')
for v in excluded_vars:
if not isinstance(v, bm.Variable):
raise TypeError(f'"excluded_vars" must be a sequence of Variable, '
f'but we got {type(v)}')
self.excluded_vars = {f'_exclude_v{i}': v for i, v in enumerate(excluded_vars)}
if len(self.target_vars) > 0 and len(self.excluded_vars) > 0:
raise ValueError('"target_vars" and "excluded_vars" cannot be provided simultaneously.')
self.target = f_cell
if isinstance(f_cell, DynamicalSystem):
# included variables
all_vars = f_cell.vars(method='relative', level=-1, include_self=True).unique()
# exclude variables
if len(self.target_vars) > 0:
_all_ids = [id(v) for v in self.target_vars.values()]
for k, v in all_vars.items():
if id(v) not in _all_ids:
self.excluded_vars[k] = v
else:
self.target_vars = all_vars
if len(excluded_vars):
excluded_vars = [id(v) for v in excluded_vars]
for key, val in tuple(self.target_vars.items()):
if id(val) in excluded_vars:
self.target_vars.pop(key)
# input function
if callable(inputs):
self._inputs = inputs
else:
if inputs is None:
self._inputs = None
else:
self._inputs = check_and_format_inputs(host=self.target, inputs=inputs)
# check included variables
for var in self.target_vars.values():
if var.batch_axis is not None:
if var.shape[var.batch_axis] != 1:
raise ValueError(f'Batched variables should has only one batch. '
f'But we got {var.shape[var.batch_axis]}. Maybe '
f'you need to call ".reset_state(batch_size=1)" '
f'for your system.')
# update function
self.f_cell = self._generate_ds_cell_function(self.target, t, dt)
# check function type
if f_type is not None:
if f_type != constants.DISCRETE:
raise ValueError(f'"f_type" must be "{constants.DISCRETE}" when "f_cell" '
f'is instance of {DynamicalSystem.__name__}')
f_type = constants.DISCRETE
# original data
self.target_data = {k: v.value for k, v in self.target_vars.items()}
self.excluded_data = {k: v.value for k, v in self.excluded_vars.items()}
elif callable(f_cell):
if len(self.args) > 0:
self.f_cell = lambda x: f_cell(x, *self.args)
else:
self.f_cell = f_cell
if inputs is not None:
raise UnsupportedError('Do not support "inputs" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
self._inputs = inputs
if t is not None:
raise UnsupportedError('Do not support "t" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if dt is not None:
raise UnsupportedError('Do not support "dt" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if target_vars is not None:
raise UnsupportedError('Do not support "target_vars" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if len(excluded_vars) > 0:
raise UnsupportedError('Do not support "excluded_vars" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
else:
raise ValueError(f'Unknown type of "f_type": {type(f_cell)}')
if f_type not in [constants.DISCRETE, constants.CONTINUOUS]:
raise AnalyzerError(f'Only support "{constants.CONTINUOUS}" (continuous derivative function) or '
f'"{constants.DISCRETE}" (discrete update function), not {f_type}.')
self.verbose = verbose
self.f_type = f_type
# loss functon
if f_loss_batch is not None:
raise UnsupportedError('"f_loss_batch" is no longer supported, please '
'use "f_loss" instead.')
if fun_inputs is not None:
raise UnsupportedError('"fun_inputs" is no longer supported.')
if f_loss is None:
f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square
self.f_loss = f_loss
# essential variables
self._losses = None
self._fixed_points = None
self._selected_ids = None
self._opt_losses = None
# functions
self._opt_functions = dict()
@property
def opt_losses(self) -> np.ndarray:
"""The optimization losses."""
return np.asarray(self._opt_losses)
@opt_losses.setter
def opt_losses(self, val):
raise UnsupportedError('Do not support set "opt_losses" by users.')
@property
def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""The final fixed points found."""
return tree_map(lambda a: np.asarray(a), self._fixed_points)
@fixed_points.setter
def fixed_points(self, val):
raise UnsupportedError('Do not support set "fixed_points" by users.')
@property
def num_fps(self) -> int:
if isinstance(self._fixed_points, dict):
return tuple(self._fixed_points.values())[0].shape[0]
else:
return self._fixed_points.shape[0]
@property
def losses(self) -> np.ndarray:
"""Losses of fixed points."""
return np.asarray(self._losses)
@losses.setter
def losses(self, val):
raise UnsupportedError('Do not support set "losses" by users.')
@property
def selected_ids(self) -> np.ndarray:
"""The selected ids of candidate points."""
return np.asarray(self._selected_ids)
@selected_ids.setter
def selected_ids(self, val):
raise UnsupportedError('Do not support set "selected_ids" by users.')
[docs]
def find_fps_with_gd_method(
self,
candidates: Union[ArrayType, Dict[str, ArrayType]],
tolerance: Union[float, Dict[str, float]] = 1e-5,
num_batch: int = 100,
num_opt: int = 10000,
optimizer: optim.Optimizer = None,
):
"""Optimize fixed points with gradient descent methods.
Parameters
----------
candidates : ArrayType, dict
The array with the shape of (batch size, state dim) of hidden states
of RNN to start training for fixed points.
tolerance: float
The loss threshold during optimization
num_opt : int
The maximum number of optimization.
num_batch : int
Print training information during optimization every so often.
optimizer: optim.Optimizer
The optimizer instance.
.. versionadded:: 2.1.2
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
raise ValueError(f'Must be an instance of {optim.Optimizer.__name__}, '
f'while we got {type(optimizer)}')
# set up optimization
num_candidate = self._check_candidates(candidates)
if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)):
raise ValueError('Candidates must be instance of ArrayType or dict of ArrayType.')
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.Array))
f_eval_loss = self._get_f_eval_loss()
def f_loss():
return f_eval_loss(tree_map(lambda a: bm.as_jax(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.Array))).mean()
grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True)
optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
def train(idx):
gradients, loss = grad_f()
optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients})
optimizer.lr.step_epoch()
return loss
def batch_train(start_i, n_batch):
return bm.for_loop(train, bm.arange(start_i, start_i + n_batch))
# Run the optimization
if self.verbose:
print(f"Optimizing with {optimizer} to find fixed points:")
opt_losses = []
do_stop = False
num_opt_loops = int(num_opt / num_batch)
for oidx in range(num_opt_loops):
if do_stop:
break
batch_idx_start = oidx * num_batch
start_time = time.time()
train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch)
batch_time = time.time() - start_time
opt_losses.append(train_losses)
if self.verbose:
print(f" "
f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} "
f"in {batch_time:0.2f} sec, Training loss {train_losses[-1]:0.10f}")
if train_losses[-1] < tolerance:
do_stop = True
if self.verbose:
print(f' '
f'Stop optimization as mean training loss {train_losses[-1]:0.10f} '
f'is below tolerance {tolerance:0.10f}.')
self._opt_losses = jnp.concatenate(opt_losses)
self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.Array)))
self._fixed_points = tree_map(lambda a: bm.as_jax(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.Array))
self._selected_ids = jnp.arange(num_candidate)
if isinstance(self.target, DynamicalSystem):
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
[docs]
def find_fps_with_opt_solver(
self,
candidates: Union[ArrayType, Dict[str, ArrayType]],
opt_solver: str = 'BFGS'
):
"""Optimize fixed points with nonlinear optimization solvers.
Parameters
----------
candidates: ArrayType, dict
The candidate (initial) fixed points.
opt_solver: str
The solver of the optimization.
"""
# optimization function
num_candidate = self._check_candidates(candidates)
for var in self.target_vars.values():
if jnp.ndim(var) != 1:
raise ValueError('Cannot use opt solver.')
if self._opt_functions.get(F_OPT_SOLVER, None) is None:
self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver(candidates, SUPPORTED_OPT_SOLVERS[opt_solver])
f_opt = self._opt_functions[F_OPT_SOLVER]
if self.verbose:
print(f"Optimizing with {opt_solver} to find fixed points:")
# optimizing
res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.Array)))
# results
valid_ids = jnp.where(res.success)[0]
fixed_points = res.x[valid_ids]
if isinstance(candidates, dict):
indices = [0]
for v in candidates.values():
indices.append(v.shape[1])
indices = np.cumsum(indices)
keys = tuple(candidates.keys())
self._fixed_points = {key: fixed_points[:, indices[i]: indices[i + 1]]
for i, key in enumerate(keys)}
else:
self._fixed_points = fixed_points
self._losses = res.fun[valid_ids]
self._selected_ids = jnp.asarray(valid_ids)
if self.verbose:
print(f' '
f'Found {len(valid_ids)} fixed points from {num_candidate} initial points.')
[docs]
def filter_loss(self, tolerance: float = 1e-5):
"""Filter fixed points whose speed larger than a given tolerance.
Parameters
----------
tolerance: float
Discard fixed points with squared speed larger than this value.
"""
if self.verbose:
print(f"Excluding fixed points with squared speed above "
f"tolerance {tolerance}:")
if isinstance(self._fixed_points, dict):
num_fps = tuple(self._fixed_points.values())[0].shape[0]
else:
num_fps = self._fixed_points.shape[0]
ids = self._losses < tolerance
keep_ids = bm.as_jax(bm.where(ids)[0])
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
if self.verbose:
print(f" "
f"Kept {len(keep_ids)}/{num_fps} "
f"fixed points with tolerance under {tolerance}.")
[docs]
def keep_unique(self, tolerance: float = 2.5e-2):
"""Filter unique fixed points by choosing a representative within tolerance.
Parameters
----------
tolerance: float
Tolerance for determination of identical fixed points.
"""
if self.verbose:
print("Excluding non-unique fixed points:")
if isinstance(self._fixed_points, dict):
num_fps = tuple(self._fixed_points.values())[0].shape[0]
else:
num_fps = self._fixed_points.shape[0]
fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance)
self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
if self.verbose:
print(f" Kept {keep_ids.shape[0]}/{num_fps} unique fixed points "
f"with uniqueness tolerance {tolerance}.")
[docs]
def exclude_outliers(self, tolerance: float = 1e0):
"""Exclude points whose closest neighbor is further than threshold.
Parameters
----------
tolerance: float
Any point whose closest fixed point is greater than tol is an outlier.
"""
if self.verbose:
print("Excluding outliers:")
if np.isinf(tolerance):
return
if isinstance(self._fixed_points, dict):
num_fps = tuple(self._fixed_points.values())[0].shape[0]
else:
num_fps = self._fixed_points.shape[0]
if num_fps <= 1:
return
# Compute pairwise distances between all fixed points.
distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps))
# Find the second smallest element in each column of the pairwise distance matrix.
# This corresponds to the closest neighbor for each fixed point.
closest_neighbor = np.partition(distances, kth=1, axis=0)[1]
# Return data with outliers removed and indices of kept datapoints.
keep_ids = np.where(closest_neighbor < tolerance)[0]
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._selected_ids = self._selected_ids[keep_ids]
self._losses = self._losses[keep_ids]
if self.verbose:
print(f" "
f"Kept {keep_ids.shape[0]}/{num_fps} fixed points "
f"with within outlier tolerance {tolerance}.")
[docs]
def compute_jacobians(
self,
points: Union[ArrayType, Dict[str, ArrayType]],
stack_dict_var: bool = True,
plot: bool = False,
num_col: int = 4,
len_col: int = 3,
len_row: int = 2,
):
"""Compute the Jacobian matrices at the points.
Parameters
----------
points: np.ndarray, bm.ArrayType, jax.ndarray
The fixed points with the shape of (num_point, num_dim).
stack_dict_var: bool
Stack dictionary variables to calculate Jacobian matrix?
plot: bool
Plot the decomposition results of the Jacobian matrix.
num_col: int
The number of the figure column.
len_col: int
The length of each column.
len_row: int
The length of each row.
"""
# check data
info = np.asarray([(l.ndim, l.shape[0])
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.Array))[0]])
ndim = np.unique(info[:, 0])
if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}')
if ndim[0] == 1:
points = tree_map(lambda a: bm.asarray([a]), points)
num_point = 1
elif ndim[0] == 2:
nsize = np.unique(info[:, 1])
if len(nsize) != 1: raise ValueError(f'Number of the evaluated points are mis-matched. {nsize}')
num_point = nsize[0]
else:
raise ValueError('Only support points of 1D: (num_feature,) or 2D: (num_point, num_feature)')
if isinstance(points, dict) and stack_dict_var:
points = jnp.hstack(tuple(points.values()))
# get Jacobian matrix
jacobian = self._get_f_jocabian(stack_dict_var)(points)
# visualization
if plot:
import matplotlib.pyplot as plt
from brainpy import visualize
jacobian = bm.as_numpy(jacobian)
num_col = min(num_col, num_point)
num_row = int(math.ceil(num_point / num_col))
fig, gs = visualize.get_figure(num_row, num_col, len_row, len_col)
for i in range(num_point):
eigval, eigvec = np.linalg.eig(np.asarray(jacobian[i]))
ax = fig.add_subplot(gs[i // num_col, i % num_col])
ax.scatter(np.real(eigval), np.imag(eigval))
ax.plot([1, 1] if self.f_type == constants.DISCRETE else [0, 0], [-1, 1], '--')
ax.set_xlabel('Real')
ax.set_ylabel('Imaginary')
ax.set_title(f'Point {i}')
plt.show()
return jacobian
[docs]
@staticmethod
def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False):
"""Compute the eigenvalues of the matrices.
Parameters
----------
matrices: np.ndarray, bm.ArrayType, jax.ndarray
A 3D array with the shape of (num_matrices, dim, dim).
sort_by: str
The method of sorting.
do_compute_lefts: bool
Compute the left eigenvectors? Requires a pseudo-inverse call.
Returns
-------
decompositions : list
A list of dictionaries with sorted eigenvalues components:
(eigenvalues, right eigenvectors, and left eigenvectors).
"""
if sort_by == 'magnitude':
sort_fun = np.abs
elif sort_by == 'real':
sort_fun = np.real
else:
raise ValueError("Not implemented yet.")
matrices = np.asarray(matrices)
decompositions = []
for mat in matrices:
eig_values, eig_vectors = np.linalg.eig(mat)
indices = np.flipud(np.argsort(sort_fun(eig_values)))
L = None
if do_compute_lefts:
L = np.linalg.pinv(eig_vectors).T # as columns
L = L[:, indices]
decompositions.append({'eig_values': eig_values[indices],
'R': eig_vectors[:, indices],
'L': L})
return decompositions
def _step_func_input(self):
if self._inputs is None:
return
elif callable(self._inputs):
try:
ba = inspect.signature(self._inputs).bind(dict())
self._inputs(share.get_shargs())
warnings.warn(_input_deprecate_msg, UserWarning)
except TypeError:
self._inputs()
else:
for ops, values in self._inputs['fixed'].items():
for var, data in values:
_f_ops(ops, var, data)
for ops, values in self._inputs['array'].items():
if len(values) > 0:
raise UnsupportedError
for ops, values in self._inputs['functional'].items():
for var, data in values:
_f_ops(ops, var, data(share.get_shargs()))
for ops, values in self._inputs['iterated'].items():
if len(values) > 0:
raise UnsupportedError
def _get_f_eval_loss(self, ):
name = 'f_eval_loss'
if name not in self._opt_functions:
self._opt_functions[name] = self._generate_f_eval_loss()
return self._opt_functions[name]
def _generate_f_eval_loss(self):
# evaluate losses of a batch of inputs
if self.f_type == constants.DISCRETE:
f_eval_loss = lambda h: self.f_loss(h, jax.vmap(self.f_cell)(h), axis=1)
else:
f_eval_loss = lambda h: self.f_loss(jax.vmap(self.f_cell)(h), axis=1)
if isinstance(self.target, DynamicalSystem):
@jax.jit
def loss_func(h):
r = f_eval_loss(h)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
return r
return loss_func
else:
return jax.jit(f_eval_loss)
def _get_f_for_opt_solver(self, candidates, opt_method):
# loss function
if self.f_type == constants.DISCRETE:
# overall loss function for fixed points optimization
if isinstance(candidates, dict):
keys = tuple(self.target_vars.keys())
indices = [0]
for v in self.target_vars.values():
indices.append(v.shape[0])
indices = np.cumsum(indices)
def f_loss(h):
h = {key: h[indices[i]: indices[i + 1]] for i, key in enumerate(keys)}
return bm.as_jax(self.f_loss(h, self.f_cell(h)))
else:
def f_loss(h):
return bm.as_jax(self.f_loss(h, self.f_cell(h)))
else:
# overall loss function for fixed points optimization
def f_loss(h):
return self.f_loss(self.f_cell(h))
@jax.jit
@jax.vmap
def f_opt(x0):
for k, v in self.target_vars.items():
v.value = x0[k] if (v.batch_axis is None) else jnp.expand_dims(x0[k], axis=v.batch_axis)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
if isinstance(x0, dict):
x0 = jnp.concatenate(tuple(x0.values()))
return opt_method(f_loss, x0)
def call_opt(x):
r = f_opt(x)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
return r
return call_opt if isinstance(self.target, DynamicalSystem) else f_opt
def _generate_ds_cell_function(
self, target,
t: float = None,
dt: float = None,
):
if dt is None: dt = bm.get_dt()
if t is None: t = 0.
def f_cell(h: Dict):
share.save(t=t, i=0, dt=dt)
# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
if v.batch_axis is None else
bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype))
# update excluded variables
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
# add inputs
clear_input(target)
self._step_func_input()
# call update functions
target(*self.args)
# get new states
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
for k, v in self.target_vars.items()}
return new_h
return f_cell
def _get_f_jocabian(self, stack=True):
name = f'f_eval_jacobian_stack={stack}'
if name not in self._opt_functions:
self._opt_functions[name] = self._generate_ds_jocabian(stack)
return self._opt_functions[name]
def _generate_ds_jocabian(self, stack=True):
if stack and isinstance(self.target, DynamicalSystem):
indices = [0]
for var in self.target_vars.values():
shape = list(var.shape)
if var.batch_axis is not None:
shape.pop(var.batch_axis)
indices.append(np.prod(shape))
indices = np.cumsum(indices)
def jacob(x0):
x0 = {k: x0[indices[i]:indices[i + 1]] for i, k in enumerate(self.target_vars.keys())}
r = self.f_cell(x0)
return jnp.concatenate(list(r.values()))
else:
jacob = self.f_cell
f_jac = jax.jit(jax.vmap(bm.jacobian(jacob)))
if isinstance(self.target, DynamicalSystem):
def jacobian_func(x):
r = f_jac(x)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
return r
return jacobian_func
else:
return f_jac
def _check_candidates(self, candidates):
if isinstance(self.target, DynamicalSystem):
if not isinstance(candidates, dict):
raise ValueError(f'When "f_cell" is instance of {DynamicalSystem.__name__}, '
f'we should provide "candidates" as a dict, in which the key is '
f'the variable name with relative path, and the value '
f'is the candidate fixed point values. ')
for key in candidates:
if key not in self.target_vars:
raise KeyError(f'"{key}" is not defined in required variables '
f'for fixed point optimization of {self.target}. '
f'Please do not provide its initial values.')
for key in self.target_vars.keys():
if key not in candidates:
raise KeyError(f'"{key}" is defined in required variables '
f'for fixed point optimization of {self.target}. '
f'Please provide its initial values.')
for key, value in candidates.items():
if self.target_vars[key].batch_axis is None:
if value.ndim != self.target_vars[key].ndim + 1:
raise ValueError(f'"{key}" is defined in the required variables for fixed '
f'point optimization of {self.target}. \n'
f'We expect the provided candidate has a batch size, '
f'but we got {value.shape} for variable with shape of '
f'{self.target_vars[key].shape}')
else:
if value.ndim != self.target_vars[key].ndim:
raise ValueError(f'"{key}" is defined in the required variables for fixed '
f'point optimization of {self.target}. \n'
f'We expect the provided candidate has a batch size, '
f'but we got {value.shape} for variable with shape of '
f'{self.target_vars[key].shape}')
if isinstance(candidates, dict):
num_candidate = np.unique([leaf.shape[0] for leaf in candidates.values()])
if len(num_candidate) != 1:
raise ValueError('The numbers of candidates for each variable should be the same. '
f'But we got {num_candidate}')
num_candidate = num_candidate[0]
else:
num_candidate = candidates.shape[0]
return num_candidate