Skip to content
Snippets Groups Projects
Commit 9870739c authored by Peter Pao-Huang's avatar Peter Pao-Huang
Browse files

Changed evaluate method to conform to distiller requirements

parent c9ce1871
No related branches found
No related tags found
No related merge requests found
...@@ -54,7 +54,7 @@ def quantize( ...@@ -54,7 +54,7 @@ def quantize(
if not os.path.isfile(stats_file): if not os.path.isfile(stats_file):
# generates `stats_file` # generates `stats_file`
collect_quant_stats( collect_quant_stats(
model, lambda model: get_loss(model, dataloader), save_dir=working_dir model, get_loss, dataloader, save_dir=working_dir
) )
# Generate Quantized Scales # Generate Quantized Scales
...@@ -142,7 +142,7 @@ def get_loss(model: nn.Module, dataloader: DataLoader): ...@@ -142,7 +142,7 @@ def get_loss(model: nn.Module, dataloader: DataLoader):
@torch.no_grad() @torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader): def evaluate(model: nn.Module, dataloader: DataLoader = None):
model.eval() model.eval()
correct = 0 correct = 0
total = 0 total = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment