diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/approxknobs.json b/hpvm/projects/torch2hpvm/torch2hpvm/approxknobs.json index 974b536c48cd1d5ab96120cfd0c5e9510846df17..9d7cb28a8b3fcc2301735c21e99119beb5a89907 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/approxknobs.json +++ b/hpvm/projects/torch2hpvm/torch2hpvm/approxknobs.json @@ -2,7 +2,8 @@ { "name": "11", "speedup": 1.0, - "applies_to": null + "applies_to": null, + "is_baseline": true }, { "name": "12", diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py index cc2a670dad75661a296dcb4465a8de56358630b5..5e27224298b2f645a6d7375e3101afd0e18bd525 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py @@ -98,14 +98,21 @@ class ModelExporter: KnobInfoT = Tuple[str, float] ty_knobs: Dict[str, List[KnobInfoT]] = defaultdict(list) default_knobs: List[KnobInfoT] = [] + baseline_knob = None for k in knobs: - applies_to = k.pop("applies_to") - k = k["name"], k["speedup"] + kp = k["name"], k["speedup"] + if "is_baseline" in k: + if baseline_knob: + raise ValueError("Multiple baseline knobs") + baseline_knob = k["name"] + applies_to = k["applies_to"] if applies_to is None: - default_knobs.append(k) + default_knobs.append(kp) continue for ty in applies_to: - ty_knobs[ty].append(k) + ty_knobs[ty].append(kp) + if not baseline_knob: + raise ValueError("No baseline knob given") idx = 0 op_cost: Dict[str, int] = {} op_knobs: Dict[str, List[str]] = {} @@ -127,6 +134,9 @@ class ModelExporter: "op_cost": op_cost, "knob_speedup": knob_speedup, "op_knobs": op_knobs, + "baseline_knob": baseline_knob, + "tune_args": "tune", + "test_args": "test" }, f, indent=2,