Commit 379a3c2f authored by whooie's avatar whooie
Browse files

add convenience functions for looping over parameter values

parent 4c417806
......@@ -4,6 +4,9 @@ import inspect
import ctypes
import sys
import time
from math import prod, log10
from itertools import product
from typing import Callable, Iterable, Sequence, TypeVar
def _async_raise(tid, exctype):
"""
......@@ -198,3 +201,148 @@ class Controller:
self.cmd = argv[0]
command_handler[argv[0]](*argv[1:])
T = TypeVar("T")
class ScanVals:
def __init__(
self,
vals: Sequence[T],
pre_loop: Callable[[T, ...], ...]=None,
post_loop: Callable[[T, ...], ...]=None,
):
self.vals = vals
self.pre_loop = lambda *args, **kwargs: None \
if pre_loop is None else pre_loop
self.post_loop = lambda *args, **kwargs: None \
if post_loop is None else post_loop
def __getitem__(self, idx: int) -> T:
return self.vals[idx]
def __contains__(self, thing: T) -> bool:
return thing in self.vals
def __len__(self) -> int:
return len(self.vals)
def __iter__(self) -> Iterable:
return iter(self.vals)
def _nested_loop_over(
vals: list[ScanVals],
loop_func: Callable[[T, ...], ...],
Q: list[int],
NN: list[int],
fmt: str,
loop_args: Sequence[...],
loop_kwargs: dict[str, ...],
reps_first: bool,
printflag: bool,
) -> None:
if len(vals) == 0:
raise Exception("_nested_loop_over: must have values to loop over")
elif len(vals) == 1:
_fmt = fmt + f" {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}"
_fmt += " ({:6.2f}%) "
for k, x in enumerate(vals[0]):
if printflag:
print(
_fmt.format(
*(Q + [k]),
100.0 * (
sum(q * nnq for q, nnq in zip(Q, NN[1:])) + k
) / NN[0]
),
end="", flush=True
)
# vals[0].pre_loop(x, *loop_args, **loop_kwargs)
loop_func(x, *loop_args, **loop_kwargs)
# vals[0].post_loop(x, *loop_args, **loop_kwargs)
if printflag:
print(
_fmt.format(
*([q + 1 for q in Q] + [k]),
100.0 * (
sum(q * nnq for q, nnq in zip(Q, NN[1:])) + k + 1
) / NN[0]
),
end="", flush=True
)
elif len(vals) == 2 and not reps_first:
print(loop_args, loop_kwargs)
for k, x in enumerate(vals[0]):
vals[0].pre_loop(x, *loop_args, **loop_kwargs)
_nested_loop_over(
[ScanVals(len(vals[1]) * [x])],
loop_func,
Q + [k],
NN + [prod([len(X) for X in vals[1:]])],
fmt + f" {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}",
loop_args,
loop_kwargs,
reps_first,
printflag,
)
vals[0].post_loop(x, *loop_args, **loop_kwargs)
else:
for k, x in enumerate(vals[0]):
vals[0].pre_loop(x, *loop_args, **loop_kwargs)
_nested_loop_over(
vals[1:],
loop_func,
Q + [k],
NN + [prod([len(X) for X in vals[1:]])],
fmt + f" {{:{int(log10(len(vals[0]))) + 1}}}/{len(vals[0])}",
loop_args,
loop_kwargs,
reps_first,
printflag,
)
vals[0].post_loop(x, *loop_args, **loop_kwargs)
def ndscan(
vals: list[ScanVals],
reps: int,
loop_func: Callable[[...], ...],
loop_args: Sequence[...]=None,
loop_kwargs: dict[str, ...]=None,
reps_first: bool=True,
printflag: bool=True,
) -> None:
if reps_first:
_vals = [ScanVals(list(range(reps)))] + vals
else:
_vals = vals + [ScanVals(list(range(reps)))]
_nested_loop_over(
_vals,
loop_func,
list(),
[prod(len(X) for X in _vals)],
f"\r ",
list() if loop_args is None else loop_args,
dict() if loop_kwargs is None else loop_kwargs,
reps_first,
printflag,
)
if printflag:
print("")
def map_scanvars(
scan_maps: Sequence[tuple[Callable[[T, ...], ...], Sequence[Sequence[T]]]],
map_args: Sequence[...]=None,
map_kwargs: dict[str, ...]=None,
printflag: bool=True,
) -> list[list[...]]:
_map_args = tuple() if map_args is None else map_args
_map_kwargs = dict() if map_kwargs is None else map_kwargs
if printflag:
print("[control] computing iterable scanning variables ... ",
end="", flush=True)
ret = [
[ m[0](*XYZ, *_map_args, **_map_kwargs) for XYZ in product(*m[1]) ]
for m in scan_maps
]
if printflag:
print(f"done {tuple(len(X) for X in ret)}")
return ret
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment