import warnings
from functools import cache, wraps
# We don't need to reraise ImportErrors, they should be clear enough by themselves. If not
# and you're reading this: Fix the NEURON install, it's currently not importable ;)
from neuron import h as _h
from .core import (
assert_connectable,
is_nrn_scalar,
is_point_process,
is_section,
is_segment,
transform,
transform_netcon,
)
from .error_handler import CatchNetCon, CatchSectionAccess, catch_hoc_error
from .exceptions import (
BroadcastError,
HocSectionAccessError,
ParallelConnectError,
)
from .objects import (
IClamp,
NetCon,
PointProcess, # noqa: F401 # Function used during metaprogramming
PythonHocObject,
SEClamp,
Section,
SectionRef,
VecStim,
Vector,
_get_obj_registration_queue,
_safe_call,
)
from .version import get_neuron_version
_nrnver = get_neuron_version()
try:
if _nrnver < "7.8": # pragma: nocover
raise ImportError("Patch 3.0+ only supports NEURON v7.8.0 or higher.")
except Exception: # pragma: nocover
warnings.warn(
"Could not establish whether Patch supports installed NEURON version "
f"`{_nrnver}`",
stacklevel=1,
)
[docs]
class TimeSingleton(Vector):
@cache
def __new__(cls, *args, **kwargs):
return super().__new__(cls)
[docs]
class PythonHocInterpreter:
__pc: "ParallelContext"
__point_processes = []
__h = _h
def __init__(self):
self.__loaded_extensions = []
self.load_file("stdrun.hoc")
self.celsius = 32
self._finitialized: bool = False
@classmethod
def _process_registration_queue(cls):
"""
Most PythonHocObject classes (all those provided by Patch for sure) are created
before the PythonHocInterpreter class is available. Yet they require the class to
combine the original pointer from ``h.<object>`` (e.g. ``h.Section``) with a
function that defers to their constructor so that you can call ``p.Section()`` and
create a PythonHocObject wrapped around the underlying ``h`` pointer.
This function is called right after the PythonHocInterpreter class is created so
that PythonHocObjects can place themselves in a queue and have themselves
registered into the class right after it's ready.
"""
for hoc_object_class in _get_obj_registration_queue():
cls.register_hoc_object(hoc_object_class)
[docs]
@classmethod
def register_hoc_object(interpreter_class, hoc_object_class):
h = interpreter_class.__h
if hoc_object_class.__name__ in interpreter_class.__dict__:
# The function call was overridden in the interpreter and should not be
# destroyed.
return
hoc_object_name = hoc_object_class.__name__
# If the original interpreter doesn't have a function with the same name we can't
# simplify the constructor of the PythonHocObject and shouldn't wrap it.
if hasattr(h, hoc_object_name):
# Wrap it in the interpreter with a call to the underlying `h` to obtain a
# pointer and use that to make our PythonHocObject
factory = getattr(h, hoc_object_name)
@wraps(hoc_object_class.__init__)
def wrapper(interpreter_instance, *args, **kwargs):
hoc_ptr = factory(*args, **kwargs)
return hoc_object_class(interpreter_instance, hoc_ptr)
setattr(interpreter_class, hoc_object_class.__name__, wrapper)
def __getattr__(self, attr_name):
# Get the missing attribute from h
return getattr(self.__h, attr_name)
def __setattr__(self, attr, value):
if hasattr(self.__h, attr):
setattr(self.__h, attr, value)
else:
self.__dict__[attr] = value
[docs]
def nrn_load_dll(self, path):
result = self.__h.nrn_load_dll(path)
self.__class__._wrap_point_processes()
return result
[docs]
def NetCon(self, source, target, *args, **kwargs):
nrn_source = transform_netcon(source)
nrn_target = transform_netcon(target)
# Change the NetCon signature so that weight, delay and threshold become
# independent optional keyword arguments.
setters = {}
# Set sensible defaults: NetCons appear not to work sometimes if they're not set.
defaults = {"weight": 0.1, "delay": 1, "threshold": -20}
setter_keys = ["weight", "delay", "threshold"]
for key in setter_keys:
if key not in kwargs:
kwargs[key] = defaults[key]
setters[key] = kwargs[key]
del kwargs[key]
if is_section(source):
kwargs["sec"] = source
elif is_segment(source):
kwargs["sec"] = source.sec
elif is_nrn_scalar(source) and "sec" not in kwargs:
raise ConnectionError(
"Using NetCon with a scalar such as s(0.5)._ref_v is discouraged. "
"Use s(0.5) instead."
)
if "sec" in kwargs:
kwargs["sec"] = transform(kwargs["sec"])
# Execute HOC NetCon and wrap result into `connection`
with catch_hoc_error(CatchNetCon, nrn_source=nrn_source, nrn_target=nrn_target):
connection = NetCon(
self,
self.__h.NetCon(nrn_source, nrn_target, *args, **kwargs),
)
# Set the weight, delay and threshold independently
for k, v in setters.items():
if k == "weight":
if hasattr(type(v), "__iter__"): # pragma: nocover
for i, w in enumerate(v):
connection.weight[i] = w
else:
connection.weight[0] = v
else:
setattr(connection, k, v)
# Have the NetCon reference source and target
connection.__ref__(source)
connection.__ref__(target)
# If target is None, this NetCon is used as a spike detector.
if target is not None:
# Connect source and target.
assert_connectable(source, label="Source")
assert_connectable(target, label="Target")
source._connections[target] = connection
target._connections[source] = connection
elif hasattr(source, "__ref__"):
# Since the connection isn't established, make sure that the source and NetCon
# reference each other both ways
source.__ref__(connection)
return connection
[docs]
def ParallelCon(self, a, b, output=True, *args, **kwargs):
a_int = isinstance(a, int)
b_int = isinstance(b, int)
gid = a if a_int else b
if a_int != b_int:
if b_int:
source = a
nc = self.NetCon(source, None, *args, **kwargs)
self.parallel.set_gid2node(gid, self.parallel.id())
self.parallel.cell(gid, nc)
if output:
self.parallel.outputcell(gid)
# We only set threshold for sending NetCon's, as setting it on receiving
# NetCons may break transmission:
# https://github.com/neuronsimulator/nrn/issues/2135
nc.threshold = kwargs.get("threshold", -20.0)
else:
target = b
nc = self.parallel.gid_connect(gid, target)
nc.delay = kwargs.get("delay", nc.delay)
nc.weight[0] = kwargs.get("weight", nc.weight[0])
nc.gid = gid
return nc
else:
raise ParallelConnectError(
"Either the first or second argument has to be an integer GID."
)
[docs]
def SectionRef(self, *args, sec=None):
if len(args) > 1:
raise TypeError(
f"SectionRef takes 1 positional argument but {len(args)} given."
)
if sec is None:
if args:
sec = args[0]
else:
sec = self.cas()
if not sec: # pragma: nocover
raise RuntimeError(
"SectionRef() failed as there is no currently accessed section "
"available. Please specify a Section."
)
ref = SectionRef(self, self.__h.SectionRef(sec=transform(sec)))
if transform(sec) is sec:
sec = Section(self, sec)
ref.__ref__(sec)
ref.__dict__["sec"] = sec
ref.section = sec
return ref
[docs]
def ParallelContext(self):
return self.parallel
[docs]
def VecStim(self, /, *args, pattern=None, **kwargs):
import glia as g
mod_name = g.resolve("VecStim")
vec_stim = VecStim(self, getattr(self.__h, mod_name)(*args, **kwargs))
if pattern is not None:
pattern_vector = self.Vector(pattern)
vec_stim.play(pattern_vector.__neuron__())
self._vector = pattern_vector
self._pattern = pattern
return vec_stim
[docs]
def IClamp(self, x=0.5, sec=None):
sec = sec if sec is not None else self.cas()
clamp = IClamp(self, self.__h.IClamp(x, sec=transform(sec)))
clamp.__ref__(sec)
if hasattr(sec, "__ref__"):
sec.__ref__(clamp)
return clamp
[docs]
def SEClamp(self, sec, x=0.5):
clamp = SEClamp(self, self.__h.SEClamp(transform(sec(x))))
clamp.__ref__(sec)
if hasattr(sec, "__ref__"):
sec.__ref__(clamp)
return clamp
@property
def time(self):
# Fix for upstream NEURON bug. See https://github.com/neuronsimulator/nrn/issues/416
if not any(self.allsec()): # pragma: nocover
self.__dud_section = self.Section(name="this_is_here_to_record_time")
# Time vectors need to be shared, because it can only be recorded into 1
# target, but should also update every time they are accessed, to resize them
# between simulation sessions
time_singleton = TimeSingleton(self, self.__h.Vector().record(self._ref_t))
return time_singleton
[docs]
def finitialize(self, initial=None):
self.parallel.set_maxstep(10)
self._setup_transfer()
if initial is not None:
self.__h.finitialize(initial)
else:
self.__h.finitialize()
self._finitialized = True
[docs]
def continuerun(self, duration, v_init=None):
self._do_init(v_init)
self.parallel.psolve(self.__h.t + duration, v_init)
[docs]
def run(self, duration, v_init=None, reset=True):
self._do_init(v_init, reset=reset)
self.__h.continuerun(duration)
def _do_init(self, v_init=None, reset=False):
if reset or not self._finitialized:
self.finitialize(v_init)
[docs]
def cas(self):
# Currently error won't be triggered as h.cas() exits on undefined section acces:
# https://github.com/neuronsimulator/nrn/issues/769
try:
with catch_hoc_error(CatchSectionAccess):
return Section(self, self.__h.cas())
except HocSectionAccessError: # pragma: nocover
return None
def _init_pc(self):
if not hasattr(self, "_PythonHocInterpreter__pc"):
pc = ParallelContext(self, self.__h.ParallelContext())
try:
from mpi4py import MPI
except Exception:
self.__h.nrnmpi_init()
msize = pc.nhost()
else:
msize = MPI.COMM_WORLD.size
hosts = pc.nhost()
if msize != hosts: # pragma: nocover
raise RuntimeError(
f"MPI initialization error. `mpi4py` has a universe of size {msize},"
+ f" while NEURON has {hosts} hosts. Make sure that you import"
+ " `mpi4py` before importing either NEURON or Patch. If you did so,"
+ " your tools must not agree on which MPI implementation to use."
)
self.__pc = pc
@property
def parallel(self) -> "ParallelContext":
self._init_pc()
return self.__pc
[docs]
def record(self, target):
v = self.Vector()
v.record(target)
return v
@classmethod
def _wrap_point_processes(cls):
# Filter out all the point processes in the interpreter
point_processes = [k for k in dir(cls.__h) if is_point_process(k)]
old_point_processes = cls.__point_processes
# Check if there are any new things to wrap.
for point_process in set(point_processes) - set(old_point_processes):
# For each point process check if a function already exists, if not, wrap the
# HocInterpreter factory function.
if point_process not in cls.__dict__:
setattr(cls, point_process, cls._wrap_point_process(point_process))
cls.__point_processes = point_processes
@classmethod
def _wrap_point_process(cls, point_process):
# Create a function that has the right `f.__code__.co_name` for error messages.
scope = {}
exec(
f"""def {point_process}(self, target, *args, **kwargs):
h = getattr(self, '_PythonHocInterpreter__h')
factory = getattr(h, '{point_process}')
og_target = target
if hasattr(target, "__arc__"):
target = target(target.__arc__(), ephemeral=True)
nrn_target = transform(target)
nrn_ptr = factory(nrn_target, *args, **kwargs)
point_process = PointProcess(self, nrn_ptr)
if hasattr(og_target, "__ref__"):
og_target.__ref__(point_process)
point_process.__ref__(og_target)
return point_process""",
globals(),
scope,
)
return scope[point_process]
def _setup_transfer(self): # pragma: nocover
v = self.__h.Vector()
self.parallel.allgather(self.parallel._transfer_flag, v)
should_setup = sum(v)
if should_setup:
self.parallel.setup_transfer()
[docs]
class ParallelContext(PythonHocObject):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._warn_new_gids = False
self._transfer_max = -1
self._transfer_flag = False
[docs]
def cell(self, gid, nc):
transform(self).cell(gid, transform(nc))
[docs]
def set_gid2node(self, gid, id=None):
if self._warn_new_gids:
warnings.warn(
f"New GID ({gid}) registered after `spike_record` was called."
" This GID will not be recorded.",
stacklevel=2,
)
if id is None:
id = self.id()
return transform(self).set_gid2node(gid, id)
[docs]
def gid_connect(self, gid, target):
nrn_nc = transform(self).gid_connect(gid, transform_netcon(target))
nc = NetCon(self, nrn_nc)
# Forbid set threshold. See https://github.com/neuronsimulator/nrn/issues/2135
nc._nothreshold = True
nc.__ref__(target)
if hasattr(target, "__ref__"):
target.__ref__(nc)
return nc
@_safe_call
def source_var(self, call_result, *args, **kwargs): # pragma: nocover
key = args[-1]
if key < 0:
raise ValueError("Transfer variable keys must be larger than 0.")
# Store the highest used identifier
self._transfer_max = max(self._transfer_max, args[-1])
self._transfer_flag = True
return call_result
@_safe_call
def target_var(self, call_result, *args, **kwargs): # pragma: nocover
key = args[-1]
if key < 0:
raise ValueError("Transfer variable keys must be larger than 0.")
# Store the highest used identifier
self._transfer_max = max(self._transfer_max, args[-1])
self._transfer_flag = True
return call_result
@_safe_call
def setup_transfer(self, call_result, *args, **kwargs): # pragma: nocover
self._transfer_flag = False
return call_result
[docs]
def spike_record(self, gids=-1, time_vector=None, gid_vector=None, /):
if time_vector is None:
time_vector = self._interpreter.Vector()
if gid_vector is None:
gid_vector = self._interpreter.Vector()
transform(self).spike_record(gids, transform(time_vector), transform(gid_vector))
if gids == -1:
self._warn_new_gids = True
return time_vector, gid_vector
[docs]
def broadcast(self, data, root=0):
"""
Broadcast either a Vector or arbitrary picklable data. If ``data`` is a Vector,
the Vectors are resized and filled with the data from the Vector in the ``root``
node. If ``data`` is not a Vector, it is pickled, transmitted and returned from
this function to all nodes.
:param data: The data to broadcast to the nodes.
:type data: :class:`Vector <.objects.Vector>` or any picklable object.
:param root: The id of the node that is broadcasting the data.
:type root: int
:returns: None (Vectors filled) or the transmitted data
:raises: BroadcastError if ``neuron.hoc.HocObjects`` that aren't Vectors are
transmitted
"""
import neuron
data_ptr = transform(data)
# Is anyone broadcasting a HocObject?
if isinstance(data_ptr, neuron.hoc.HocObject):
# Comparing dir is used as a silly equality check because all NEURON object
# have class 'neuron.hoc.HocObject'
if dir(data_ptr) == dir(neuron.h.Vector()):
# If this node is broadcasting a Vector, then proceed to traditional
# broadcasting. If all nodes are broadcasting a Vector traditional
# broadcasting will occur, otherwise a BroadcastError is thrown.
transform(self).broadcast(data_ptr, root)
# After succesful broadcasts, the Vector is updated, return it.
return data
else:
# Send an empty vector so the other nodes don't hang.
transform(self).broadcast(transform(self._interpreter.Vector()), root)
raise BroadcastError(
"NEURON HocObjects cannot be broadcasted, "
"they need to be created on their own nodes."
)
else:
# If noone is sending a HocObject we proceed with picklable data broadcasting
return self._broadcast(data, root=root)
def _broadcast(self, data, root=0):
import pickle
if self.id() == root:
try:
v = self._interpreter.Vector(list(pickle.dumps(data)))
except Exception as e:
# Send an empty vector
# so the other nodes don't hang waiting for a broadcast.
transform(self).broadcast(transform(self._interpreter.Vector()), root)
raise BroadcastError(str(e)) from None
else:
v = self._interpreter.Vector()
v = transform(v)
transform(self).broadcast(v, root)
try:
return pickle.loads(bytes([int(d) for d in v]))
except EOFError:
raise BroadcastError(
"Root node did not transmit. Look for root node error."
) from None
[docs]
def psolve(self, tstop, v_init=None):
self_ = transform(self)
self_.set_maxstep(10)
self._interpreter._do_init(v_init)
self_.psolve(tstop)
PythonHocInterpreter._process_registration_queue()
PythonHocInterpreter._wrap_point_processes()