diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 6b42383b4b13ffa6a8d6e9f3f41481fe5074af74..f73d6b898686a4afdec865390530b03962cd64d6 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -566,11 +566,11 @@ def train(train_loader, model, criterion, optimizer, epoch,
         if not early_exit_mode(args):
             loss = criterion(output, target)
             # Measure accuracy
-            classerr.add(output.data, target)
+            classerr.add(output.detach(), target)
             acc_stats.append([classerr.value(1), classerr.value(5)])
         else:
             # Measure accuracy and record loss
-            classerr.add(output[args.num_exits-1].data, target) # add the last exit (original exit)
+            classerr.add(output[args.num_exits-1].detach(), target) # add the last exit (original exit)
             loss = earlyexit_loss(output, target, criterion, args)
         # Record loss
         losses[OBJECTIVE_LOSS_KEY].add(loss.item())
@@ -698,9 +698,9 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
                 loss = criterion(output, target)
                 # measure accuracy and record loss
                 losses['objective_loss'].add(loss.item())
-                classerr.add(output.data, target)
+                classerr.add(output.detach(), target)
                 if args.display_confusion:
-                    confusion.add(output.data, target)
+                    confusion.add(output.detach(), target)
             else:
                 earlyexit_validate_loss(output, target, criterion, args)
 
@@ -751,10 +751,10 @@ def earlyexit_loss(output, target, criterion, args):
     for exitnum in range(args.num_exits-1):
         exit_loss = criterion(output[exitnum], target)
         weighted_loss += args.earlyexit_lossweights[exitnum] * exit_loss
-        args.exiterrors[exitnum].add(output[exitnum].data, target)
+        args.exiterrors[exitnum].add(output[exitnum].detach(), target)
     # handle final exit
     weighted_loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target)
-    args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target)
+    args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].detach(), target)
     return weighted_loss