Skip to content
Snippets Groups Projects
Commit 9870739c authored by Peter Pao-Huang's avatar Peter Pao-Huang
Browse files

Changed evaluate method to conform to distiller requirements

parent c9ce1871
No related branches found
No related tags found
No related merge requests found
......@@ -54,7 +54,7 @@ def quantize(
if not os.path.isfile(stats_file):
# generates `stats_file`
collect_quant_stats(
model, lambda model: get_loss(model, dataloader), save_dir=working_dir
model, get_loss, dataloader, save_dir=working_dir
)
# Generate Quantized Scales
......@@ -142,7 +142,7 @@ def get_loss(model: nn.Module, dataloader: DataLoader):
@torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader):
def evaluate(model: nn.Module, dataloader: DataLoader = None):
model.eval()
correct = 0
total = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment