Skip to content
Snippets Groups Projects
Commit 9e26e065 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Save model P1 with torch.save

parent 895fc7cd
No related branches found
No related tags found
No related merge requests found
import abc
import json
import logging
import os
import pickle
from pathlib import Path
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
......@@ -239,7 +240,7 @@ class QoSModelP1(IQoSModel):
:param qos_metric: A function that compute a QoS level from the return value
of `tensor_output_getter`.
:param storage: A `pickle` file to store this model into, if the file doesn't exist,
:param storage: A file of PyTorch format to store this model into, if the file doesn't exist,
or load the model from if the file exists.
If not given, the model will not be stored.
"""
......@@ -277,42 +278,22 @@ class QoSModelP1(IQoSModel):
return float(self.qos_metric(ret))
def _init(self):
if self.storage and self.storage.is_file():
self.delta_tensors, self.baseline_tensor = torch.load(self.storage)
dt = self.delta_tensors
btensor = self.baseline_tensor = self.output_f({})
if self.storage and self.storage.is_file():
for op, knob, delta_tensor in self._load(self.storage):
dt[op][knob] = delta_tensor
updated = False
for op, knob in barred_ravel_knobs(self.app):
if dt[op][knob] is not None:
continue
updated = True
delta_tensor = self.output_f({op: knob}) - btensor
dt[op][knob] = delta_tensor
self._try_append_save(self.storage, op, knob, delta_tensor)
if self.storage and updated:
os.makedirs(self.storage.parent, exist_ok=True)
torch.save((dt, btensor), self.storage)
super()._init()
def _load(self, path: Path) -> Iterator[Tuple[str, str, torch.Tensor]]:
msg_logger.info(f"Model {self.name} found saved model at {path}")
with path.open("rb") as f:
while True:
try:
op_name, knob_name, tensor = pickle.load(f)
yield op_name, knob_name, tensor
except EOFError:
return
@staticmethod
def _try_append_save(
path: Optional[Path], op_name: str, knob_name: str, tensor: torch.Tensor
):
import os
if not path:
return
if not path.parent.is_dir():
os.makedirs(path.parent)
with path.open("ab") as f:
pickle.dump((op_name, knob_name, tensor), f)
class QoSModelP2(IQoSModel):
"""QoS model `P1` in ApproxTuner.
......@@ -389,8 +370,6 @@ class QoSModelP2(IQoSModel):
self.baseline_qos = float(data["bqos"])
def _save(self, path: Path):
import os
if not path.parent.is_dir():
os.makedirs(path.parent)
with path.open("w") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment