diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index e6fc056b72f19c7f2c275a17fcfc637109b4bfd4..5f848831bf0fb80709d3b06ddd315bd83ced64f2 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
 
     if os.path.isfile(chkpt_file):
         msglogger.info("=> loading checkpoint %s", chkpt_file)
-        checkpoint = torch.load(chkpt_file)
+        checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
         msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
         start_epoch = checkpoint['epoch'] + 1
         best_top1 = checkpoint.get('best_top1', None)
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index e04ce067d6a7e40a34fb8f2f962d51e32485797c..08808666d02ba39796bd3cdb4ec046545423e151 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -213,7 +213,8 @@ def model_performance_summary(model, dummy_input, batch_size=1):
     model = distiller.make_non_parallel_copy(model)
     model.apply(install_perf_collector)
     # Now run the forward path and collect the data
-    model(dummy_input.cuda())
+    dummy_input = dummy_input.to(distiller.model_device(model))
+    model(dummy_input)
     # Unregister from the forward hooks
     for handle in hook_handles:
         handle.remove()
diff --git a/distiller/thinning.py b/distiller/thinning.py
index ad3ff26d3fb0909901b0021b2de70dadc48b9039..5cdff7f2ff14491073dfa3429d4978a3d793f11d 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -68,12 +68,13 @@ def create_graph(dataset, arch):
     if dataset == 'imagenet':
         dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
     elif dataset == 'cifar10':
-        dummy_input = torch.randn((1, 3, 32, 32))
+        dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
     assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
 
     model = create_model(False, dataset, arch, parallel=False)
     assert model is not None
-    return SummaryGraph(model, dummy_input.cuda())
+    dummy_input = dummy_input.to(distiller.model_device(model))
+    return SummaryGraph(model, dummy_input)
 
 
 def param_name_2_layer_name(param_name):
@@ -486,7 +487,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                     msglogger.debug("[thinning] {}: setting {} to {}".
                                     format(layer_name, attr, indices_to_select.nelement()))
                     setattr(layers[layer_name], attr,
-                            torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
+                            torch.index_select(running, dim=dim_to_trim, index=indices_to_select.to(running.device)))
             else:
                 msglogger.debug("[thinning] {}: setting {} to {}".format(layer_name, attr, val))
                 setattr(layers[layer_name], attr, val)
@@ -521,13 +522,13 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                     param.grad = param.grad.resize_(*directive[3])
             else:
                 if param.data.size(dim) != len_indices:
-                    param.data = torch.index_select(param.data, dim, indices)
+                    param.data = torch.index_select(param.data, dim, indices.to(param.device))
                     msglogger.debug("[thinning] changed param {} shape: {}".format(param_name, len_indices))
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
                 if param.grad is not None and param.grad.size(dim) != len_indices:
-                    param.grad = torch.index_select(param.grad, dim, indices)
+                    param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
                     if optimizer_thinning(optimizer, param, dim, indices):
                         msglogger.debug("Updated velocity buffer %s" % param_name)
 
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index da2c3b629446b721aecae93751f566fb2f35dc0a..be03a27410034d216c63021e1a5ddb18c80c0ec9 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -62,7 +62,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         view_2d = param.view(-1, param.size(2) * param.size(3))
         # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
         #    thresholds tensor with length = #IFMs * # OFMs
-        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).to(param.device)
         # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
         #    elements in each channel as the threshold filter.
         # 3. Apply the threshold filter
@@ -71,20 +71,20 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
 
     elif group_type == 'Rows':
         assert param.dim() == 2, "This regularization is only supported for 2D weights"
-        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
         binary_map = threshold_policy(param, thresholds, threshold_criteria)
         return binary_map
 
     elif group_type == 'Cols':
         assert param.dim() == 2, "This regularization is only supported for 2D weights"
-        thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(1)).to(param.device)
         binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0)
         return binary_map
 
     elif group_type == '3D' or group_type == 'Filters':
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
         view_filters = param.view(param.size(0), -1)
-        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
         binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
         return binary_map
 
@@ -109,7 +109,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         # Next, compute the sum of the squares (of the elements in each row/kernel)
         kernel_means = view_2d.abs().mean(dim=1)
         k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
-        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
+        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).to(param.device)
         binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
         return binary_map
 
diff --git a/distiller/utils.py b/distiller/utils.py
index b1e70c8c9b1bbc6426fcbef72a9f1aa0563d4cd8..2c4bcf066bf53d7f08b77965088c1f041a6ee2a7 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -25,6 +25,14 @@ import torch.nn as nn
 from copy import deepcopy
 
 
+def model_device(model):
+    """Determine the device the model is allocated on."""
+    # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180
+    if next(model.parameters()).is_cuda:
+        return 'cuda'
+    return 'cpu'
+
+
 def to_np(var):
     return var.data.cpu().numpy()
 
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index d0b53db6512d6ebda4043a9789342ebf185f725b..daf46fb8fe1b6d51804b008736d5ed98c3693cb2 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -142,6 +142,10 @@ parser.add_argument('--deterministic', '--det', action='store_true',
                     help='Ensure deterministic execution for re-producible results.')
 parser.add_argument('--gpus', metavar='DEV_ID', default=None,
                     help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
+parser.add_argument('--cpu', action='store_true',
+                    help='Use CPU only. \n'
+                    'Flag not set => uses GPUs according to the --gpus flag value.'
+                    'Flag set => overrides the --gpus flag')
 parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
 parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
 parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
@@ -290,20 +294,25 @@ def main():
         # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
         cudnn.benchmark = True
 
-    if args.gpus is not None:
-        try:
-            args.gpus = [int(s) for s in args.gpus.split(',')]
-        except ValueError:
-            msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
-            exit(1)
-        available_gpus = torch.cuda.device_count()
-        for dev_id in args.gpus:
-            if dev_id >= available_gpus:
-                msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
-                                .format(dev_id, available_gpus))
+    if args.cpu is not None or not torch.cuda.is_available():
+        # Set GPU index to -1 if using CPU
+        args.device = 'cpu'
+    else:
+        args.device = 'cuda'
+        if args.gpus is not None:
+            try:
+                args.gpus = [int(s) for s in args.gpus.split(',')]
+            except ValueError:
+                msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
                 exit(1)
-        # Set default device in case the first one on the list != 0
-        torch.cuda.set_device(args.gpus[0])
+            available_gpus = torch.cuda.device_count()
+            for dev_id in args.gpus:
+                if dev_id >= available_gpus:
+                    msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
+                                    .format(dev_id, available_gpus))
+                    exit(1)
+            # Set default device in case the first one on the list != 0
+            torch.cuda.set_device(args.gpus[0])
 
     # Infer the dataset from the model name
     args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
@@ -332,10 +341,11 @@ def main():
     if args.resume:
         model, compression_scheduler, start_epoch = apputils.load_checkpoint(
             model, chkpt_file=args.resume)
-        model.cuda()
+        model.to(args.device)
 
     # Define loss function (criterion) and optimizer
-    criterion = nn.CrossEntropyLoss().cuda()
+    criterion = nn.CrossEntropyLoss().to(args.device)
+
     optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
@@ -372,7 +382,7 @@ def main():
         # requires a compression schedule configuration file in YAML.
         compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler)
         # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
-        model.cuda()
+        model.to(args.device)
     elif compression_scheduler is None:
         compression_scheduler = distiller.CompressionScheduler(model)
 
@@ -476,7 +486,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
     for train_step, (inputs, target) in enumerate(train_loader):
         # Measure data loading time
         data_time.add(time.time() - end)
-        inputs, target = inputs.to('cuda'), target.to('cuda')
+        inputs, target = inputs.to(args.device), target.to(args.device)
 
         # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
@@ -600,7 +610,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
     end = time.time()
     for validation_step, (inputs, target) in enumerate(data_loader):
         with torch.no_grad():
-            inputs, target = inputs.to('cuda'), target.to('cuda')
+            inputs, target = inputs.to(args.device), target.to(args.device)
             # compute output from model
             output = model(inputs)
 
@@ -675,7 +685,8 @@ def earlyexit_validate_loss(output, target, criterion, args):
     # but with a grouping of samples equal to the batch size.
     # Note that final group might not be a full batch - so determine actual size.
     this_batch_size = target.size()[0]
-    earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).cuda()
+    earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).to(args.device)
+
     for exitnum in range(args.num_exits):
         # calculate losses at each sample separately in the minibatch.
         args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target)
@@ -744,7 +755,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector
                                                           args.qe_bits_accum, args.qe_mode, args.qe_clip_acts,
                                                           args.qe_no_clip_layers, args.qe_per_channel)
         quantizer.prepare_model()
-        model.cuda()
+        model.to(args.device)
 
     top1, _, _ = test(test_loader, model, criterion, loggers, activations_collectors, args=args)
 
diff --git a/models/__init__.py b/models/__init__.py
index 7bde40697c3cc328541a2f790b43efe47c587071..193951d579dd36ac663998ad2387e41986a56109 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -51,6 +51,10 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         dataset:
         arch:
         parallel:
+        device_ids: Devices on which model should be created -
+            None - GPU if available, otherwise CPU
+            -1 - CPU
+            >=0 - GPU device IDs
     """
     msglogger.info('==> using %s dataset' % dataset)
 
@@ -81,5 +85,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     elif parallel:
         model = torch.nn.DataParallel(model, device_ids=device_ids)
 
-    model.cuda()
+    if torch.cuda.is_available() and device_ids != -1:
+        model.cuda()
+
     return model