Skip to content
Snippets Groups Projects
control.py 11.86 KiB
from __future__ import annotations
import threading
import inspect
import ctypes
import sys
import time
import timeit
from math import prod, log10
from itertools import product
from typing import Callable, Iterable, Sequence, TypeVar

def _async_raise(tid, exctype):
    """
    Raises the exception, performs cleanup if needed.
    """
    if not inspect.isclass(exctype):
        raise TypeError("Only types can be raised (not instances)")
    res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
        ctypes.c_long(tid), ctypes.py_object(exctype))
    if res == 0:
        raise ValueError("invalid thread id")
    elif res != 1:
        # if it returns a number greater than one, you're in trouble,
        # and you should call it again with exc=NULL to revert the effect"""
        ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, 0)
        raise SystemError("PyThreadState_SetAsyncExc failed")


class TerminableThread(threading.Thread):
    def _get_my_tid(self):
        """
        Determines this (self's) thread id.
        """
        if not self.is_alive():
            raise threading.ThreadError("the thread is not active")

        # do we have it cached?
        if hasattr(self, "_thread_id"):
            return self._thread_id

        # no, look for it in the _active dict
        for tid, tobj in threading._active.items():
            if tobj is self:
                self._thread_id = tid
                return tid

        raise AssertionError("could not determine the thread's id")

    def raise_exc(self, exctype):
        """
        Raises the given exception type in the context of this thread.
        """
        _async_raise(self._get_my_tid(), exctype)

    def terminate(self):
        """
        Raises SystemExit in the context of the given thread, which should
        cause the thread to exit silently (unless caught).
        """
        self.raise_exc(SystemExit)

    def run(self):
        self.exception = None
        try:
            threading.Thread.run(self)
        except BaseException as err:
            self.exception = err

    def join(self):
        threading.Thread.join(self)
        if self.exception is not None:
            raise self.exception

class Controller:
    def precmd(self, *args):
        raise NotImplementedError

    def postcmd(self, *args):
        raise NotImplementedError

    def on_interrupt(self, *args):
        self.on_error(KeyboardInterrupt(), *args)

    def on_error(self, ERR: Exception, *args):
        raise NotImplementedError

    def _get_dirnames(self, prefix: str=""):
        n = len(prefix)
        return [name[n:] for name in dir(self) if name[:n] == prefix]

    def _get_diritems(self, prefix: str=""):
        n = len(prefix)
        return {
            name[n:]: getattr(self, name)
            for name in dir(self) if name[:n] == prefix
        }

    def cmd_help(self, *args, exit=0):
        cmds = self._get_diritems("cmd_")
        helptext = "Available commands:\n\t" \
                + ", ".join(cmd for cmd in cmds.keys())
        if len(args) > 0:
            if args[0] in cmds.keys():
                print(f"Help for command {args[0]}:")
                print(cmds[args[0]].__doc__)
                sys.exit(0)
            else:
                print(helptext)
                sys.exit(1)
        else:
            print(helptext)
            sys.exit(exit)

    def _watch_threads(self, threads: list[(str, TerminableThread)]):
        while len(threads) > 0:
            for name, thread in threads:
                if not thread.is_alive():
                    try:
                        thread.join()
                    except BaseException as ERR:
                        print(f"\n[control] action `{name}` encountered exception"
                            f" {type(ERR).__name__}: '{ERR}'")
                        print("[control] terminating all actions...")
                        raise ERR
                    print(f"[control] action `{name}` finished")
                    threads.remove((name, thread))
                time.sleep(1)

    def _terminate_threads(self, threads: list[(str, TerminableThread)]) \
            -> list[BaseException]:
        for name, thread in threads:
            if thread.is_alive():
                thread.terminate()
        exceptions = list()
        while len(threads) > 0:
            for name, thread in threads:
                if not thread.is_alive():
                    try:
                        thread.join()
                    except BaseException as err:
                        exceptions.append(err)
                    print(f"[control] action `{name}` terminated")
                    threads.remove((name, thread))
                time.sleep(1)
        return exceptions

    def _perform_actions(self, prefix: str, args):
        actions = self._get_diritems(prefix)
        print(
            f"[control] Found {len(actions.keys())} actions"
            + (
                ":\n\t"
                + ", ".join(name for name in actions.keys())
            ) if len(actions.keys()) > 0 else ""
        )
        try:
            self.precmd(*args)
        except NotImplementedError:
            print("[control] no pre-command defined")
        threads = [
            (name, TerminableThread(target=fn, args=args))
            for name, fn in actions.items()
        ]
        for name, thread in threads:
            print(f"[control] action `{name}` started")
            thread.start()
        try:
            self._watch_threads(threads)
        except KeyboardInterrupt:
            print("\n[control] received keyboard interrupt; terminating all actions...")
            try:
                self._terminate_threads(threads)
            except KeyboardInterrupt:
                self._terminate_threads(threads)
            try:
                self.on_interrupt(*args)
            except NotImplementedError:
                print("[control] no interrupt handler defined")
            return None
        except BaseException as ERR:
            # print(f"[control] encountered {type(ERR).__name__}: '{ERR}'; terminating all actions...")
            self._terminate_threads(threads)
            try:
                self.on_error(ERR, *args)
            except NotImplementedError:
                print("[control] no error handler defined")
            return None
        try:
            self.postcmd(*args)
        except NotImplementedError:
            print("[control] no post-command defined")

    def cmd_run(self, *args):
        self._perform_actions("run_", args)

    def RUN(self, argv=None):
        argv = sys.argv[1:] if argv is None else argv
        command_handler = self._get_diritems("cmd_")
        if len(argv) < 1 or argv[0] not in command_handler.keys():
            self.cmd_help(exit=1)
        else:
            self.cmd = argv[0]
            command_handler[argv[0]](*argv[1:])

T = TypeVar("T")

class ScanVals:
    def __init__(
        self,
        vals: Sequence[T],
        pre_loop: Callable[[T, ...], ...]=None,
        post_loop: Callable[[T, ...], ...]=None,
    ):
        self.vals = vals
        self.pre_loop = lambda *args, **kwargs: None \
                if pre_loop is None else pre_loop
        self.post_loop = lambda *args, **kwargs: None \
                if post_loop is None else post_loop

    def __getitem__(self, idx: int) -> T:
        return self.vals[idx]

    def __contains__(self, thing: T) -> bool:
        return thing in self.vals

    def __len__(self) -> int:
        return len(self.vals)

    def __iter__(self) -> Iterable:
        return iter(self.vals)

def _nested_loop_over(
    vals: list[ScanVals],
    loop_func: Callable[[T, ...], ...],
    Q: list[int],
    NN: list[int],
    fmt: str,
    loop_args: Sequence[...],
    loop_kwargs: dict[str, ...],
    reps_first: bool,
    printflag: bool,
) -> None:
    if len(vals) == 0:
        raise Exception("_nested_loop_over: must have values to loop over")
    elif len(vals) == 1:
        _fmt = fmt + f"  {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}"
        _fmt += "  ({:6.2f}%)  \r"
        for k, x in enumerate(vals[0]):
            if printflag:
                print(
                    _fmt.format(
                        *([q + 1 for q in Q] + [k + 1]),
                        100.0 * (
                            sum(q * nnq for q, nnq in zip(Q, NN[1:])) + k
                        ) / NN[0]
                    ),
                    end="", flush=True
                )
            # vals[0].pre_loop(x, *loop_args, **loop_kwargs)
            loop_func(x, *loop_args, **loop_kwargs)
            # vals[0].post_loop(x, *loop_args, **loop_kwargs)
        if printflag:
            print(
                _fmt.format(
                    *([q + 1 for q in Q] + [k + 1]),
                    100.0 * (
                        sum(q * nnq for q, nnq in zip(Q, NN[1:])) + k + 1
                    ) / NN[0]
                ),
                end="", flush=True
            )
    elif len(vals) == 2 and not reps_first:
        for k, x in enumerate(vals[0]):
            vals[0].pre_loop(x, *loop_args, **loop_kwargs)
            _nested_loop_over(
                [ScanVals(len(vals[1]) * [x])],
                loop_func,
                Q + [k],
                NN + [prod([len(X) for X in vals[1:]])],
                fmt + f"  {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}",
                loop_args,
                loop_kwargs,
                reps_first,
                printflag,
            )
            vals[0].post_loop(x, *loop_args, **loop_kwargs)
    else:
        for k, x in enumerate(vals[0]):
            vals[0].pre_loop(x, *loop_args, **loop_kwargs)
            _nested_loop_over(
                vals[1:],
                loop_func,
                Q + [k],
                NN + [prod([len(X) for X in vals[1:]])],
                fmt + f"  {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}",
                loop_args,
                loop_kwargs,
                reps_first,
                printflag,
            )
            vals[0].post_loop(x, *loop_args, **loop_kwargs)

def ndscan(
    vals: list[ScanVals],
    reps: int,
    loop_func: Callable[[...], ...],
    loop_args: Sequence[...]=None,
    loop_kwargs: dict[str, ...]=None,
    reps_first: bool=True,
    printflag: bool=True,
) -> None:
    if reps_first:
        _vals = [ScanVals(list(range(reps)))] + vals
    else:
        _vals = vals + [ScanVals(list(range(reps)))]
    TOT = prod(len(X) for X in _vals)
    t0 = timeit.default_timer()
    _nested_loop_over(
        _vals,
        loop_func,
        list(),
        [TOT],
        "",
        list() if loop_args is None else loop_args,
        dict() if loop_kwargs is None else loop_kwargs,
        reps_first,
        printflag,
    )
    T = timeit.default_timer() - t0
    if printflag:
        print(
            f"\n  total time elapsed: {T:.2f} s"
            f"\n  average time per shot: {T / TOT:.2f} s"
        )

def map_scanvars(
    scan_maps: Sequence[tuple[Callable[[T, ...], ...], dict[str, Sequence[T]]]],
    map_args: Sequence[...]=None,
    map_kwargs: dict[str, ...]=None,
    reps: int=None,
    reps_first: bool=True,
    printflag: bool=True,
) -> tuple[list[str], list[list[...]]]:
    _map_args = tuple() if map_args is None else map_args
    _map_kwargs = dict() if map_kwargs is None else map_kwargs
    allvars = {"REP": list(range(reps))} if reps is not None and reps_first else dict()
    for m in scan_maps:
        allvars.update(m[1])
    if reps is not None and not reps_first:
        allvars["REP"]
    fmt = ("_".join(v + "={:+.5f}" for v in allvars.keys())).replace("REP=", "")
    if printflag:
        print("[control] computing iterable scanning variables ... ",
            end="", flush=True)
    names = [fmt.format(*V) for V in product(allvars.values())]
    mapped = [
        [m[0](*XYZ, *_map_args, **_map_kwargs) for XYZ in product(m[1].values())]
        for m in scan_maps
    ]
    if printflag:
        print(f"done {tuple(len(X) for X in ret)}")
    return names, mapped