#!/usr/bin/env python
"""
Cached networks to use with the population classes, as the only
variables being used is "nodes_ex" and "nodes_in" VERSION THAT WORKS.
"""
import numpy as np
import os
import glob
if 'DISPLAY' not in os.environ:
import matplotlib
matplotlib.use('Agg')
from .gdf import GDF
import matplotlib.pyplot as plt
from mpi4py import MPI
################# Initialization of MPI stuff ##################################
COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()
############## Functions #######################################################
def remove_axis_junk(ax, which=['right', 'top']):
"""
Remove axis lines from axes object that exist in list which.
Parameters
----------
ax : `matplotlib.axes.AxesSubplot` object
which : list of str
Entries in ['right', 'top', 'bottom', 'left'].
Returns
-------
None
"""
for loc, spine in ax.spines.items():
if loc in which:
spine.set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
################ Classes #######################################################
[docs]class CachedNetwork(object):
"""
Offline processing and storing of network spike events, used by other
class objects in the package hybridLFPy.
Parameters
----------
simtime : float
Simulation duration.
dt : float,
Simulation timestep size.
spike_output_path : str
Path to gdf-files with spikes.
label : str
Prefix of spiking gdf-files.
ext : str
File extension of gdf-files.
GIDs : dict
dictionary keys are population names and item a list with first
GID in population and population size
autocollect : bool
If True, class init will process gdf files.
cmap : str
Name of colormap, must be in `dir(plt.cm)`.
Returns
-------
`hybridLFPy.cachednetworks.CachedNetwork` object
See also
--------
CachedFixedSpikesNetwork, CachedNoiseNetwork
"""
def __init__(self,
simtime = 1000.,
dt = 0.1,
spike_output_path='spike_output_path',
label = 'spikes',
ext = 'gdf',
GIDs={'EX' : [1, 400], 'IN' : [401, 100]},
autocollect=True,
cmap='Set1',
):
"""
Offline processing and storing of network spike events, used by other
class objects in the package `hybridLFPy`.
Parameters
----------
simtime : float
Simulation duration.
dt : float
Simulation timestep size.
spike_output_path : str
Path to gdf-files with spikes.
label : str
Prefix of spiking gdf-files.
ext : str
File extension of gdf-files.
GIDs : dict
dictionary keys are population names and item a list with first
GID in population and population size
autocollect : bool
If True, class init will process gdf files.
cmap : str
Name of colormap, must be in dir(plt.cm).
Returns
-------
`hybridLFPy.cachednetworks.CachedNetwork` object
See also
--------
CachedFixedSpikesNetwork, CachedNoiseNetwork
"""
# Set some attributes
self.simtime = simtime
self.dt = dt
self.spike_output_path = spike_output_path
self.label = label
self.ext = ext
self.dbname = ':memory:'
self.GIDs = GIDs
self.X = GIDs.keys()
self.X.sort()
self.autocollect = autocollect
# Create a dictionary of nodes with proper layernames
self.nodes = {}
for X in self.X:
self.nodes[X] = np.arange(self.GIDs[X][1]) + self.GIDs[X][0]
#list population sizes
self.N_X = np.array([self.GIDs[X][1] for X in self.X])
if self.autocollect:
#collect the gdf files
self.collect_gdf()
# Specify some plot colors used for each population:
if 'TC' in self.X:
numcolors = len(self.X)-1
else:
numcolors = len(self.X)
self.colors = []
for i in range(numcolors):
self.colors += [plt.get_cmap(cmap, numcolors)(i)]
if 'TC' in self.X:
self.colors += ['k']
[docs] def collect_gdf(self):
"""
Collect the gdf-files from network sim in folder `spike_output_path`
into sqlite database, using the GDF-class.
Parameters
----------
None
Returns
-------
None
"""
# Resync
COMM.Barrier()
# Raise Exception if there are no gdf files to be read
if len(glob.glob(os.path.join(self.spike_output_path,
self.label + '*.'+ self.ext))) == 0:
raise Exception('path to files contain no gdf-files!')
#create in-memory databases of spikes
if not hasattr(self, 'dbs'):
self.dbs = {}
for X in self.X:
db = GDF(os.path.join(self.dbname),
debug=True, new_db=True)
db.create(re=os.path.join(self.spike_output_path,
'{0}*{1}*{2}'.format(self.label, X,
self.ext)),
index=True)
self.dbs.update({
X : db
})
COMM.Barrier()
[docs] def get_xy(self, xlim, fraction=1.):
"""
Get pairs of node units and spike trains on specific time interval.
Parameters
----------
xlim : list of floats
Spike time interval, e.g., [0., 1000.].
fraction : float in [0, 1.]
If less than one, sample a fraction of nodes in random order.
Returns
-------
x : dict
In `x` key-value entries are population name and neuron spike times.
y : dict
Where in `y` key-value entries are population name and neuron gid number.
"""
x = {}
y = {}
for X, nodes in self.nodes.items():
x[X] = np.array([])
y[X] = np.array([])
if fraction != 1:
nodes = np.random.permutation(nodes)[:int(nodes.size*fraction)]
nodes.sort()
spiketimes = self.dbs[X].select_neurons_interval(nodes, T=xlim)
i = 0
for times in spiketimes:
x[X] = np.r_[x[X], times]
y[X] = np.r_[y[X], np.zeros(times.size) + nodes[i]]
i += 1
return x, y
[docs] def plot_raster(self, ax, xlim, x, y, pop_names=False,
markersize=20., alpha=1., legend=True,
marker='o', rasterized=True):
"""
Plot network raster plot in subplot object.
Parameters
----------
ax : `matplotlib.axes.AxesSubplot` object
plot axes
xlim : list
List of floats. Spike time interval, e.g., [0., 1000.].
x : dict
Key-value entries are population name and neuron spike times.
y : dict
Key-value entries are population name and neuron gid number.
pop_names: bool
If True, show population names on yaxis instead of gid number.
markersize : float
raster plot marker size
alpha : float in [0, 1]
transparency of marker
legend : bool
Switch on axes legends.
marker : str
marker symbol for matplotlib.pyplot.plot
rasterized : bool
if True, the scatter plot will be treated as a bitmap embedded in
pdf file output
Returns
-------
None
"""
yoffset = [sum(self.N_X) if X=='TC' else 0 for X in self.X]
for i, X in enumerate(self.X):
if y[X].size > 0:
ax.plot(x[X], y[X]+yoffset[i], marker,
markersize=markersize,
mfc=self.colors[i],
mec='none' if marker in '.ov><v^1234sp*hHDd' else self.colors[i],
alpha=alpha,
label=X, rasterized=rasterized,
clip_on=True)
#don't draw anything for the may-be-quiet TC population
N_X_sum = 0
for i, X in enumerate(self.X):
if y[X].size > 0:
N_X_sum += self.N_X[i]
ax.axis([xlim[0], xlim[1],
self.GIDs[self.X[0]][0], self.GIDs[self.X[0]][0]+N_X_sum])
ax.set_ylim(ax.get_ylim()[::-1])
ax.set_ylabel('cell id', labelpad=0)
ax.set_xlabel('$t$ (ms)', labelpad=0)
if legend:
ax.legend()
if pop_names:
yticks = []
yticklabels = []
for i, X in enumerate(self.X):
if y[X] != []:
yticks.append(y[X].mean()+yoffset[i])
yticklabels.append(self.X[i])
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
# Add some horizontal lines separating the populations
for i, X in enumerate(self.X):
if y[X].size > 0:
ax.plot([xlim[0], xlim[1]],
[y[X].max()+yoffset[i], y[X].max()+yoffset[i]],
'k', lw=0.25)
[docs] def plot_f_rate(self, ax, X, i, xlim, x, y, binsize=1, yscale='linear',
plottype='fill_between', show_label=False, rasterized=False):
"""
Plot network firing rate plot in subplot object.
Parameters
----------
ax : `matplotlib.axes.AxesSubplot` object.
X : str
Population name.
i : int
Population index in class attribute `X`.
xlim : list of floats
Spike time interval, e.g., [0., 1000.].
x : dict
Key-value entries are population name and neuron spike times.
y : dict
Key-value entries are population name and neuron gid number.
yscale : 'str'
Linear, log, or symlog y-axes in rate plot.
plottype : str
plot type string in `['fill_between', 'bar']`
show_label : bool
whether or not to show labels
Returns
-------
None
"""
bins = np.arange(xlim[0], xlim[1]+binsize, binsize)
(hist, bins) = np.histogram(x[X], bins=bins)
if plottype == 'fill_between':
ax.fill_between(bins[:-1], hist * 1000. / self.N_X[i],
color=self.colors[i], lw=0.5, label=X, rasterized=rasterized,
clip_on=False)
ax.plot(bins[:-1], hist * 1000. / self.N_X[i],
color='k', lw=0.5, label=X, rasterized=rasterized,
clip_on=False)
elif plottype == 'bar':
ax.bar(bins[:-1], hist * 1000. / self.N_X[i],
color=self.colors[i], label=X, rasterized=rasterized ,
linewidth=0.25, width=0.9, clip_on=False)
else:
mssg = "plottype={} not in ['fill_between', 'bar']".format(plottype)
raise Exception(mssg)
remove_axis_junk(ax)
ax.axis(ax.axis('tight'))
ax.set_yscale(yscale)
ax.set_xlim(xlim[0], xlim[1])
if show_label:
ax.text(xlim[0] + .05*(xlim[1]-xlim[0]), ax.axis()[3]*1.5, X,
va='center', ha='left')
[docs] def raster_plots(self, xlim=[0, 1000], markersize=1, alpha=1., marker='o'):
"""
Pretty plot of the spiking output of each population as raster and rate.
Parameters
----------
xlim : list
List of floats. Spike time interval, e.g., `[0., 1000.]`.
markersize : float
marker size for plot, see `matplotlib.pyplot.plot`
alpha : float
transparency for markers, see `matplotlib.pyplot.plot`
marker : :mod:`A valid marker style <matplotlib.markers>`
Returns
-------
fig : `matplotlib.figure.Figure` object
"""
x, y = self.get_xy(xlim)
fig = plt.figure()
fig.subplots_adjust(left=0.12, hspace=0.15)
ax0 = fig.add_subplot(211)
self.plot_raster(ax0, xlim, x, y, markersize=markersize, alpha=alpha,
marker=marker)
remove_axis_junk(ax0)
ax0.set_title('spike raster')
ax0.set_xlabel("")
nrows = len(self.X)
bottom = np.linspace(0.1, 0.45, nrows+1)[::-1][1:]
thickn = np.abs(np.diff(bottom))[0]*0.9
for i, layer in enumerate(self.X):
ax1 = fig.add_axes([0.12, bottom[i], 0.78, thickn])
self.plot_f_rate(ax1, layer, i, xlim, x, y, )
if i == nrows-1:
ax1.set_xlabel('time (ms)')
else:
ax1.set_xticklabels([])
if i == 4:
ax1.set_ylabel(r'population rates ($s^{-1}$)')
if i == 0:
ax1.set_title(r'population firing rates ($s^{-1}$)')
return fig
[docs]class CachedFixedSpikesNetwork(CachedNetwork):
"""
Subclass of CachedNetwork.
Fake nest output, where each cell in a subpopulation spike
simultaneously, and each subpopulation is activated at times given in
kwarg activationtimes.
Parameters
----------
activationtimes : list of floats
Each entry set spike times of all cells in each population
autocollect : bool
whether or not to automatically gather gdf file output
**kwargs : see parent class `hybridLFPy.cachednetworks.CachedNetwork`
Returns
-------
`hybridLFPy.cachednetworks.CachedFixedSpikesNetwork` object
See also
--------
CachedNetwork, CachedNoiseNetwork,
"""
def __init__(self,
activationtimes=[200, 300, 400, 500, 600, 700, 800, 900, 1000],
autocollect=False,
**kwargs):
"""
Subclass of CachedNetwork
Fake nest output, where each cell in a subpopulation spike
simultaneously, and each subpopulation is activated at times given in
kwarg activationtimes.
Parameters
----------
activationtimes : list
Each entry set spike times of all cells in each population
autocollect : bool
whether or not to automatically gather gdf file output
**kwargs : see parent class `hybridLFPy.cachednetworks.CachedNetwork`
Returns
-------
`hybridLFPy.cachednetworks.CachedFixedSpikesNetwork` object
See also
--------
CachedNetwork, CachedNoiseNetwork,
"""
CachedNetwork.__init__(self, autocollect=autocollect, **kwargs)
# Set some attributes
self.activationtimes = activationtimes
if len(activationtimes) != len(self.N_X):
raise Exception('len(activationtimes != len(self.N_X))')
""" Create a dictionary of nodes with proper layernames
self.nodes = {}.
"""
if RANK == 0:
for i, N in enumerate(self.N_X):
nodes = self.nodes[self.X[i]]
cell_spt = list(zip(nodes, [self.activationtimes[i]
for x in range(nodes.size)]))
cell_spt = np.array(cell_spt, dtype=[('a', int), ('b', float)])
np.savetxt(os.path.join(self.spike_output_path,
self.label + '_{}.gdf'.format(self.X[i])),
cell_spt, fmt=['%i', '%.1f'])
# Resync
COMM.barrier()
# Collect the gdf files
self.collect_gdf()
[docs]class CachedNoiseNetwork(CachedNetwork):
"""
Subclass of CachedNetwork.
Use Nest to generate N_X poisson-generators each with rate frate,
and record every vector, and create database with spikes.
Parameters
----------
frate : list
Rate of each layer, may be tuple (onset, rate, offset)
autocollect : bool
whether or not to automatically gather gdf file output
**kwargs : see parent class `hybridLFPy.cachednetworks.CachedNetwork`
Returns
-------
`hybridLFPy.cachednetworks.CachedNoiseNetwork` object
See also
--------
CachedNetwork, CachedFixedSpikesNetwork
"""
def __init__(self,
frate=[(200., 15., 210.), 0.992, 3.027, 4.339, 5.962,
7.628, 8.669, 1.118, 7.859],
autocollect=False,
**kwargs):
"""
Subclass of `CachedNetwork`.
Use Nest to generate N_X poisson-generators each with rate frate,
and record every vector, and create database with spikes.
Parameters
----------
frate : list
Rate of each layer, may be tuple (onset, rate, offset).
autocollect : bool
whether or not to automatically gather gdf file output
**kwargs : see parent class `hybridLFPy.cachednetworks.CachedNetwork`
Returns
-------
`hybridLFPy.cachednetworks.CachedNoiseNetwork` object
See also
--------
CachedNetwork, CachedFixedSpikesNetwork
"""
CachedNetwork.__init__(self, autocollect=autocollect, **kwargs)
"""
Putting import nest here, avoid making `nest` a mandatory
`hybridLFPy` dependency.
"""
import nest
#set some attributes:
self.frate = frate
if len(self.frate) != self.N_X.size:
raise Exception('self.frate.size != self.N_X.size')
self.spike_output_path = spike_output_path
self.total_num_virtual_procs = SIZE
# Reset nest kernel and set some kernel status variables, destroy old
# nodes etc in the process
nest.ResetKernel()
#if dt is in powers of two, dt must be multiple of ms_per_tic
if self.dt in 2**np.arange(-32., 0):
nest.SetKernelStatus({
"tics_per_ms" : 2**2 / self.dt,
"resolution": self.dt,
"print_time": True,
"overwrite_files" : True,
"total_num_virtual_procs" : self.total_num_virtual_procs,
})
else:
nest.SetKernelStatus({
"resolution": self.dt,
"print_time": True,
"overwrite_files" : True,
"total_num_virtual_procs" : self.total_num_virtual_procs,
})
nest.SetDefaults("spike_detector", {
'withtime' : True,
'withgid' : True,
'to_file' : True,
'to_memory' : False,
})
# Create some populations of parrot neurons that echo the poisson noise
self.nodes = {}
for i, N in enumerate(self.N_X):
self.nodes[self.X[i]] = nest.Create('parrot_neuron', N)
if os.path.isfile(os.path.join(self.spike_output_path, self.dbname)):
mystring = os.path.join(self.spike_output_path, self.dbname)
print('db %s exist, will not rerun sim or collect gdf!' % mystring)
else:
# Create spike detector
self.spikes = nest.Create("spike_detector", 1,
{'label' : os.path.join(self.spike_output_path,
self.label)})
""" Create independent poisson spike trains with the some rate,
but each layer population should really have different rates.
"""
self.noise = []
for rate in self.frate:
if type(rate) == tuple:
self.noise.append(nest.Create("poisson_generator", 1,
{ "start" : rate[0],
"rate" : rate[1],
"stop" : rate[2]}))
else:
self.noise.append(nest.Create("poisson_generator", 1,
{"rate" : rate}))
## Connect parrots and spike detector
for layer in self.X:
nest.ConvergentConnect(self.nodes[layer], self.spikes,
model='static_synapse')
# Connect noise generators and nodes
for i, layer in enumerate(self.X):
nest.ConvergentConnect(self.noise[i], self.nodes[layer],
model='static_synapse')
# Run simulation
nest.Simulate(self.simtime)
# Collect the gdf files
self.collect_gdf()
# Nodes need to be collected in np.ndarrays:
for key in list(self.nodes.keys()):
self.nodes[key] = np.array(self.nodes[key])
if __name__ == '__main__':
import doctest
doctest.testmod()