From e749ea6288431a53f839b621cc3e38facbf824de Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Sun, 16 Sep 2018 17:53:17 +0300
Subject: [PATCH] Clean up PyTorch 0.3 compatibility code (#46)

* Clean up PyTorch 0.3 compatibility code
We don't need this anymore and PyTorch 1.0 is just around the corner.

* explicitly place the inputs tensor on the GPU(s)
---
 .../compress_classifier.py                    | 77 ++++++-------------
 1 file changed, 24 insertions(+), 53 deletions(-)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index e95d92f..943a970 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -35,8 +35,8 @@ For each epoch:
 train():
     For each training step:
         compression_scheduler.on_minibatch_begin(epoch)
-        output = model(input_var)
-        loss = criterion(output, target_var)
+        output = model(input)
+        loss = criterion(output, target)
         compression_scheduler.before_backward_pass(epoch)
         loss.backward()
         optimizer.step()
@@ -378,27 +378,24 @@ def train(train_loader, model, criterion, optimizer, epoch,
     for train_step, (inputs, target) in enumerate(train_loader):
         # Measure data loading time
         data_time.add(time.time() - end)
-
-        target = target.cuda(async=True)
-        input_var = inputs.cuda()
-        target_var = torch.autograd.Variable(target)
+        inputs, target = inputs.to('cuda'), target.to('cuda')
 
         # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
             compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer)
 
         if args.kd_policy is None:
-            output = model(input_var)
+            output = model(inputs)
         else:
-            output = args.kd_policy.forward(input_var)
+            output = args.kd_policy.forward(inputs)
 
         if not args.earlyexit_lossweights:
-            loss = criterion(output, target_var)
+            loss = criterion(output, target)
             # Measure accuracy and record loss
             classerr.add(output.data, target)
         else:
             # Measure accuracy and record loss
-            loss = earlyexit_loss(output, target_var, criterion, args)
+            loss = earlyexit_loss(output, target, criterion, args)
 
         losses[OBJECTIVE_LOSS_KEY].add(loss.item())
 
@@ -482,7 +479,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
             args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)))
             args.losses_exits.append(tnt.AverageValueMeter())
         args.exit_taken = [0] * args.num_exits
-    
+
     batch_time = tnt.AverageValueMeter()
     total_samples = len(data_loader.sampler)
     batch_size = data_loader.batch_size
@@ -496,16 +493,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
 
     end = time.time()
     for validation_step, (inputs, target) in enumerate(data_loader):
-        with PytorchNoGrad():
-            target = target.cuda(async=True)
-            input_var = get_inference_var(inputs)
-            target_var = get_inference_var(target)
+        with torch.no_grad():
+            inputs, target = inputs.to('cuda'), target.to('cuda')
             # compute output from model
-            output = model(input_var)
+            output = model(inputs)
 
             if not args.earlyexit_thresholds:
                 # compute loss
-                loss = criterion(output, target_var)
+                loss = criterion(output, target)
                 # measure accuracy and record loss
                 losses['objective_loss'].add(loss.item())
                 classerr.add(output.data, target)
@@ -514,7 +509,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
             else:
                 # If using Early Exit, then compute outputs at all exits - output is now a list of all exits
                 # from exit0 through exitN (i.e. [exit0, exit1, ... exitN])
-                earlyexit_validate_loss(output, target_var, criterion, args)
+                earlyexit_validate_loss(output, target, criterion, args)
 
             # measure elapsed time
             batch_time.add(time.time() - end)
@@ -548,7 +543,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
     if not args.earlyexit_thresholds:
         msglogger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n',
                        classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
-        
+
         if args.display_confusion:
             msglogger.info('==> Confusion:\n%s\n', str(confusion.value()))
         return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
@@ -573,65 +568,41 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
         return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1]
 
 
-class PytorchNoGrad(object):
-    """This is a temporary class to bridge some difference between PyTorch 3.x and 4.x"""
-    def __init__(self):
-        self.no_grad = None
-        if torch.__version__ >= '0.4':
-            self.no_grad = torch.no_grad()
-
-    def __enter__(self):
-        if self.no_grad:
-            return self.no_grad.__enter__()
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        if self.no_grad:
-            return self.no_grad.__exit__(self, exc_type, exc_val, exc_tb)
-
-
-def get_inference_var(tensor):
-    """This is a temporary function to bridge some difference between PyTorch 3.x and 4.x"""
-    tensor = tensor.cuda(async=True)
-    if torch.__version__ >= '0.4':
-        return torch.autograd.Variable(tensor)
-    return torch.autograd.Variable(tensor, volatile=True)
-
-
-def earlyexit_loss(output, target_var, criterion, args):
+def earlyexit_loss(output, target, criterion, args):
     loss = 0
     sum_lossweights = 0
     for exitnum in range(args.num_exits-1):
-        loss += (args.earlyexit_lossweights[exitnum] * criterion(output[exitnum], target_var))
+        loss += (args.earlyexit_lossweights[exitnum] * criterion(output[exitnum], target))
         sum_lossweights += args.earlyexit_lossweights[exitnum]
-        args.exiterrors[exitnum].add(output[exitnum].data, target_var)
+        args.exiterrors[exitnum].add(output[exitnum].data, target)
     # handle final exit
-    loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target_var)
-    args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target_var)
+    loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target)
+    args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target)
     return loss
 
 
-def earlyexit_validate_loss(output, target_var, criterion, args):
+def earlyexit_validate_loss(output, target, criterion, args):
     for exitnum in range(args.num_exits):
-        args.loss_exits[exitnum] = criterion(output[exitnum], target_var)
+        args.loss_exits[exitnum] = criterion(output[exitnum], target)
         args.losses_exits[exitnum].add(args.loss_exits[exitnum].item())
 
     # We need to go through this batch itself - this is now a vector of losses through the batch.
     # Collecting stats on which exit early can be done across the batch at this time.
     # Note that we can't use batch_size as last batch might be smaller
-    this_batch_size = target_var.size()[0]
+    this_batch_size = target.size()[0]
     for batchnum in range(this_batch_size):
         # take the exit using CrossEntropyLoss as confidence measure (lower is more confident)
         for exitnum in range(args.num_exits-1):
             if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]:
                 # take the results from early exit since lower than threshold
                 args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)),
-                        torch.full([1], target_var[batchnum], dtype=torch.long))
+                        torch.full([1], target[batchnum], dtype=torch.long))
                 args.exit_taken[exitnum] += 1
             else:
                 # skip the early exits and include results from end of net
                 args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum],
                                                                             ndmin=2)),
-                        torch.full([1], target_var[batchnum], dtype=torch.long))
+                        torch.full([1], target[batchnum], dtype=torch.long))
                 args.exit_taken[args.num_exits-1] += 1
 
 
-- 
GitLab