Skip to content
Snippets Groups Projects
Unverified Commit e749ea62 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Clean up PyTorch 0.3 compatibility code (#46)

* Clean up PyTorch 0.3 compatibility code
We don't need this anymore and PyTorch 1.0 is just around the corner.

* explicitly place the inputs tensor on the GPU(s)
parent 5d3d6d8d
No related branches found
No related tags found
No related merge requests found
...@@ -35,8 +35,8 @@ For each epoch: ...@@ -35,8 +35,8 @@ For each epoch:
train(): train():
For each training step: For each training step:
compression_scheduler.on_minibatch_begin(epoch) compression_scheduler.on_minibatch_begin(epoch)
output = model(input_var) output = model(input)
loss = criterion(output, target_var) loss = criterion(output, target)
compression_scheduler.before_backward_pass(epoch) compression_scheduler.before_backward_pass(epoch)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -378,27 +378,24 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -378,27 +378,24 @@ def train(train_loader, model, criterion, optimizer, epoch,
for train_step, (inputs, target) in enumerate(train_loader): for train_step, (inputs, target) in enumerate(train_loader):
# Measure data loading time # Measure data loading time
data_time.add(time.time() - end) data_time.add(time.time() - end)
inputs, target = inputs.to('cuda'), target.to('cuda')
target = target.cuda(async=True)
input_var = inputs.cuda()
target_var = torch.autograd.Variable(target)
# Execute the forward phase, compute the output and measure loss # Execute the forward phase, compute the output and measure loss
if compression_scheduler: if compression_scheduler:
compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer) compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer)
if args.kd_policy is None: if args.kd_policy is None:
output = model(input_var) output = model(inputs)
else: else:
output = args.kd_policy.forward(input_var) output = args.kd_policy.forward(inputs)
if not args.earlyexit_lossweights: if not args.earlyexit_lossweights:
loss = criterion(output, target_var) loss = criterion(output, target)
# Measure accuracy and record loss # Measure accuracy and record loss
classerr.add(output.data, target) classerr.add(output.data, target)
else: else:
# Measure accuracy and record loss # Measure accuracy and record loss
loss = earlyexit_loss(output, target_var, criterion, args) loss = earlyexit_loss(output, target, criterion, args)
losses[OBJECTIVE_LOSS_KEY].add(loss.item()) losses[OBJECTIVE_LOSS_KEY].add(loss.item())
...@@ -482,7 +479,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -482,7 +479,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))) args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)))
args.losses_exits.append(tnt.AverageValueMeter()) args.losses_exits.append(tnt.AverageValueMeter())
args.exit_taken = [0] * args.num_exits args.exit_taken = [0] * args.num_exits
batch_time = tnt.AverageValueMeter() batch_time = tnt.AverageValueMeter()
total_samples = len(data_loader.sampler) total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size batch_size = data_loader.batch_size
...@@ -496,16 +493,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -496,16 +493,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
end = time.time() end = time.time()
for validation_step, (inputs, target) in enumerate(data_loader): for validation_step, (inputs, target) in enumerate(data_loader):
with PytorchNoGrad(): with torch.no_grad():
target = target.cuda(async=True) inputs, target = inputs.to('cuda'), target.to('cuda')
input_var = get_inference_var(inputs)
target_var = get_inference_var(target)
# compute output from model # compute output from model
output = model(input_var) output = model(inputs)
if not args.earlyexit_thresholds: if not args.earlyexit_thresholds:
# compute loss # compute loss
loss = criterion(output, target_var) loss = criterion(output, target)
# measure accuracy and record loss # measure accuracy and record loss
losses['objective_loss'].add(loss.item()) losses['objective_loss'].add(loss.item())
classerr.add(output.data, target) classerr.add(output.data, target)
...@@ -514,7 +509,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -514,7 +509,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
else: else:
# If using Early Exit, then compute outputs at all exits - output is now a list of all exits # If using Early Exit, then compute outputs at all exits - output is now a list of all exits
# from exit0 through exitN (i.e. [exit0, exit1, ... exitN]) # from exit0 through exitN (i.e. [exit0, exit1, ... exitN])
earlyexit_validate_loss(output, target_var, criterion, args) earlyexit_validate_loss(output, target, criterion, args)
# measure elapsed time # measure elapsed time
batch_time.add(time.time() - end) batch_time.add(time.time() - end)
...@@ -548,7 +543,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -548,7 +543,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
if not args.earlyexit_thresholds: if not args.earlyexit_thresholds:
msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n',
classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean) classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
if args.display_confusion: if args.display_confusion:
msglogger.info('==> Confusion:\n%s\n', str(confusion.value())) msglogger.info('==> Confusion:\n%s\n', str(confusion.value()))
return classerr.value(1), classerr.value(5), losses['objective_loss'].mean return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
...@@ -573,65 +568,41 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -573,65 +568,41 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1] return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1]
class PytorchNoGrad(object): def earlyexit_loss(output, target, criterion, args):
"""This is a temporary class to bridge some difference between PyTorch 3.x and 4.x"""
def __init__(self):
self.no_grad = None
if torch.__version__ >= '0.4':
self.no_grad = torch.no_grad()
def __enter__(self):
if self.no_grad:
return self.no_grad.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.no_grad:
return self.no_grad.__exit__(self, exc_type, exc_val, exc_tb)
def get_inference_var(tensor):
"""This is a temporary function to bridge some difference between PyTorch 3.x and 4.x"""
tensor = tensor.cuda(async=True)
if torch.__version__ >= '0.4':
return torch.autograd.Variable(tensor)
return torch.autograd.Variable(tensor, volatile=True)
def earlyexit_loss(output, target_var, criterion, args):
loss = 0 loss = 0
sum_lossweights = 0 sum_lossweights = 0
for exitnum in range(args.num_exits-1): for exitnum in range(args.num_exits-1):
loss += (args.earlyexit_lossweights[exitnum] * criterion(output[exitnum], target_var)) loss += (args.earlyexit_lossweights[exitnum] * criterion(output[exitnum], target))
sum_lossweights += args.earlyexit_lossweights[exitnum] sum_lossweights += args.earlyexit_lossweights[exitnum]
args.exiterrors[exitnum].add(output[exitnum].data, target_var) args.exiterrors[exitnum].add(output[exitnum].data, target)
# handle final exit # handle final exit
loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target_var) loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target)
args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target_var) args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target)
return loss return loss
def earlyexit_validate_loss(output, target_var, criterion, args): def earlyexit_validate_loss(output, target, criterion, args):
for exitnum in range(args.num_exits): for exitnum in range(args.num_exits):
args.loss_exits[exitnum] = criterion(output[exitnum], target_var) args.loss_exits[exitnum] = criterion(output[exitnum], target)
args.losses_exits[exitnum].add(args.loss_exits[exitnum].item()) args.losses_exits[exitnum].add(args.loss_exits[exitnum].item())
# We need to go through this batch itself - this is now a vector of losses through the batch. # We need to go through this batch itself - this is now a vector of losses through the batch.
# Collecting stats on which exit early can be done across the batch at this time. # Collecting stats on which exit early can be done across the batch at this time.
# Note that we can't use batch_size as last batch might be smaller # Note that we can't use batch_size as last batch might be smaller
this_batch_size = target_var.size()[0] this_batch_size = target.size()[0]
for batchnum in range(this_batch_size): for batchnum in range(this_batch_size):
# take the exit using CrossEntropyLoss as confidence measure (lower is more confident) # take the exit using CrossEntropyLoss as confidence measure (lower is more confident)
for exitnum in range(args.num_exits-1): for exitnum in range(args.num_exits-1):
if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]: if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]:
# take the results from early exit since lower than threshold # take the results from early exit since lower than threshold
args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)), args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)),
torch.full([1], target_var[batchnum], dtype=torch.long)) torch.full([1], target[batchnum], dtype=torch.long))
args.exit_taken[exitnum] += 1 args.exit_taken[exitnum] += 1
else: else:
# skip the early exits and include results from end of net # skip the early exits and include results from end of net
args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum],
ndmin=2)), ndmin=2)),
torch.full([1], target_var[batchnum], dtype=torch.long)) torch.full([1], target[batchnum], dtype=torch.long))
args.exit_taken[args.num_exits-1] += 1 args.exit_taken[args.num_exits-1] += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment