# -*- coding: utf-8 -*-
import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap
from copy import deepcopy
import brainpy.math as bm
from brainpy import errors, math
from brainpy._src.analysis import stability, plotstyle, constants as C, utils
from brainpy._src.analysis.lowdim.lowdim_analyzer import *
pyplot = None
__all__ = [
'PhasePlane1D',
'PhasePlane2D',
]
[docs]
class PhasePlane1D(Num1DAnalyzer):
"""Phase plane analyzer for 1D dynamical system.
This class can help users fast check:
- Vector fields
- Fixed points
Parameters
----------
model : Any
A model of the population, the integrator function,
or a list/tuple of integrator functions.
target_vars : dict
The target/dynamical variables.
fixed_vars : dict
The fixed variables.
target_pars : dict, optional
The parameters which can be dynamical varied.
pars_update : dict, optional
The parameters to update.
resolutions : float, dict
"""
def __init__(self,
model,
target_vars,
fixed_vars=None,
target_pars=None,
pars_update=None,
resolutions=None,
**kwargs):
if (target_pars is not None) and len(target_pars) > 0:
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
f'While we detect "target_pars={target_pars}".')
super().__init__(model=model,
target_vars=target_vars,
fixed_vars=fixed_vars,
target_pars=target_pars,
pars_update=pars_update,
resolutions=resolutions,
**kwargs)
# utils.output(f'I am {PhasePlane1D.__name__}.')
[docs]
def plot_vector_field(self, show=False, with_plot=True, with_return=False):
"""Plot the vector filed."""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am creating the vector field ...')
# Nullcline of the x variable
y_val = self.F_fx(self.resolutions[self.x_var])
y_val = np.asarray(y_val)
# visualization
if with_plot:
label = f"d{self.x_var}dt"
x_style = dict(color='lightcoral', alpha=.7, linewidth=4)
pyplot.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label)
pyplot.axhline(0)
pyplot.xlabel(self.x_var)
pyplot.ylabel(label)
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2))
pyplot.legend()
if show: pyplot.show()
# return
if with_return:
return y_val
[docs]
def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
"""Plot the fixed point."""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am searching fixed points ...')
# fixed points and stability analysis
fps, _, pars = self._get_fixed_points(self.resolutions[self.x_var])
container = {a: [] for a in stability.get_1d_stability_types()}
for i in range(len(fps)):
x = fps[i]
dfdx = self.F_dfxdx(x)
fp_type = stability.stability_analysis(dfdx)
utils.output(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
container[fp_type].append(x)
# visualization
if with_plot:
for fp_type, points in container.items():
if len(points):
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
# return
if with_return:
return fps
[docs]
class PhasePlane2D(Num2DAnalyzer):
"""Phase plane analyzer for 2D dynamical system.
Parameters
----------
model : Any
A model of the population, the integrator function,
or a list/tuple of integrator functions.
target_vars : dict
The target/dynamical variables.
fixed_vars : dict
The fixed variables.
target_pars : dict, optional
The parameters which can be dynamical varied.
pars_update : dict, optional
The parameters to update.
resolutions : float, dict
"""
def __init__(self,
model,
target_vars,
fixed_vars=None,
target_pars=None,
pars_update=None,
resolutions=None,
**kwargs):
if (target_pars is not None) and len(target_pars) > 0:
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
f'While we detect "target_pars={target_pars}".')
super().__init__(model=model,
target_vars=target_vars,
fixed_vars=fixed_vars,
target_pars=target_pars,
pars_update=pars_update,
resolutions=resolutions,
**kwargs)
@property
def F_vmap_brentq_fy(self):
if C.F_vmap_brentq_fy not in self.analyzed_results:
f_opt = jax.jit(vmap(utils.jax_brentq(self.F_fy)))
self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
return self.analyzed_results[C.F_vmap_brentq_fy]
[docs]
def plot_vector_field(self, with_plot=True, with_return=False,
plot_method='streamplot', plot_style=None, show=False):
"""Plot the vector field.
Parameters
----------
with_plot: bool
with_return : bool
show : bool
plot_method : str
The method to plot the vector filed. It can be "streamplot" or "quiver".
plot_style : dict, optional
The style for vector filed plotting.
- For ``plot_method="streamplot"``, it can set the keywords like "density",
"linewidth", "color", "arrowsize". More settings please check
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html.
- For ``plot_method="quiver"``, it can set the keywords like "color",
"units", "angles", "scale". More settings please check
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am creating the vector field ...')
# get vector fields
xs = self.resolutions[self.x_var]
ys = self.resolutions[self.y_var]
X, Y = jnp.meshgrid(xs, ys)
dx = self.F_fx(X, Y)
dy = self.F_fy(X, Y)
X, Y = np.asarray(X), np.asarray(Y)
dx, dy = np.asarray(dx), np.asarray(dy)
if with_plot: # plot vector fields
if plot_method == 'quiver':
if plot_style is None:
plot_style = dict(units='xy')
if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
speed = np.sqrt(dx ** 2 + dy ** 2)
dx = dx / speed
dy = dy / speed
pyplot.quiver(X, Y, dx, dy, **plot_style)
elif plot_method == 'streamplot':
if plot_style is None:
plot_style = dict(arrowsize=1.2, density=1, color='thistle')
linewidth = plot_style.get('linewidth', None)
if linewidth is None:
if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
min_width, max_width = 0.5, 5.5
speed = np.nan_to_num(np.sqrt(dx ** 2 + dy ** 2))
linewidth = min_width + max_width * (speed / speed.max())
pyplot.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style)
else:
raise errors.AnalyzerError(f'Unknown plot_method "{plot_method}", '
f'only supports "quiver" and "streamplot".')
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
if show:
pyplot.show()
if with_return: # return vector fields
return dx, dy
[docs]
def plot_nullcline(self, with_plot=True, with_return=False,
y_style=None, x_style=None, show=False,
coords=None, tol_nullcline=1e-7):
"""Plot the nullcline."""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am computing fx-nullcline ...')
if coords is None:
coords = dict()
x_coord = coords.get(self.x_var, None)
y_coord = coords.get(self.y_var, None)
# Nullcline of the x variable
xy_values_in_fx, = self._get_fx_nullcline_points(coords=x_coord, tol=tol_nullcline)
x_values_in_fx = np.asarray(xy_values_in_fx[:, 0])
y_values_in_fx = np.asarray(xy_values_in_fx[:, 1])
if with_plot:
if x_style is None:
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
line_args = (x_style.pop('fmt'),) if 'fmt' in x_style else tuple()
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")
# Nullcline of the y variable
utils.output('I am computing fy-nullcline ...')
xy_values_in_fy, = self._get_fy_nullcline_points(coords=y_coord, tol=tol_nullcline)
x_values_in_fy = np.asarray(xy_values_in_fy[:, 0])
y_values_in_fy = np.asarray(xy_values_in_fy[:, 1])
if with_plot:
if y_style is None:
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
line_args = (y_style.pop('fmt'),) if 'fmt' in y_style else tuple()
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")
if with_plot:
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
scale = (self.lim_scale - 1.) / 2
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
pyplot.legend()
if show:
pyplot.show()
if with_return:
return {self.x_var: (x_values_in_fx, y_values_in_fx),
self.y_var: (x_values_in_fy, y_values_in_fy)}
[docs]
def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
tol_unique=1e-2, tol_aux=1e-8, tol_opt_screen=None,
select_candidates='fx-nullcline', num_rank=100, ):
"""Plot the fixed point and analyze its stability.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am searching fixed points ...')
if self._can_convert_to_one_eq():
if self.convert_type() == C.x_by_y:
candidates = bm.as_jax(self.resolutions[self.y_var])
else:
candidates = bm.as_jax(self.resolutions[self.x_var])
else:
if select_candidates == 'fx-nullcline':
candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys()
if key.startswith(C.fx_nullcline_points)]
if len(candidates) == 0:
raise errors.AnalyzerError(f'No nullcline points are found, please call '
f'".{self.plot_nullcline.__name__}()" first.')
candidates = jnp.vstack(candidates)
elif select_candidates == 'fy-nullcline':
candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys()
if key.startswith(C.fy_nullcline_points)]
if len(candidates) == 0:
raise errors.AnalyzerError(f'No nullcline points are found, please call '
f'".{self.plot_nullcline.__name__}()" first.')
candidates = jnp.vstack(candidates)
elif select_candidates == 'nullclines':
candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys()
if key.startswith(C.fy_nullcline_points) or key.startswith(C.fy_nullcline_points)]
if len(candidates) == 0:
raise errors.AnalyzerError(f'No nullcline points are found, please call '
f'".{self.plot_nullcline.__name__}()" first.')
candidates = jnp.vstack(candidates)
elif select_candidates == 'aux_rank':
candidates, _ = self._get_fp_candidates_by_aux_rank(num_rank=num_rank)
else:
raise ValueError
# get fixed points
if len(candidates):
fixed_points, _, _ = self._get_fixed_points(jnp.asarray(candidates),
tol_aux=tol_aux,
tol_unique=tol_unique,
tol_opt_candidate=tol_opt_screen)
utils.output('I am trying to filter out duplicate fixed points ...')
fixed_points = np.asarray(fixed_points)
fixed_points, _ = utils.keep_unique(fixed_points, tolerance=tol_unique)
utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.')
else:
utils.output(f'{C.prefix}Found no fixed points.')
return
# stability analysis
# ------------------
container = {a: {'x': [], 'y': []} for a in stability.get_2d_stability_types()}
for i in range(len(fixed_points)):
x = fixed_points[i, 0]
y = fixed_points[i, 1]
fp_type = stability.stability_analysis(self.F_jacobian(x, y))
utils.output(f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.")
container[fp_type]['x'].append(x)
container[fp_type]['y'].append(y)
# visualization
# -------------
if with_plot:
for fp_type, points in container.items():
if len(points['x']):
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
if with_return:
return fixed_points
[docs]
def plot_trajectory(self, initials, duration, plot_durations=None, axes='v-v',
dt=None, show=False, with_plot=True, with_return=False, **kwargs):
"""Plot trajectories according to the settings.
Parameters
----------
initials : list, tuple, dict
The initial value setting of the targets. It can be a tuple/list of floats to specify
each value of dynamical variables (for example, ``(a, b)``). It can also be a
tuple/list of tuple to specify multiple initial values (for example,
``[(a1, b1), (a2, b2)]``).
duration : int, float, tuple, list
The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
- It can be a int/float (``t_end``) to specify the same running end time,
- Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
the start and end simulation time.
- Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
to specify the specific start and end simulation time for each initial value.
plot_durations : tuple, list, optional
The duration to plot. It can be a tuple with ``(start, end)``. It can
also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
the plot duration for each initial value running.
axes : str
The axes to plot. It can be:
- 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis.
- 't-v': Plot the trajectory in the 'time'-'var' axis.
show : bool
Whether show or not.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the trajectory ...')
if axes not in ['v-v', 't-v']:
raise errors.AnalyzerError(f'Unknown axes "{axes}", only support "v-v" and "t-v".')
# check the initial values
initials = utils.check_initials(initials, self.target_var_names)
# 2. format the running duration
assert isinstance(duration, (int, float))
# 3. format the plot duration
plot_durations = utils.check_plot_durations(plot_durations, duration, initials)
# 5. run the network
dt = math.get_dt() if dt is None else dt
traject_model = utils.TrajectModel(
initial_vars=initials,
integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y},
dt=dt)
mon_res = traject_model.run(duration=duration)
if with_plot:
# plots
for i, initial in enumerate(zip(*list(initials.values()))):
# legend
legend = f'$traj_{i}$: '
for j, key in enumerate(self.target_var_names):
legend += f'{key}={round(float(initial[j]), 4)}, '
legend = legend[:-2]
# visualization
start = int(plot_durations[i][0] / dt)
end = int(plot_durations[i][1] / dt)
if axes == 'v-v':
lines = pyplot.plot(mon_res[self.x_var][start: end, i],
mon_res[self.y_var][start: end, i],
label=legend, **kwargs)
utils.add_arrow(lines[0])
else:
pyplot.plot(mon_res.ts[start: end],
mon_res[self.x_var][start: end, i],
label=legend + f', {self.x_var}', **kwargs)
pyplot.plot(mon_res.ts[start: end],
mon_res[self.y_var][start: end, i],
label=legend + f', {self.y_var}', **kwargs)
# visualization of others
if axes == 'v-v':
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
scale = (self.lim_scale - 1.) / 2
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
pyplot.legend()
else:
pyplot.legend(title='Initial values')
if show:
pyplot.show()
if with_return:
return mon_res
[docs]
def plot_limit_cycle_by_sim(self, initials, duration, tol=0.01, show=False, dt=None):
"""Plot trajectories according to the settings.
Parameters
----------
initials : list, tuple
The initial value setting of the targets.
- It can be a tuple/list of floats to specify each value of dynamical variables
(for example, ``(a, b)``).
- It can also be a tuple/list of tuple to specify multiple initial values (for
example, ``[(a1, b1), (a2, b2)]``).
duration : int, float, tuple, list
The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
- It can be a int/float (``t_end``) to specify the same running end time,
- Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
the start and end simulation time.
- Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
to specify the specific start and end simulation time for each initial value.
show : bool
Whether show or not.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the limit cycle ...')
# 1. format the initial values
initials = utils.check_initials(initials, self.target_var_names)
# 2. format the running duration
assert isinstance(duration, (int, float))
dt = math.get_dt() if dt is None else dt
traject_model = utils.TrajectModel(
initial_vars=initials,
integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y},
dt=dt)
mon_res = traject_model.run(duration=duration)
# 5. run the network
for init_i, initial in enumerate(zip(*list(initials.values()))):
# 5.2 run the model
x_data = mon_res[self.x_var][:, init_i]
y_data = mon_res[self.y_var][:, init_i]
max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol)
if max_index[0] != -1:
x_cycle = x_data[max_index[0]: max_index[1]]
y_cycle = y_data[max_index[0]: max_index[1]]
# 5.5 visualization
lines = pyplot.plot(x_cycle, y_cycle, label='limit cycle')
utils.add_arrow(lines[0])
else:
utils.output(f'No limit cycle found for initial value {initial}')
# 6. visualization
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
scale = (self.lim_scale - 1.) / 2
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
pyplot.legend()
if show:
pyplot.show()