diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index e6fc056b72f19c7f2c275a17fcfc637109b4bfd4..5f848831bf0fb80709d3b06ddd315bd83ced64f2 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): if os.path.isfile(chkpt_file): msglogger.info("=> loading checkpoint %s", chkpt_file) - checkpoint = torch.load(chkpt_file) + checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys()))) start_epoch = checkpoint['epoch'] + 1 best_top1 = checkpoint.get('best_top1', None) diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index e04ce067d6a7e40a34fb8f2f962d51e32485797c..08808666d02ba39796bd3cdb4ec046545423e151 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -213,7 +213,8 @@ def model_performance_summary(model, dummy_input, batch_size=1): model = distiller.make_non_parallel_copy(model) model.apply(install_perf_collector) # Now run the forward path and collect the data - model(dummy_input.cuda()) + dummy_input = dummy_input.to(distiller.model_device(model)) + model(dummy_input) # Unregister from the forward hooks for handle in hook_handles: handle.remove() diff --git a/distiller/thinning.py b/distiller/thinning.py index ad3ff26d3fb0909901b0021b2de70dadc48b9039..5cdff7f2ff14491073dfa3429d4978a3d793f11d 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -68,12 +68,13 @@ def create_graph(dataset, arch): if dataset == 'imagenet': dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) elif dataset == 'cifar10': - dummy_input = torch.randn((1, 3, 32, 32)) + dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) model = create_model(False, dataset, arch, parallel=False) assert model is not None - return SummaryGraph(model, dummy_input.cuda()) + dummy_input = dummy_input.to(distiller.model_device(model)) + return SummaryGraph(model, dummy_input) def param_name_2_layer_name(param_name): @@ -486,7 +487,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr msglogger.debug("[thinning] {}: setting {} to {}". format(layer_name, attr, indices_to_select.nelement())) setattr(layers[layer_name], attr, - torch.index_select(running, dim=dim_to_trim, index=indices_to_select)) + torch.index_select(running, dim=dim_to_trim, index=indices_to_select.to(running.device))) else: msglogger.debug("[thinning] {}: setting {} to {}".format(layer_name, attr, val)) setattr(layers[layer_name], attr, val) @@ -521,13 +522,13 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr param.grad = param.grad.resize_(*directive[3]) else: if param.data.size(dim) != len_indices: - param.data = torch.index_select(param.data, dim, indices) + param.data = torch.index_select(param.data, dim, indices.to(param.device)) msglogger.debug("[thinning] changed param {} shape: {}".format(param_name, len_indices)) # We also need to change the dimensions of the gradient tensor. # If have not done a backward-pass thus far, then the gradient will # not exist, and therefore won't need to be re-dimensioned. if param.grad is not None and param.grad.size(dim) != len_indices: - param.grad = torch.index_select(param.grad, dim, indices) + param.grad = torch.index_select(param.grad, dim, indices.to(param.device)) if optimizer_thinning(optimizer, param, dim, indices): msglogger.debug("Updated velocity buffer %s" % param_name) diff --git a/distiller/thresholding.py b/distiller/thresholding.py index da2c3b629446b721aecae93751f566fb2f35dc0a..be03a27410034d216c63021e1a5ddb18c80c0ec9 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -62,7 +62,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria) view_2d = param.view(-1, param.size(2) * param.size(3)) # 1. Determine if the kernel "value" is below the threshold, by creating a 1D # thresholds tensor with length = #IFMs * # OFMs - thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda() + thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).to(param.device) # 2. Create a binary thresholds mask, where we use the mean of the abs values of the # elements in each channel as the threshold filter. # 3. Apply the threshold filter @@ -71,20 +71,20 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria) elif group_type == 'Rows': assert param.dim() == 2, "This regularization is only supported for 2D weights" - thresholds = torch.Tensor([threshold] * param.size(0)).cuda() + thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device) binary_map = threshold_policy(param, thresholds, threshold_criteria) return binary_map elif group_type == 'Cols': assert param.dim() == 2, "This regularization is only supported for 2D weights" - thresholds = torch.Tensor([threshold] * param.size(1)).cuda() + thresholds = torch.Tensor([threshold] * param.size(1)).to(param.device) binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0) return binary_map elif group_type == '3D' or group_type == 'Filters': assert param.dim() == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) - thresholds = torch.Tensor([threshold] * param.size(0)).cuda() + thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device) binary_map = threshold_policy(view_filters, thresholds, threshold_criteria) return binary_map @@ -109,7 +109,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria) # Next, compute the sum of the squares (of the elements in each row/kernel) kernel_means = view_2d.abs().mean(dim=1) k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t() - thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda() + thresholds = torch.Tensor([threshold] * num_kernels_per_filter).to(param.device) binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type()) return binary_map diff --git a/distiller/utils.py b/distiller/utils.py index b1e70c8c9b1bbc6426fcbef72a9f1aa0563d4cd8..2c4bcf066bf53d7f08b77965088c1f041a6ee2a7 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -25,6 +25,14 @@ import torch.nn as nn from copy import deepcopy +def model_device(model): + """Determine the device the model is allocated on.""" + # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180 + if next(model.parameters()).is_cuda: + return 'cuda' + return 'cpu' + + def to_np(var): return var.data.cpu().numpy() diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index d0b53db6512d6ebda4043a9789342ebf185f725b..daf46fb8fe1b6d51804b008736d5ed98c3693cb2 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -142,6 +142,10 @@ parser.add_argument('--deterministic', '--det', action='store_true', help='Ensure deterministic execution for re-producible results.') parser.add_argument('--gpus', metavar='DEV_ID', default=None, help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)') +parser.add_argument('--cpu', action='store_true', + help='Use CPU only. \n' + 'Flag not set => uses GPUs according to the --gpus flag value.' + 'Flag set => overrides the --gpus flag') parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints') parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, @@ -290,20 +294,25 @@ def main(): # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled. cudnn.benchmark = True - if args.gpus is not None: - try: - args.gpus = [int(s) for s in args.gpus.split(',')] - except ValueError: - msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') - exit(1) - available_gpus = torch.cuda.device_count() - for dev_id in args.gpus: - if dev_id >= available_gpus: - msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available' - .format(dev_id, available_gpus)) + if args.cpu is not None or not torch.cuda.is_available(): + # Set GPU index to -1 if using CPU + args.device = 'cpu' + else: + args.device = 'cuda' + if args.gpus is not None: + try: + args.gpus = [int(s) for s in args.gpus.split(',')] + except ValueError: + msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') exit(1) - # Set default device in case the first one on the list != 0 - torch.cuda.set_device(args.gpus[0]) + available_gpus = torch.cuda.device_count() + for dev_id in args.gpus: + if dev_id >= available_gpus: + msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available' + .format(dev_id, available_gpus)) + exit(1) + # Set default device in case the first one on the list != 0 + torch.cuda.set_device(args.gpus[0]) # Infer the dataset from the model name args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' @@ -332,10 +341,11 @@ def main(): if args.resume: model, compression_scheduler, start_epoch = apputils.load_checkpoint( model, chkpt_file=args.resume) - model.cuda() + model.to(args.device) # Define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() + criterion = nn.CrossEntropyLoss().to(args.device) + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) @@ -372,7 +382,7 @@ def main(): # requires a compression schedule configuration file in YAML. compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler) # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) - model.cuda() + model.to(args.device) elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) @@ -476,7 +486,7 @@ def train(train_loader, model, criterion, optimizer, epoch, for train_step, (inputs, target) in enumerate(train_loader): # Measure data loading time data_time.add(time.time() - end) - inputs, target = inputs.to('cuda'), target.to('cuda') + inputs, target = inputs.to(args.device), target.to(args.device) # Execute the forward phase, compute the output and measure loss if compression_scheduler: @@ -600,7 +610,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): end = time.time() for validation_step, (inputs, target) in enumerate(data_loader): with torch.no_grad(): - inputs, target = inputs.to('cuda'), target.to('cuda') + inputs, target = inputs.to(args.device), target.to(args.device) # compute output from model output = model(inputs) @@ -675,7 +685,8 @@ def earlyexit_validate_loss(output, target, criterion, args): # but with a grouping of samples equal to the batch size. # Note that final group might not be a full batch - so determine actual size. this_batch_size = target.size()[0] - earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).cuda() + earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).to(args.device) + for exitnum in range(args.num_exits): # calculate losses at each sample separately in the minibatch. args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target) @@ -744,7 +755,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector args.qe_bits_accum, args.qe_mode, args.qe_clip_acts, args.qe_no_clip_layers, args.qe_per_channel) quantizer.prepare_model() - model.cuda() + model.to(args.device) top1, _, _ = test(test_loader, model, criterion, loggers, activations_collectors, args=args) diff --git a/models/__init__.py b/models/__init__.py index 7bde40697c3cc328541a2f790b43efe47c587071..193951d579dd36ac663998ad2387e41986a56109 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -51,6 +51,10 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): dataset: arch: parallel: + device_ids: Devices on which model should be created - + None - GPU if available, otherwise CPU + -1 - CPU + >=0 - GPU device IDs """ msglogger.info('==> using %s dataset' % dataset) @@ -81,5 +85,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): elif parallel: model = torch.nn.DataParallel(model, device_ids=device_ids) - model.cuda() + if torch.cuda.is_available() and device_ids != -1: + model.cuda() + return model