-
Yb Tweezer authoredYb Tweezer authored
control.py 11.86 KiB
from __future__ import annotations
import threading
import inspect
import ctypes
import sys
import time
import timeit
from math import prod, log10
from itertools import product
from typing import Callable, Iterable, Sequence, TypeVar
def _async_raise(tid, exctype):
"""
Raises the exception, performs cleanup if needed.
"""
if not inspect.isclass(exctype):
raise TypeError("Only types can be raised (not instances)")
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(tid), ctypes.py_object(exctype))
if res == 0:
raise ValueError("invalid thread id")
elif res != 1:
# if it returns a number greater than one, you're in trouble,
# and you should call it again with exc=NULL to revert the effect"""
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
class TerminableThread(threading.Thread):
def _get_my_tid(self):
"""
Determines this (self's) thread id.
"""
if not self.is_alive():
raise threading.ThreadError("the thread is not active")
# do we have it cached?
if hasattr(self, "_thread_id"):
return self._thread_id
# no, look for it in the _active dict
for tid, tobj in threading._active.items():
if tobj is self:
self._thread_id = tid
return tid
raise AssertionError("could not determine the thread's id")
def raise_exc(self, exctype):
"""
Raises the given exception type in the context of this thread.
"""
_async_raise(self._get_my_tid(), exctype)
def terminate(self):
"""
Raises SystemExit in the context of the given thread, which should
cause the thread to exit silently (unless caught).
"""
self.raise_exc(SystemExit)
def run(self):
self.exception = None
try:
threading.Thread.run(self)
except BaseException as err:
self.exception = err
def join(self):
threading.Thread.join(self)
if self.exception is not None:
raise self.exception
class Controller:
def precmd(self, *args):
raise NotImplementedError
def postcmd(self, *args):
raise NotImplementedError
def on_interrupt(self, *args):
self.on_error(KeyboardInterrupt(), *args)
def on_error(self, ERR: Exception, *args):
raise NotImplementedError
def _get_dirnames(self, prefix: str=""):
n = len(prefix)
return [name[n:] for name in dir(self) if name[:n] == prefix]
def _get_diritems(self, prefix: str=""):
n = len(prefix)
return {
name[n:]: getattr(self, name)
for name in dir(self) if name[:n] == prefix
}
def cmd_help(self, *args, exit=0):
cmds = self._get_diritems("cmd_")
helptext = "Available commands:\n\t" \
+ ", ".join(cmd for cmd in cmds.keys())
if len(args) > 0:
if args[0] in cmds.keys():
print(f"Help for command {args[0]}:")
print(cmds[args[0]].__doc__)
sys.exit(0)
else:
print(helptext)
sys.exit(1)
else:
print(helptext)
sys.exit(exit)
def _watch_threads(self, threads: list[(str, TerminableThread)]):
while len(threads) > 0:
for name, thread in threads:
if not thread.is_alive():
try:
thread.join()
except BaseException as ERR:
print(f"\n[control] action `{name}` encountered exception"
f" {type(ERR).__name__}: '{ERR}'")
print("[control] terminating all actions...")
raise ERR
print(f"[control] action `{name}` finished")
threads.remove((name, thread))
time.sleep(1)
def _terminate_threads(self, threads: list[(str, TerminableThread)]) \
-> list[BaseException]:
for name, thread in threads:
if thread.is_alive():
thread.terminate()
exceptions = list()
while len(threads) > 0:
for name, thread in threads:
if not thread.is_alive():
try:
thread.join()
except BaseException as err:
exceptions.append(err)
print(f"[control] action `{name}` terminated")
threads.remove((name, thread))
time.sleep(1)
return exceptions
def _perform_actions(self, prefix: str, args):
actions = self._get_diritems(prefix)
print(
f"[control] Found {len(actions.keys())} actions"
+ (
":\n\t"
+ ", ".join(name for name in actions.keys())
) if len(actions.keys()) > 0 else ""
)
try:
self.precmd(*args)
except NotImplementedError:
print("[control] no pre-command defined")
threads = [
(name, TerminableThread(target=fn, args=args))
for name, fn in actions.items()
]
for name, thread in threads:
print(f"[control] action `{name}` started")
thread.start()
try:
self._watch_threads(threads)
except KeyboardInterrupt:
print("\n[control] received keyboard interrupt; terminating all actions...")
try:
self._terminate_threads(threads)
except KeyboardInterrupt:
self._terminate_threads(threads)
try:
self.on_interrupt(*args)
except NotImplementedError:
print("[control] no interrupt handler defined")
return None
except BaseException as ERR:
# print(f"[control] encountered {type(ERR).__name__}: '{ERR}'; terminating all actions...")
self._terminate_threads(threads)
try:
self.on_error(ERR, *args)
except NotImplementedError:
print("[control] no error handler defined")
return None
try:
self.postcmd(*args)
except NotImplementedError:
print("[control] no post-command defined")
def cmd_run(self, *args):
self._perform_actions("run_", args)
def RUN(self, argv=None):
argv = sys.argv[1:] if argv is None else argv
command_handler = self._get_diritems("cmd_")
if len(argv) < 1 or argv[0] not in command_handler.keys():
self.cmd_help(exit=1)
else:
self.cmd = argv[0]
command_handler[argv[0]](*argv[1:])
T = TypeVar("T")
class ScanVals:
def __init__(
self,
vals: Sequence[T],
pre_loop: Callable[[T, ...], ...]=None,
post_loop: Callable[[T, ...], ...]=None,
):
self.vals = vals
self.pre_loop = lambda *args, **kwargs: None \
if pre_loop is None else pre_loop
self.post_loop = lambda *args, **kwargs: None \
if post_loop is None else post_loop
def __getitem__(self, idx: int) -> T:
return self.vals[idx]
def __contains__(self, thing: T) -> bool:
return thing in self.vals
def __len__(self) -> int:
return len(self.vals)
def __iter__(self) -> Iterable:
return iter(self.vals)
def _nested_loop_over(
vals: list[ScanVals],
loop_func: Callable[[T, ...], ...],
Q: list[int],
NN: list[int],
fmt: str,
loop_args: Sequence[...],
loop_kwargs: dict[str, ...],
reps_first: bool,
printflag: bool,
) -> None:
if len(vals) == 0:
raise Exception("_nested_loop_over: must have values to loop over")
elif len(vals) == 1:
_fmt = fmt + f" {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}"
_fmt += " ({:6.2f}%) \r"
for k, x in enumerate(vals[0]):
if printflag:
print(
_fmt.format(
*([q + 1 for q in Q] + [k + 1]),
100.0 * (
sum(q * nnq for q, nnq in zip(Q, NN[1:])) + k
) / NN[0]
),
end="", flush=True
)
# vals[0].pre_loop(x, *loop_args, **loop_kwargs)
loop_func(x, *loop_args, **loop_kwargs)
# vals[0].post_loop(x, *loop_args, **loop_kwargs)
if printflag:
print(
_fmt.format(
*([q + 1 for q in Q] + [k + 1]),
100.0 * (
sum(q * nnq for q, nnq in zip(Q, NN[1:])) + k + 1
) / NN[0]
),
end="", flush=True
)
elif len(vals) == 2 and not reps_first:
for k, x in enumerate(vals[0]):
vals[0].pre_loop(x, *loop_args, **loop_kwargs)
_nested_loop_over(
[ScanVals(len(vals[1]) * [x])],
loop_func,
Q + [k],
NN + [prod([len(X) for X in vals[1:]])],
fmt + f" {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}",
loop_args,
loop_kwargs,
reps_first,
printflag,
)
vals[0].post_loop(x, *loop_args, **loop_kwargs)
else:
for k, x in enumerate(vals[0]):
vals[0].pre_loop(x, *loop_args, **loop_kwargs)
_nested_loop_over(
vals[1:],
loop_func,
Q + [k],
NN + [prod([len(X) for X in vals[1:]])],
fmt + f" {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}",
loop_args,
loop_kwargs,
reps_first,
printflag,
)
vals[0].post_loop(x, *loop_args, **loop_kwargs)
def ndscan(
vals: list[ScanVals],
reps: int,
loop_func: Callable[[...], ...],
loop_args: Sequence[...]=None,
loop_kwargs: dict[str, ...]=None,
reps_first: bool=True,
printflag: bool=True,
) -> None:
if reps_first:
_vals = [ScanVals(list(range(reps)))] + vals
else:
_vals = vals + [ScanVals(list(range(reps)))]
TOT = prod(len(X) for X in _vals)
t0 = timeit.default_timer()
_nested_loop_over(
_vals,
loop_func,
list(),
[TOT],
"",
list() if loop_args is None else loop_args,
dict() if loop_kwargs is None else loop_kwargs,
reps_first,
printflag,
)
T = timeit.default_timer() - t0
if printflag:
print(
f"\n total time elapsed: {T:.2f} s"
f"\n average time per shot: {T / TOT:.2f} s"
)
def map_scanvars(
scan_maps: Sequence[tuple[Callable[[T, ...], ...], dict[str, Sequence[T]]]],
map_args: Sequence[...]=None,
map_kwargs: dict[str, ...]=None,
reps: int=None,
reps_first: bool=True,
printflag: bool=True,
) -> tuple[list[str], list[list[...]]]:
_map_args = tuple() if map_args is None else map_args
_map_kwargs = dict() if map_kwargs is None else map_kwargs
allvars = {"REP": list(range(reps))} if reps is not None and reps_first else dict()
for m in scan_maps:
allvars.update(m[1])
if reps is not None and not reps_first:
allvars["REP"]
fmt = ("_".join(v + "={:+.5f}" for v in allvars.keys())).replace("REP=", "")
if printflag:
print("[control] computing iterable scanning variables ... ",
end="", flush=True)
names = [fmt.format(*V) for V in product(allvars.values())]
mapped = [
[m[0](*XYZ, *_map_args, **_map_kwargs) for XYZ in product(m[1].values())]
for m in scan_maps
]
if printflag:
print(f"done {tuple(len(X) for X in ret)}")
return names, mapped