From 40251fc874bc2c4d6399674a597c5e3caee8c477 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Sat, 23 Jan 2021 23:35:24 -0600 Subject: [PATCH] Handle float64 downstream --- predtuner/approxapp.py | 6 ++++-- predtuner/torchapp.py | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index fbc7eda..c7e9415 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -241,7 +241,8 @@ class TunerInterface(MeasurementInterface): self.pbar = tqdm(total=test_limit, leave=False) self.app_kwargs = app_kwargs - objective = ThresholdAccuracyMinimizeTime(tuner_thres) + # tune_thres can come in as np.float64 and opentuner doesn't like that + objective = ThresholdAccuracyMinimizeTime(float(tuner_thres)) input_manager = FixedInputManager(size=len(self.app.op_knobs)) super(TunerInterface, self).__init__( args, @@ -263,7 +264,8 @@ class TunerInterface(MeasurementInterface): from opentuner.resultsdb.models import Result cfg = desired_result.configuration.data - qos, perf = self.app.measure_qos_perf(cfg, False) + qos, perf = self.app.measure_qos_perf(cfg, False, **self.app_kwargs) + qos, perf = float(qos), float(perf) # Print a debug message for each config in tuning and keep threshold self.print_debug_config(qos, perf) self.pbar.update() diff --git a/predtuner/torchapp.py b/predtuner/torchapp.py index 172abb0..052a378 100644 --- a/predtuner/torchapp.py +++ b/predtuner/torchapp.py @@ -110,8 +110,7 @@ class TorchApp(ModeledApp, abc.ABC): end = begin + len(target) qos = self.tensor_to_qos(tensor_output[begin:end], target) qoses.append(qos) - # float64 -> float - return float(self.combine_qos(np.array(qoses))) + return self.combine_qos(np.array(qoses)) return [ LinearPerfModel(self._op_costs, self._knob_speedups), @@ -137,7 +136,7 @@ class TorchApp(ModeledApp, abc.ABC): qoses.append(self.tensor_to_qos(outputs, targets)) time_end = time_ns() / (10 ** 9) qos = self.combine_qos(np.array(qoses)) - return float(qos), time_end - time_begin # float64->float + return qos, time_end - time_begin def __repr__(self) -> str: class_name = self.__class__.__name__ -- GitLab