From d260f0a6d855992ddbb0f736782376f0c7b394b1 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Sat, 23 Jan 2021 21:55:57 -0600
Subject: [PATCH] Added integrated test case (also example)

---
 predtuner/approxapp.py    |  6 +++++-
 test/integrated_tuning.py | 27 +++++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)
 create mode 100644 test/integrated_tuning.py

diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py
index a4c276a..cd24abc 100644
--- a/predtuner/approxapp.py
+++ b/predtuner/approxapp.py
@@ -3,8 +3,8 @@ import logging
 from pathlib import Path
 from typing import Dict, Generic, List, NamedTuple, Optional, Tuple, TypeVar, Union
 
-import numpy as np
 import matplotlib.pyplot as plt
+import numpy as np
 from opentuner.measurement.interface import MeasurementInterface
 from opentuner.resultsdb.models import Configuration, Result
 from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
@@ -162,12 +162,16 @@ class ApproxTuner(Generic[T]):
         return [configs[i] for i in taken_idx]
 
     def write_configs_to_dir(self, directory: PathLike):
+        import os
+
         from jsonpickle import encode
 
         if not self.tuned:
             raise RuntimeError(
                 f"No tuning session has been run; call self.tune() first."
             )
+        directory = Path(directory)
+        os.makedirs(directory, exist_ok=True)
         encode(self.kept_configs, directory)
 
     def plot_configs(self) -> plt.Figure:
diff --git a/test/integrated_tuning.py b/test/integrated_tuning.py
new file mode 100644
index 0000000..bfe1fb0
--- /dev/null
+++ b/test/integrated_tuning.py
@@ -0,0 +1,27 @@
+import site
+from pathlib import Path
+
+import torch
+from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.dataset import Subset
+
+site.addsitedir(Path(__file__).absolute().parent.parent)
+from model_zoo import CIFAR, VGG16Cifar10
+from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
+
+msg_logger = config_pylogger(output_dir="tuner_results/logs", verbose=True)
+
+dataset = CIFAR.from_file(
+    "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
+)
+tune_loader = DataLoader(Subset(dataset, range(5000)), batch_size=500)
+calib_loader = DataLoader(Subset(dataset, range(5000, 10000)), batch_size=500)
+module = VGG16Cifar10()
+module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
+app = TorchApp(
+    "TestTorchApp", module, tune_loader, calib_loader, get_knobs_from_file(), accuracy,
+)
+baseline, _ = app.measure_qos_perf({}, False)
+tuner = app.get_tuner()
+tuner.tune(500, 2.1, 3.0, True, 50)
+tuner.write_configs_to_dir("tuner_results/test")
-- 
GitLab