diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index f82544e7763909a816cd91a7c81805b159a0c7fb..7564bb81a61f2c6a31b982c28b9be6eb4c21d426 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -93,11 +93,10 @@ class ApproxApp(abc.ABC): return baseline_knob def add_baseline_to_knobs(self, approxes: KnobsT): - approxes = approxes.copy() - for op_name in self.ops: - if op_name not in approxes: - approxes[op_name] = self.baseline_knob.name - return approxes + return { + op_name: approxes.get(op_name, self.baseline_knob.name) + for op_name in self.ops + } class BaselineKnob(ApproxKnob): @@ -230,7 +229,7 @@ class ApproxTuner(Generic[T]): filepath = Path(filepath) os.makedirs(filepath.parent, exist_ok=True) with filepath.open("w") as f: - f.write(encode(self.kept_configs, indent=2)) + f.write(encode(self.best_configs, indent=2)) def plot_configs(self) -> plt.Figure: if not self.tuned: