diff --git a/.gitignore b/.gitignore index b820feca28116dca2f6b739ac23ce4f3adc78f29..8f641a310621fb58fd6d9b2cade1474b9f6e6588 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.pyc __pycache__/ .pytest_cache +.cache *.tar site/ env/ diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 83ae242b4c0d8c9281e3d87226a0d005ab79341a..16c9c98b90a0e2b49e9df6d2b9db302dd5849f19 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -28,7 +28,8 @@ import distiller msglogger = logging.getLogger() -def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=None, is_best=False, name=None): +def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, + best_top1=None, is_best=False, name=None, dir='.'): """Save a pytorch training checkpoint Args: @@ -40,10 +41,16 @@ def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=Non best_top1: the best top1 score seen so far is_best: True if this is the best (top1 accuracy) model so far name: the name of the checkpoint file + dir: directory in which to save the checkpoint """ msglogger.info("Saving checkpoint") + if not os.path.isdir(dir): + msglogger.info("Error: Directory to save checkpoint doesn't exist - {0}".format(os.path.abspath(dir))) + exit(1) filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar' + fullpath = os.path.join(dir, filename) filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar' + fullpath_best = os.path.join(dir, filename_best) checkpoint = {} checkpoint['epoch'] = epoch checkpoint['arch'] = arch @@ -56,10 +63,12 @@ def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=Non checkpoint['compression_sched'] = scheduler.state_dict() if hasattr(model, 'thinning_recipes'): checkpoint['thinning_recipes'] = model.thinning_recipes + if hasattr(model, 'quantizer_metadata'): + checkpoint['quantizer_metadata'] = model.quantizer_metadata - torch.save(checkpoint, filename) + torch.save(checkpoint, fullpath) if is_best: - shutil.copyfile(filename, filename_best) + shutil.copyfile(fullpath, fullpath_best) def load_checkpoint(model, chkpt_file, optimizer=None): @@ -86,20 +95,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None): compression_scheduler.load_state_dict(checkpoint['compression_sched']) msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) + else: + msglogger.info("Warning: compression schedule data does not exist in the checkpoint") if 'thinning_recipes' in checkpoint: if 'compression_sched' not in checkpoint: - raise KeyError("Found thinning_recipes key, but missing mandatoy key compression_sched") + raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched") msglogger.info("Loaded a thinning recipe from the checkpoint") # Cache the recipes in case we need them later model.thinning_recipes = checkpoint['thinning_recipes'] distiller.execute_thinning_recipes_list(model, compression_scheduler.zeros_mask_dict, model.thinning_recipes) - else: - msglogger.info("Warning: compression schedule data does not exist in the checkpoint") - msglogger.info("=> loaded checkpoint '%s' (epoch %d)", - chkpt_file, checkpoint['epoch']) + + if 'quantizer_metadata' in checkpoint: + msglogger.info('Loaded quantizer metadata from the checkpoint') + qmd = checkpoint['quantizer_metadata'] + quantizer = qmd['type'](model, **qmd['params']) + quantizer.prepare_model() + + msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch']) model.load_state_dict(checkpoint['state_dict']) return model, compression_scheduler, start_epoch diff --git a/apputils/data_loaders.py b/apputils/data_loaders.py index 41f7bd867dc361a5c76fc9dad1cd2e5be1044c6f..0771bef0f693b59503b057c9eeeeaa2325283631 100755 --- a/apputils/data_loaders.py +++ b/apputils/data_loaders.py @@ -29,7 +29,7 @@ import numpy as np DATASETS_NAMES = ['imagenet', 'cifar10'] -def load_data(dataset, data_dir, batch_size, workers, deterministic=False): +def load_data(dataset, data_dir, batch_size, workers, valid_size=0.1, deterministic=False): """Load a dataset. Args: @@ -37,14 +37,15 @@ def load_data(dataset, data_dir, batch_size, workers, deterministic=False): data_dir: the directory where the datset resides batch_size: the batch size workers: the number of worker threads to use for loading the data + valid_size: portion of training dataset to set aside for validation deterministic: set to True if you want the data loading process to be deterministic. Note that deterministic data loading suffers from poor performance. """ assert dataset in DATASETS_NAMES if dataset == 'cifar10': - return cifar10_load_data(data_dir, batch_size, workers, deterministic=deterministic) + return cifar10_load_data(data_dir, batch_size, workers, valid_size=valid_size, deterministic=deterministic) if dataset == 'imagenet': - return imagenet_load_data(data_dir, batch_size, workers, deterministic=deterministic) + return imagenet_load_data(data_dir, batch_size, workers, valid_size=valid_size, deterministic=deterministic) print("FATAL ERROR: load_data does not support dataset %s" % dataset) exit(1) @@ -73,7 +74,7 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi We transform them to Tensors of normalized range [-1, 1] https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py - Data augmentation: 4 pixels are padded on each side, and a 32x32 crop is randomly sampled + Data augmentation: 4 pixels are padded on each side, and a 32x32 crop is randomly sampled from the padded image or its horizontal flip. This is similar to [1] and some other work that use CIFAR10. @@ -103,7 +104,6 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi train_idx, valid_idx = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_idx) - valid_sampler = SubsetRandomSampler(valid_idx) worker_init_fn = __deterministic_worker_init_fn if deterministic else None @@ -112,10 +112,13 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi num_workers=num_workers, pin_memory=True, worker_init_fn=worker_init_fn) - valid_loader = torch.utils.data.DataLoader(train_dataset, - batch_size=batch_size, sampler=valid_sampler, - num_workers=num_workers, pin_memory=True, - worker_init_fn=worker_init_fn) + valid_loader = None + if split > 0: + valid_sampler = SubsetRandomSampler(valid_idx) + valid_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=batch_size, sampler=valid_sampler, + num_workers=num_workers, pin_memory=True, + worker_init_fn=worker_init_fn) testset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test) @@ -125,7 +128,9 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi num_workers=num_workers, pin_memory=True) input_shape = __image_size(train_dataset) - return train_loader, valid_loader, test_loader, input_shape + + # If validation split was 0 we use the test set as the validation set + return train_loader, valid_loader or test_loader, test_loader, input_shape def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, deterministic=False): @@ -159,7 +164,6 @@ def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determ train_idx, valid_idx = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_idx) - valid_sampler = SubsetRandomSampler(valid_idx) input_shape = __image_size(train_dataset) @@ -170,10 +174,13 @@ def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determ num_workers=num_workers, pin_memory=True, worker_init_fn=worker_init_fn) - valid_loader = torch.utils.data.DataLoader(train_dataset, - batch_size=batch_size, sampler=valid_sampler, - num_workers=num_workers, pin_memory=True, - worker_init_fn=worker_init_fn) + valid_loader = None + if split > 0: + valid_sampler = SubsetRandomSampler(valid_idx) + valid_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=batch_size, sampler=valid_sampler, + num_workers=num_workers, pin_memory=True, + worker_init_fn=worker_init_fn) test_loader = torch.utils.data.DataLoader( datasets.ImageFolder(test_dir, transforms.Compose([ @@ -185,4 +192,5 @@ def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determ batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) - return train_loader, valid_loader, test_loader, input_shape + # If validation split was 0 we use the test set as the validation set + return train_loader, valid_loader or test_loader, test_loader, input_shape diff --git a/apputils/execution_env.py b/apputils/execution_env.py index 52441828263f903b00135dd448e0ffee36739913..09d079d5bf9fb58177b2cd41d6ca70cbbbee3961 100755 --- a/apputils/execution_env.py +++ b/apputils/execution_env.py @@ -58,7 +58,6 @@ def log_execution_env_state(app_args, gitroot='.'): if repo.is_dirty(): logger.debug("Git is dirty") - #repo.index.diff(None) try: branch_name = repo.active_branch.name except TypeError: @@ -80,7 +79,7 @@ def log_execution_env_state(app_args, gitroot='.'): logger.debug("App args: %s", app_args) -def config_pylogger(log_cfg_file, experiment_name): +def config_pylogger(log_cfg_file, experiment_name, output_dir='logs'): """Configure the Python logger. For each execution of the application, we'd like to create a unique log directory. @@ -90,11 +89,11 @@ def config_pylogger(log_cfg_file, experiment_name): TensorBoard, for example. """ timestr = time.strftime("%Y.%m.%d-%H%M%S") - filename = timestr if experiment_name is None else experiment_name + '___' + timestr - logdir = os.path.join('./logs', filename) + exp_full_name = timestr if experiment_name is None else experiment_name + '___' + timestr + logdir = os.path.join(output_dir, exp_full_name) if not os.path.exists(logdir): os.makedirs(logdir) - log_filename = os.path.join(logdir, filename + '.log') + log_filename = os.path.join(logdir, exp_full_name + '.log') if os.path.isfile(log_cfg_file): logging.config.fileConfig(log_cfg_file, defaults={'logfilename': log_filename}) msglogger = logging.getLogger() diff --git a/distiller/__init__.py b/distiller/__init__.py index 448f08f8dd29812dc3d169e1a60790363fcab8f5..a3213b564dea9d592068f5c2ef9f35fb92558e17 100755 --- a/distiller/__init__.py +++ b/distiller/__init__.py @@ -16,7 +16,7 @@ from .utils import * from .thresholding import GroupThresholdMixin, threshold_mask -from .config import fileConfig, dictConfig +from .config import file_config, dict_config from .model_summaries import * from .scheduler import * from .sensitivity import * @@ -25,7 +25,7 @@ from .policy import * from .thinning import * #del utils -del dictConfig +del dict_config del thinning #del model_summaries #del scheduler diff --git a/distiller/config.py b/distiller/config.py index 21ede1d52ba8f24351327d99616364afe65e4026..a5cc9cd76a82b7c90e873f9d27a7cafc5244742c 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -33,27 +33,37 @@ When a YAML file is loaded, its dictionary is extracted and passed to ```dictCon """ import logging +from collections import OrderedDict import yaml import json import inspect -import distiller from torch.optim.lr_scheduler import * import distiller from distiller.thinning import * from distiller.pruning import * from distiller.regularization import L1Regularizer, GroupLassoRegularizer from distiller.learning_rate import * -logger = logging.getLogger("app_cfg") +from distiller.quantization import * + +msglogger = logging.getLogger() +app_cfg_logger = logging.getLogger("app_cfg") + -def dictConfig(model, optimizer, schedule, sched_dict, logger): - logger.debug(json.dumps(sched_dict, indent=1)) +def dict_config(model, optimizer, sched_dict): + app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2)) + + schedule = distiller.CompressionScheduler(model) pruners = __factory('pruners', model, sched_dict) regularizers = __factory('regularizers', model, sched_dict) - lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) + quantizers = __factory('quantizers', model, sched_dict) + if len(quantizers) > 1: + print("\nError: Multiple Quantizers not supported") + exit(1) extensions = __factory('extensions', model, sched_dict) try: + lr_policies = [] for policy_def in sched_dict['policies']: policy = None if 'pruner' in policy_def: @@ -76,11 +86,23 @@ def dictConfig(model, optimizer, schedule, sched_dict, logger): else: policy = distiller.RegularizationPolicy(regularizer, **args) + elif 'quantizer' in policy_def: + instance_name, args = __policy_params(policy_def, 'quantizer') + assert instance_name in quantizers, "Quantizer {} was not defined in the list of quantizers".format(instance_name) + quantizer = quantizers[instance_name] + policy = distiller.QuantizationPolicy(quantizer) + + # Quantizers for training modify the models parameters, need to update the optimizer + if quantizer.train_with_fp_copy: + optimizer_type = type(optimizer) + new_optimizer = optimizer_type(model.parameters(), **optimizer.defaults) + optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) + elif 'lr_scheduler' in policy_def: - instance_name, args = __policy_params(policy_def, 'lr_scheduler') - assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(instance_name) - lr_scheduler = lr_schedulers[instance_name] - policy = distiller.LRPolicy(lr_scheduler) + # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain + # a quantization policy was initialized (if exists) + lr_policies.append(policy_def) + continue elif 'extension' in policy_def: instance_name, args = __policy_params(policy_def, 'extension') @@ -92,12 +114,18 @@ def dictConfig(model, optimizer, schedule, sched_dict, logger): print("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]" % policy_def) exit(1) - if 'epochs' in policy_def: - schedule.add_policy(policy, epochs=policy_def['epochs']) - else: - schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'], - ending_epoch=policy_def['ending_epoch'], - frequency=policy_def['frequency']) + add_policy_to_scheduler(policy, policy_def, schedule) + + # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers + lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) + for policy_def in lr_policies: + instance_name, args = __policy_params(policy_def, 'lr_scheduler') + assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format( + instance_name) + lr_scheduler = lr_schedulers[instance_name] + policy = distiller.LRPolicy(lr_scheduler) + add_policy_to_scheduler(policy, policy_def, schedule) + except AssertionError: # propagate the assertion information raise @@ -108,12 +136,23 @@ def dictConfig(model, optimizer, schedule, sched_dict, logger): return schedule -def fileConfig(model, optimizer, schedule, filename, logger): + +def add_policy_to_scheduler(policy, policy_def, schedule): + if 'epochs' in policy_def: + schedule.add_policy(policy, epochs=policy_def['epochs']) + else: + schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'], + ending_epoch=policy_def['ending_epoch'], + frequency=policy_def['frequency']) + + +def file_config(model, optimizer, filename): """Read the schedule from file""" with open(filename, 'r') as stream: + msglogger.info('Reading compression schedule from: %s', filename) try: - sched_dict = yaml.load(stream) - dictConfig(model, optimizer, schedule, sched_dict, logger) + sched_dict = yaml_ordered_load(stream) + return dict_config(model, optimizer, sched_dict) except yaml.YAMLError as exc: print("\nFATAL Parsing error while parsing the pruning schedule configuration file %s" % filename) exit(1) @@ -165,3 +204,22 @@ def __policy_params(policy_def, type): name = policy_def[type]['instance_name'] args = policy_def[type].get('args', None) return name, args + + +def yaml_ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict): + """ + Function to load YAML file using an OrderedDict + See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts + """ + class OrderedLoader(Loader): + pass + + def construct_mapping(loader, node): + loader.flatten_mapping(node) + return object_pairs_hook(loader.construct_pairs(node)) + + OrderedLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, + construct_mapping) + + return yaml.load(stream, OrderedLoader) diff --git a/distiller/learning_rate.py b/distiller/learning_rate.py index 32325ee552dd0a712d2834c9b5830037cbad129c..657d0181c40d331a20d2d1ff23792ef74e904aed 100644 --- a/distiller/learning_rate.py +++ b/distiller/learning_rate.py @@ -14,6 +14,7 @@ # limitations under the License. # +from bisect import bisect_right from torch.optim.lr_scheduler import _LRScheduler @@ -36,4 +37,22 @@ class PolynomialLR(_LRScheduler): def get_lr(self): # base_lr * (1 - iter/max_iter) ^ (power) return [base_lr * (1 - self.last_epoch / self.T_max) ** self.power - for base_lr in self.base_lrs] \ No newline at end of file + for base_lr in self.base_lrs] + + +class MultiStepMultiGammaLR(_LRScheduler): + def __init__(self, optimizer, milestones, gammas, last_epoch=-1): + if not list(milestones) == sorted(milestones): + raise ValueError('Milestones should be a list of' + ' increasing integers. Got {}', milestones) + + self.milestones = milestones + self.multiplicative_gammas = [1] + for idx, gamma in enumerate(gammas): + self.multiplicative_gammas.append(gamma * self.multiplicative_gammas[idx]) + + super(MultiStepMultiGammaLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = bisect_right(self.milestones, self.last_epoch) + return [base_lr * self.multiplicative_gammas[idx] for base_lr in self.base_lrs] diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index b9edddcc28ebf3e778c63d883c754c8e45536fa9..908f3d667f8388d574ce1a4be52fcd8849e09773 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -31,12 +31,12 @@ import torch.optim import distiller msglogger = logging.getLogger() -__all__ = ['model_summary', 'optimizer_summary', \ - 'weights_sparsity_summary', 'weights_sparsity_tbl_summary', \ +__all__ = ['model_summary', + 'weights_sparsity_summary', 'weights_sparsity_tbl_summary', 'model_performance_summary', 'model_performance_tbl_summary'] from .data_loggers import PythonLogger, CsvLogger -def model_summary(model, optimizer, what, dataset=None): +def model_summary(model, what, dataset=None): if what == 'sparsity': pylogger = PythonLogger(msglogger) csvlogger = CsvLogger('weights.csv') @@ -55,8 +55,6 @@ def model_summary(model, optimizer, what, dataset=None): print(t) print("Total MACs: " + "{:,}".format(total_macs)) - elif what == 'optimizer': - optimizer_summary(optimizer) elif what == 'model': # print the simple form of the model print(model) @@ -72,21 +70,6 @@ def model_summary(model, optimizer, what, dataset=None): nodes.append([name, module.__class__.__name__]) print(tabulate(nodes, headers=['Name', 'Type'])) -def optimizer_summary(optimizer): - assert isinstance(optimizer, torch.optim.SGD) - lr = optimizer.param_groups[0]['lr'] - weight_decay = optimizer.param_groups[0]['weight_decay'] - momentum = optimizer.param_groups[0]['momentum'] - dampening = optimizer.param_groups[0]['dampening'] - nesterov = optimizer.param_groups[0]['nesterov'] - - msglogger.info('Optimizer:\n' - '\tmomentum={}' - '\tL2={}' - '\tLR={}' - '\tdampening={}' - '\tnesterov={}'.format(momentum, weight_decay, lr, dampening, nesterov)) - def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4]): diff --git a/distiller/policy.py b/distiller/policy.py index 595f35559076ac9a66ff8568f560c8c1908b42e3..ca3f1c1a8de2b9fa6da9b84b75c786c85ff5938a 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -20,10 +20,12 @@ - RegularizationPolicy: regulization scheduling - LRPolicy: learning-rate decay scheduling """ +import torch + import logging msglogger = logging.getLogger() -__all__ = ['PruningPolicy', 'RegularizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy'] +__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy'] class ScheduledTrainingPolicy(object): """ Base class for all scheduled training policies. @@ -130,3 +132,16 @@ class LRPolicy(ScheduledTrainingPolicy): def on_epoch_begin(self, model, zeros_mask_dict, meta): self.lr_scheduler.step() + + +class QuantizationPolicy(ScheduledTrainingPolicy): + def __init__(self, quantizer): + super(QuantizationPolicy, self).__init__() + self.quantizer = quantizer + self.quantizer.prepare_model() + self.quantizer.quantize_params() + + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict): + # After parameters update, quantize the parameters again + # (Doing this here ensures the model parameters are quantized at training completion (and at validation time) + self.quantizer.quantize_params() diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py index cdb45b7c6bd843eb0290b0cd02008b89df45879b..e50edbfe02fef5ca6eaea7227dfe21fa20c175e9 100644 --- a/distiller/quantization/__init__.py +++ b/distiller/quantization/__init__.py @@ -16,3 +16,8 @@ from .quantizer import Quantizer from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer +from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer + +del quantizer +del range_linear +del clipped_linear diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..7711a05cbe186dabc59bc8eabc5116df928e2d98 --- /dev/null +++ b/distiller/quantization/clipped_linear.py @@ -0,0 +1,109 @@ +import torch.nn as nn + +from .quantizer import Quantizer +from .q_utils import * +import logging +msglogger = logging.getLogger() + +### +# Clipping-based linear quantization (e.g. DoReFa, WRPN) +### + + +class LinearQuantizeSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, input, scale_factor, dequantize, inplace): + if inplace: + ctx.mark_dirty(input) + output = linear_quantize(input, scale_factor, inplace) + if dequantize: + output = linear_dequantize(output, scale_factor, inplace) + return output + + @staticmethod + def backward(ctx, grad_output): + # Straight-through estimator + return grad_output, None, None, None + + +class ClippedLinearQuantization(nn.Module): + def __init__(self, num_bits, clip_val, dequantize=True, inplace=False): + super(ClippedLinearQuantization, self).__init__() + self.num_bits = num_bits + self.clip_val = clip_val + self.scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val) + self.dequantize = dequantize + self.inplace = inplace + + def forward(self, input): + input = clamp(input, 0, self.clip_val, self.inplace) + input = LinearQuantizeSTE.apply(input, self.scale_factor, self.dequantize, self.inplace) + return input + + def __repr__(self): + inplace_str = ', inplace' if self.inplace else '' + return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val, + inplace_str) + + +class WRPNQuantizer(Quantizer): + """ + Quantizer using the WRPN quantization scheme, as defined in: + Mishra et al., WRPN: Wide Reduced-Precision Networks (https://arxiv.org/abs/1709.01134) + + Notes: + 1. This class does not take care of layer widening as described in the paper + 2. The paper defines special handling for 1-bit weights which isn't supported here yet + """ + def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}): + super(WRPNQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, + bits_overrides=bits_overrides, train_with_fp_copy=True) + + def wrpn_quantize_param(param_fp, num_bits): + scale_factor = symmetric_linear_quantization_scale_factor(num_bits, 1) + out = param_fp.clamp(-1, 1) + out = LinearQuantizeSTE.apply(out, scale_factor, True, False) + return out + + def relu_replace_fn(module, name, qbits_map): + bits_acts = qbits_map[name].acts + if bits_acts is None: + return module + return ClippedLinearQuantization(bits_acts, 1, dequantize=True, inplace=module.inplace) + + self.param_quantization_fn = wrpn_quantize_param + + self.replacement_factory[nn.ReLU] = relu_replace_fn + + +class DorefaQuantizer(Quantizer): + """ + Quantizer using the DoReFa scheme, as defined in: + Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients + (https://arxiv.org/abs/1606.06160) + + Notes: + 1. Gradients quantization not supported yet + 2. The paper defines special handling for 1-bit weights which isn't supported here yet + """ + def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}): + super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, + bits_overrides=bits_overrides, train_with_fp_copy=True) + + def dorefa_quantize_param(param_fp, num_bits): + scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1) + out = param_fp.tanh() + out = out / (2 * out.abs().max()) + 0.5 + out = LinearQuantizeSTE.apply(out, scale_factor, True, False) + out = 2 * out - 1 + return out + + def relu_replace_fn(module, name, qbits_map): + bits_acts = qbits_map[name].acts + if bits_acts is None: + return module + return ClippedLinearQuantization(bits_acts, 1, dequantize=True, inplace=module.inplace) + + self.param_quantization_fn = dorefa_quantize_param + + self.replacement_factory[nn.ReLU] = relu_replace_fn diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index c717d8baed40481c0bf40afe8b6a5a40ab5d5f65..6d1850bb19e6904cc88c64c15a36575be39b9588 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -16,16 +16,43 @@ from collections import namedtuple import re +import copy import logging +import torch +import torch.nn as nn msglogger = logging.getLogger() QBits = namedtuple('QBits', ['acts', 'wts']) +FP_BKP_PREFIX = 'float_' + + +def has_bias(module): + return hasattr(module, 'bias') and module.bias is not None + + +def hack_float_backup_parameter(module, name): + try: + data = dict(module.named_parameters())[name].data + except KeyError: + raise ValueError('Module has no Parameter named ' + name) + module.register_parameter(FP_BKP_PREFIX + name, nn.Parameter(data)) + module.__delattr__(name) + module.register_buffer(name, torch.zeros_like(data)) + + +class _ParamToQuant(object): + def __init__(self, module, fp_attr_name, q_attr_name, num_bits): + self.module = module + self.fp_attr_name = fp_attr_name + self.q_attr_name = q_attr_name + self.num_bits = num_bits + class Quantizer(object): r""" - Base class for quantizers + Base class for quantizers. Args: model (torch.nn.Module): The model to be quantized @@ -33,54 +60,135 @@ class Quantizer(object): Value of None means do not quantize. bits_overrides (dict): Dictionary mapping regular expressions of layer name patterns to dictionary with values for 'acts' and/or 'wts' to override the default values. + quantize_bias (bool): Flag indicating whether to quantize bias (w. same number of bits as weights) or not. + train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized and + floating-point copy, such that the following flow occurs in each training iteration: + 1. q_weights = quantize(fp_weights) + 2. Forward through network using q_weights + 3. In back-prop: + 3.1 Gradients calculated with respect to q_weights + 3.2 We also back-prop through the 'quantize' operation from step 1 + 4. Update fp_weights with gradients calculated in step 3.2 """ - def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides={}): + def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides={}, quantize_bias=False, + train_with_fp_copy=False): self.default_qbits = QBits(acts=bits_activations, wts=bits_weights) + self.quantize_bias = quantize_bias self.model = model + # Stash some quantizer data in the model so we can re-apply the quantizer on a resuming model + self.model.quantizer_metadata = {'type': type(self), + 'params': {'bits_activations': bits_activations, + 'bits_weights': bits_weights, + 'bits_overrides': copy.deepcopy(bits_overrides)}} + for k, v in bits_overrides.items(): qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts)) bits_overrides[k] = qbits # Prepare explicit mapping from each layer to QBits based on default + overrides + regex = None if bits_overrides: regex_str = '' keys_list = list(bits_overrides.keys()) for pattern in keys_list: regex_str += '(^{0}$)|'.format(pattern) - regex_str = regex_str[-1] # Remove trailing '|' + regex_str = regex_str[:-1] # Remove trailing '|' regex = re.compile(regex_str) - self.layer_qbits_map = {} - for layer_full_name, _ in model.named_modules(): - m = regex.match(layer_full_name) + self.module_qbits_map = {} + for module_full_name, module in model.named_modules(): + qbits = self.default_qbits + if regex: + # Need to account for scenario where model is parallelized with DataParallel, which wraps the original + # module with a wrapper module called 'module' :) + name_to_match = module_full_name.replace('module.', '', 1) + m = regex.match(name_to_match) if m: group_idx = 0 groups = m.groups() while groups[group_idx] is None: group_idx += 1 - self.layer_qbits_map[layer_full_name] = bits_overrides[keys_list[group_idx]] - else: - self.layer_qbits_map[layer_full_name] = self.default_qbits - else: - self.layer_qbits_map = {layer_full_name: self.default_qbits for layer_full_name, _ in model.named_modules()} + qbits = bits_overrides[keys_list[group_idx]] + self._add_qbits_entry(module_full_name, type(module), qbits) + # Mapping from module type to function generating a replacement module suited for quantization + # To be populated by child classes self.replacement_factory = {} + # Pointer to parameters quantization function, triggered during training process + # To be populated by child classes + self.param_quantization_fn = None + + self.train_with_fp_copy = train_with_fp_copy + self.params_to_quantize = [] + + def _add_qbits_entry(self, module_name, module_type, qbits): + if module_type not in [nn.Conv2d, nn.Linear]: + # For now we support weights quantization only for Conv and FC layers (so, for example, we don't + # support quantization of batch norm scale parameters) + qbits = QBits(acts=qbits.acts, wts=None) + self.module_qbits_map[module_name] = qbits def prepare_model(self): - msglogger.info('Preparing model for quantization') + r""" + Iterates over the model and replaces modules with their quantized counterparts as defined by + self.replacement_factory + """ + msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) self._pre_process_container(self.model) + for module_name, module in self.model.named_modules(): + qbits = self.module_qbits_map[module_name] + if qbits.wts is None: + continue + + curr_parameters = dict(module.named_parameters()) + for param_name, param in curr_parameters.items(): + if param_name.endswith('bias') and not self.quantize_bias: + continue + fp_attr_name = param_name + if self.train_with_fp_copy: + hack_float_backup_parameter(module, param_name) + fp_attr_name = FP_BKP_PREFIX + param_name + self.params_to_quantize.append(_ParamToQuant(module, fp_attr_name, param_name, qbits.wts)) + + param_full_name = '.'.join([module_name, param_name]) + msglogger.info( + "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, qbits.wts)) + + msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) + def _pre_process_container(self, container, prefix=''): # Iterate through model, insert quantization functions as appropriate for name, module in container.named_children(): full_name = prefix + name try: - new_module = self.replacement_factory[type(module)](module, full_name, self.layer_qbits_map) + new_module = self.replacement_factory[type(module)](module, full_name, self.module_qbits_map) msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module)) container._modules[name] = new_module + + # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping + if len(module._modules) == 0 and len(new_module._modules) > 0: + current_qbits = self.module_qbits_map[full_name] + for sub_module_name, module in new_module.named_modules(): + self._add_qbits_entry(full_name + '.' + sub_module_name, type(module), current_qbits) + self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None) except KeyError: + pass + + if len(module._modules) > 0: # For container we call recursively - if len(module._modules) > 0: - self._pre_process_container(module, full_name + '.') + self._pre_process_container(module, full_name + '.') + + def quantize_params(self): + """ + Quantize all parameters using the parameters using self.param_quantization_fn (using the defined number + of bits for each parameter) + """ + for ptq in self.params_to_quantize: + q_param = self.param_quantization_fn(ptq.module.__getattr__(ptq.fp_attr_name), ptq.num_bits) + if self.train_with_fp_copy: + ptq.module.__setattr__(ptq.q_attr_name, q_param) + else: + ptq.module.__getattr__(ptq.q_attr_name).data = q_param.data diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 31645cafeac40a2c0473dfa5b9b07a751d6976a0..bca1cb9926773b8511962fcc3b5ef0a2caf0a728 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -161,7 +161,8 @@ class SymmetricLinearQuantizer(Quantizer): """ def __init__(self, model, bits_activations=8, bits_parameters=8): super(SymmetricLinearQuantizer, self).__init__(model, bits_activations=bits_activations, - bits_weights=bits_parameters) + bits_weights=bits_parameters, + train_with_fp_copy=False) def replace_fn(module, name, qbits_map): return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 2f85151c2981e1f7601fad643550b3f6d9f295ff..e4c1d140c51a91a1de604936bbdf5dbe60d0fd92 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -21,6 +21,7 @@ This implements the scheduling of the compression policies. from functools import partial import logging import torch +from .quantization.quantizer import FP_BKP_PREFIX msglogger = logging.getLogger() @@ -127,7 +128,6 @@ class CompressionScheduler(object): policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch, self.zeros_mask_dict) - def on_epoch_end(self, epoch): if epoch in self.policies: for policy in self.policies[epoch]: @@ -135,10 +135,20 @@ class CompressionScheduler(object): meta['current_epoch'] = epoch policy.on_epoch_end(self.model, self.zeros_mask_dict, meta) - def apply_mask(self): for name, param in self.model.named_parameters(): - self.zeros_mask_dict[name].apply_mask(param) + try: + self.zeros_mask_dict[name].apply_mask(param) + except KeyError: + # Quantizers for training modify some model parameters by adding a prefix + # If this is the source of the error, workaround and move on + name_parts = name.split('.') + if name_parts[-1].startswith(FP_BKP_PREFIX): + name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, '', 1) + name = '.'.join(name_parts) + self.zeros_mask_dict[name].apply_mask(param) + else: + raise def state_dict(self): @@ -149,10 +159,10 @@ class CompressionScheduler(object): masks = {} for name, masker in self.zeros_mask_dict.items(): masks[name] = masker.mask - state = { 'masks_dict' : masks } + state = {'masks_dict': masks, + 'parallel_model': isinstance(self.model, torch.nn.DataParallel)} return state - def load_state_dict(self, state): """Loads the scheduler state. @@ -173,6 +183,17 @@ class CompressionScheduler(object): print("\t\t" + k) exit(1) + curr_model_parallel = isinstance(self.model, torch.nn.DataParallel) + # Fallback to 'True' for old checkpoints that don't have this attribute, since parallel=True is the + # default for create_model + loaded_model_parallel = state.get('parallel_model', True) for name, mask in self.zeros_mask_dict.items(): + # DataParallel modules wrap the actual module with a module named "module"... + if loaded_model_parallel and not curr_model_parallel: + load_name = 'module.' + name + elif curr_model_parallel and not loaded_model_parallel: + load_name = name.replace('module.', '', 1) + else: + load_name = name masker = self.zeros_mask_dict[name] - masker.mask = loaded_masks[name] + masker.mask = loaded_masks[load_name] diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 338ec78438c0bdbbc757dff475c34e744445983c..f33e171980a20646649a5b0dc3dd82b9b957f99c 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -80,6 +80,15 @@ import distiller.quantization as quantization from models import ALL_MODEL_NAMES, create_model msglogger = None + + +def float_range(val_str): + val = float(val_str) + if val < 0 or val >= 1: + raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str)) + return val + + parser = argparse.ArgumentParser(description='Distiller image classification model compression') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', @@ -111,7 +120,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true', help='collect activation statistics (WARNING: this slows down training)') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)') -SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'png_w_params'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, help='print a summary of the model, and exit - options: ' + ' | '.join(SUMMARY_CHOICES)) @@ -128,6 +137,9 @@ parser.add_argument('--quantize', action='store_true', parser.add_argument('--gpus', metavar='DEV_ID', default=None, help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)') parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') +parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints') +parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, + help='Portion of training dataset to set aside for validation') def check_pytorch_version(): @@ -142,11 +154,14 @@ def check_pytorch_version(): " 3. Activate the new environment") exit(1) + def main(): global msglogger check_pytorch_version() args = parser.parse_args() - msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir) # Log various details about the execution environment. It is sometimes useful # to refer to past experiment executions and this information may be useful. @@ -194,7 +209,7 @@ def main(): # Create the model png_summary = args.summary is not None and args.summary.startswith('png') - is_parallel = not png_summary # For PNG summary, parallel graphs are illegible + is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus) compression_scheduler = None @@ -208,18 +223,17 @@ def main(): model, compression_scheduler, start_epoch = apputils.load_checkpoint( model, chkpt_file=args.resume) - if 'resnet' in args.arch and 'cifar' in args.arch: + if 'resnet' in args.arch and 'preact' not in args.arch and 'cifar' in args.arch: distiller.resnet_cifar_remove_layers(model) #model = distiller.resnet_cifar_remove_channels(model, compression_scheduler.zeros_mask_dict) # Define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() - optimizer = torch.optim.SGD(model.parameters(), args.lr, + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - msglogger.info("Optimizer (%s): momentum=%s decay=%s", type(optimizer), - args.momentum, args.weight_decay) - + msglogger.info('Optimizer Type: %s', type(optimizer)) + msglogger.info('Optimizer Args: %s', optimizer.defaults) # This sample application can be invoked to produce various summary reports. if args.summary: @@ -227,7 +241,7 @@ def main(): if which_summary.startswith('png'): apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset, which_summary == 'png_w_params') else: - distiller.model_summary(model, optimizer, which_summary, args.dataset) + distiller.model_summary(model, which_summary, args.dataset) exit() # Load the datasets: the dataset to load is inferred from the model name passed @@ -235,7 +249,7 @@ def main(): # substring "_cifar", then cifar10 is used. train_loader, val_loader, test_loader, _ = apputils.load_data( args.dataset, os.path.expanduser(args.data), args.batch_size, - args.workers, args.deterministic) + args.workers, args.validation_size, args.deterministic) msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler)) @@ -276,17 +290,15 @@ def main(): top1, _, _ = test(test_loader, model, criterion, [pylogger], args.print_freq) if args.quantize: checkpoint_name = 'quantized' - apputils.save_checkpoint(0, args.arch, model, optimizer, best_top1=top1, - name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name) + apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, + name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name, + dir=msglogger.logdir) exit() if args.compress: - # The main use-case for this sample application is CNN compression. Compression + # The main use-case for this sample application is CNN compression. Compression # requires a compression schedule configuration file in YAML. - source = args.compress - msglogger.info("Compression schedule (source=%s)", source) - compression_scheduler = distiller.CompressionScheduler(model) - distiller.config.fileConfig(model, optimizer, compression_scheduler, args.compress, msglogger) + compression_scheduler = distiller.file_config(model, optimizer, args.compress) for epoch in range(start_epoch, start_epoch + args.epochs): # This is the main training loop. @@ -317,7 +329,8 @@ def main(): # remember best top1 and save checkpoint is_best = top1 > best_top1 best_top1 = max(top1, best_top1) - apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best, args.name) + apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best, + args.name, msglogger.logdir) # Finally run results on the test set test(test_loader, model, criterion, [pylogger], args.print_freq) @@ -353,7 +366,7 @@ def train(train_loader, model, criterion, optimizer, epoch, input_var = torch.autograd.Variable(inputs) target_var = torch.autograd.Variable(target) - # Execute the forard phase, compute the output and measure loss + # Execute the forward phase, compute the output and measure loss if compression_scheduler: compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch) output = model(input_var) diff --git a/examples/quantization/alexnet_bn_base_fp32.yaml b/examples/quantization/alexnet_bn_base_fp32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89b57238ac62dcbee7184e5f2c8d94d3d0893a80 --- /dev/null +++ b/examples/quantization/alexnet_bn_base_fp32.yaml @@ -0,0 +1,12 @@ +lr_schedulers: + training_lr: + class: MultiStepLR + milestones: [60, 75] + gamma: 0.2 + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 diff --git a/examples/quantization/alexnet_bn_dorefa.yaml b/examples/quantization/alexnet_bn_dorefa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ada4579a90c147466fc1d374f648cd63b667aa6 --- /dev/null +++ b/examples/quantization/alexnet_bn_dorefa.yaml @@ -0,0 +1,37 @@ +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 8 + bits_weights: 3 + bits_overrides: + features.0: + wts: null + acts: null + features.1: + wts: null + acts: null + classifier.5: + wts: null + acts: null + classifier.6: + wts: null + acts: null + +lr_schedulers: + training_lr: + class: MultiStepLR + milestones: [60, 75] + gamma: 0.2 + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 diff --git a/examples/quantization/preact_resnet18_imagenet_base_fp32.yaml b/examples/quantization/preact_resnet18_imagenet_base_fp32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b71df90d4abb743cedf2d55910bc77b0d7bff9b --- /dev/null +++ b/examples/quantization/preact_resnet18_imagenet_base_fp32.yaml @@ -0,0 +1,12 @@ +lr_schedulers: + training_lr: + class: MultiStepLR + milestones: [30, 60, 90, 100] + gamma: 0.1 + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 diff --git a/examples/quantization/preact_resnet18_imagenet_dorefa.yaml b/examples/quantization/preact_resnet18_imagenet_dorefa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..baf93d9da3f747596fae15d33673928043561cb1 --- /dev/null +++ b/examples/quantization/preact_resnet18_imagenet_dorefa.yaml @@ -0,0 +1,37 @@ +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 8 + bits_weights: 3 + bits_overrides: + conv1: + wts: null + acts: null + relu1: + wts: null + acts: null + final_relu: + wts: null + acts: null + fc: + wts: null + acts: null + +lr_schedulers: + training_lr: + class: MultiStepLR + milestones: [30, 60, 90, 100] + gamma: 0.1 + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 diff --git a/examples/quantization/preact_resnet20_cifar_base_fp32.yaml b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..792d971ae469b83d2de306c6c510ec9404d8c6eb --- /dev/null +++ b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml @@ -0,0 +1,12 @@ +lr_schedulers: + training_lr: + class: MultiStepMultiGammaLR + milestones: [80, 120, 160] + gammas: [0.1, 0.1, 0.2] + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 diff --git a/examples/quantization/preact_resnet20_cifar_dorefa.yaml b/examples/quantization/preact_resnet20_cifar_dorefa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..94851236f4c9a4c11d335a36178a3ad1a0e02b00 --- /dev/null +++ b/examples/quantization/preact_resnet20_cifar_dorefa.yaml @@ -0,0 +1,37 @@ +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 8 + bits_weights: 3 + bits_overrides: + conv1: + wts: null + acts: null + layer1.0.pre_relu: + wts: null + acts: null + final_relu: + wts: null + acts: null + fc: + wts: null + acts: null + +lr_schedulers: + training_lr: + class: MultiStepMultiGammaLR + milestones: [80, 120, 160] + gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 161 + frequency: 1 diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py index 328eb450feb67a1d38b2219c995cf2b1130c35cf..767b80c57d532fdaaf78c997dfbf463a1260be57 100755 --- a/examples/word_language_model/main.py +++ b/examples/word_language_model/main.py @@ -314,8 +314,7 @@ compression_scheduler = None if args.compress: # Create a CompressionScheduler and configure it from a YAML schedule file source = args.compress - compression_scheduler = distiller.CompressionScheduler(model) - distiller.config.fileConfig(model, None, compression_scheduler, args.compress, msglogger) + compression_scheduler = distiller.config.file_config(model, None, args.compress) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, diff --git a/models/cifar10/__init__.py b/models/cifar10/__init__.py index ffc6ed3897ea336eeb26b4b8629ca79bf71a19eb..3b72572744f7dc61529c05213624c9208f466678 100755 --- a/models/cifar10/__init__.py +++ b/models/cifar10/__init__.py @@ -18,3 +18,4 @@ from .simplenet_cifar import * from .resnet_cifar import * +from .preresnet_cifar import * diff --git a/models/cifar10/preresnet_cifar.py b/models/cifar10/preresnet_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..f8aa50c3f53ac3191e317101246385aa5cfff313 --- /dev/null +++ b/models/cifar10/preresnet_cifar.py @@ -0,0 +1,208 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pre-Activation ResNet for CIFAR10 + +Pre-Activation ResNet for CIFAR10, based on "Identity Mappings in Deep Residual Networks". +This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate +changes for pre-activation and the 10-class Cifar-10 dataset. +This ResNet also has layer gates, to be able to dynamically remove layers. + +@article{ + He2016, + author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, + title = {Identity Mappings in Deep Residual Networks}, + journal = {arXiv preprint arXiv:1603.05027}, + year = {2016} +} +""" +import torch.nn as nn +import math + +__all__ = ['preact_resnet20_cifar', 'preact_resnet32_cifar', 'preact_resnet44_cifar', 'preact_resnet56_cifar', + 'preact_resnet110_cifar', 'preact_resnet20_cifar_conv_ds', 'preact_resnet32_cifar_conv_ds', + 'preact_resnet44_cifar_conv_ds', 'preact_resnet56_cifar_conv_ds', 'preact_resnet110_cifar_conv_ds'] + +NUM_CLASSES = 10 + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class PreactBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None, preact_downsample=True): + super(PreactBasicBlock, self).__init__() + self.block_gates = block_gates + self.pre_bn = nn.BatchNorm2d(inplanes) + self.pre_relu = nn.ReLU(inplace=False) # To enable layer removal inplace must be False + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=False) + self.conv2 = conv3x3(planes, planes) + self.downsample = downsample + self.stride = stride + self.preact_downsample = preact_downsample + + def forward(self, x): + need_preact = self.block_gates[0] or self.block_gates[1] or self.downsample and self.preact_downsample + if need_preact: + preact = self.pre_bn(x) + preact = self.pre_relu(preact) + out = preact + else: + preact = out = x + + if self.block_gates[0]: + out = self.conv1(out) + out = self.bn(out) + out = self.relu(out) + + if self.block_gates[1]: + out = self.conv2(out) + + if self.downsample is not None: + if self.preact_downsample: + residual = self.downsample(preact) + else: + residual = self.downsample(x) + else: + residual = x + + out += residual + + return out + + +class PreactResNetCifar(nn.Module): + def __init__(self, block, layers, num_classes=NUM_CLASSES, conv_downsample=False): + self.nlayers = 0 + # Each layer manages its own gates + self.layer_gates = [] + for layer in range(3): + # For each of the 3 layers, create block gates: each block has two layers + self.layer_gates.append([]) # [True, True] * layers[layer]) + for blk in range(layers[layer]): + self.layer_gates[layer].append([True, True]) + + self.inplanes = 16 # 64 + super(PreactResNetCifar, self).__init__() + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0], + conv_downsample=conv_downsample) + self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2, + conv_downsample=conv_downsample) + self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2, + conv_downsample=conv_downsample) + self.final_bn = nn.BatchNorm2d(64 * block.expansion) + self.final_relu = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(8, stride=1) + self.fc = nn.Linear(64 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False): + downsample = None + outplanes = planes * block.expansion + if stride != 1 or self.inplanes != outplanes: + if conv_downsample: + downsample = nn.Conv2d(self.inplanes, outplanes, + kernel_size=1, stride=stride, bias=False) + else: + # Identity downsample uses strided average pooling + padding instead of convolution + pad_amount = int(self.inplanes / 2) + downsample = nn.Sequential( + nn.AvgPool2d(2), + nn.ConstantPad3d((0, 0, 0, 0, pad_amount, pad_amount), 0) + ) + + layers = [] + layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample, conv_downsample)) + self.inplanes = outplanes + for i in range(1, blocks): + layers.append(block(layer_gates[i], self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_bn(x) + x = self.final_relu(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def preact_resnet20_cifar(**kwargs): + model = PreactResNetCifar(PreactBasicBlock, [3, 3, 3], **kwargs) + return model + + +def preact_resnet32_cifar(**kwargs): + model = PreactResNetCifar(PreactBasicBlock, [5, 5, 5], **kwargs) + return model + + +def preact_resnet44_cifar(**kwargs): + model = PreactResNetCifar(PreactBasicBlock, [7, 7, 7], **kwargs) + return model + + +def preact_resnet56_cifar(**kwargs): + model = PreactResNetCifar(PreactBasicBlock, [9, 9, 9], **kwargs) + return model + + +def preact_resnet110_cifar(**kwargs): + model = PreactResNetCifar(PreactBasicBlock, [18, 18, 18], **kwargs) + return model + + +def preact_resnet20_cifar_conv_ds(**kwargs): + return preact_resnet20_cifar(conv_downsample=True) + + +def preact_resnet32_cifar_conv_ds(**kwargs): + return preact_resnet32_cifar(conv_downsample=True) + + +def preact_resnet44_cifar_conv_ds(**kwargs): + return preact_resnet44_cifar(conv_downsample=True) + + +def preact_resnet56_cifar_conv_ds(**kwargs): + return preact_resnet56_cifar(conv_downsample=True) + + +def preact_resnet110_cifar_conv_ds(**kwargs): + return preact_resnet110_cifar(conv_downsample=True) diff --git a/models/cifar10/resnet_cifar.py b/models/cifar10/resnet_cifar.py index 44d2ba2cbb16a7ae51c4a3757051556a013e48c0..e9ce4e55ab026ba545223d42bce061328f4105d3 100755 --- a/models/cifar10/resnet_cifar.py +++ b/models/cifar10/resnet_cifar.py @@ -53,16 +53,13 @@ class BasicBlock(nn.Module): def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() - #self.layer_id = layer_id - #self.block_id = block_id self.block_gates = block_gates self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) - # This change is required for layer removal - #self.relu = nn.ReLU(inplace=True) - self.relu = nn.ReLU(inplace=False) + self.relu1 = nn.ReLU(inplace=False) # To enable layer removal inplace must be False self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=False) self.downsample = downsample self.stride = stride @@ -72,7 +69,7 @@ class BasicBlock(nn.Module): if self.block_gates[0]: out = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.relu1(out) if self.block_gates[1]: out = self.conv2(out) @@ -82,7 +79,7 @@ class BasicBlock(nn.Module): residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.relu2(out) return out @@ -95,27 +92,19 @@ class ResNetCifar(nn.Module): self.layer_gates = [] for layer in range(3): # For each of the 3 layers, create block gates: each block has two layers - self.layer_gates.append([]) # [True, True] * layers[layer]) + self.layer_gates.append([]) # [True, True] * layers[layer]) for blk in range(layers[layer]): self.layer_gates[layer].append([True, True]) - self.inplanes = 16 # 64 + self.inplanes = 16 # 64 super(ResNetCifar, self).__init__() - #self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - #self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(self.inplanes) self.relu = nn.ReLU(inplace=True) - #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0]) self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2) self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2) - # self.layer1 = self._make_layer(block, 64, layers[0]) - # self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - # self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - #self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AvgPool2d(8, stride=1) - #self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = nn.Linear(64 * block.expansion, num_classes) for m in self.modules(): @@ -147,12 +136,10 @@ class ResNetCifar(nn.Module): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) - #x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) - #x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) diff --git a/models/imagenet/__init__.py b/models/imagenet/__init__.py index ff4d9e2840fa9103eb5c8e1915e251ab32e48f62..300ebd50ff354f6555d51a224c7f6f4c91491b36 100755 --- a/models/imagenet/__init__.py +++ b/models/imagenet/__init__.py @@ -17,3 +17,5 @@ """This package contains ImageNet image classification models not found in torchvision""" from .mobilenet import * +from .preresnet_imagenet import * +from .alexnet_batchnorm import * diff --git a/models/imagenet/alexnet_batchnorm.py b/models/imagenet/alexnet_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..edd444529580421aa9a1c7ed7e5548db9d9b68fa --- /dev/null +++ b/models/imagenet/alexnet_batchnorm.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +AlexNet model with batch-norm layers. +Model configuration based on the AlexNet DoReFa example in TensorPack: +https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py + +Code based on the AlexNet PyTorch sample, with the required changes. +""" + +import math +import torch.nn as nn + +__all__ = ['AlexNetBN', 'alexnet_bn'] + + +class AlexNetBN(nn.Module): + + def __init__(self, num_classes=1000): + super(AlexNetBN, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 96, kernel_size=12, stride=4), # conv0 (224x224x3) -> (54x54x96) + nn.ReLU(inplace=True), + nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2, bias=False), # conv1 (54x54x96) -> (54x54x256) + nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn1 (54x54x256) + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # pool1 (54x54x256) -> (27x27x256) + nn.ReLU(inplace=True), + + nn.Conv2d(256, 384, kernel_size=3, padding=1, bias=False), # conv2 (27x27x256) -> (27x27x384) + nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn2 (27x27x384) + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # pool2 (27x27x384) -> (14x14x384) + nn.ReLU(inplace=True), + + nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False), # conv3 (14x14x384) -> (14x14x384) + nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn3 (14x14x384) + nn.ReLU(inplace=True), + + nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False), # conv4 (14x14x384) -> (14x14x256) + nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn4 (14x14x256) + nn.MaxPool2d(kernel_size=3, stride=2), # pool4 (14x14x256) -> (6x6x256) + nn.ReLU(inplace=True), + ) + self.classifier = nn.Sequential( + nn.Linear(256 * 6 * 6, 4096, bias=False), # fc0 + nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc0 + nn.ReLU(inplace=True), + nn.Linear(4096, 4096, bias=False), # fc1 + nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc1 + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), # fct + ) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + fan_in, k_size = (m.in_channels, m.kernel_size[0] * m.kernel_size[1]) if isinstance(m, nn.Conv2d) \ + else (m.in_features, 1) + n = k_size * fan_in + m.weight.data.normal_(0, math.sqrt(2. / n)) + if hasattr(m, 'bias') and m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + return x + + +def alexnet_bn(**kwargs): + r"""AlexNet model with batch-norm layers. + Model configuration based on the AlexNet DoReFa example in `TensorPack + <https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py>` + """ + model = AlexNetBN(**kwargs) + return model diff --git a/models/imagenet/preresnet_imagenet.py b/models/imagenet/preresnet_imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..412ee99972b70ed1621ead4c97c5f4142eeacafe --- /dev/null +++ b/models/imagenet/preresnet_imagenet.py @@ -0,0 +1,231 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pre-Activation ResNet for ImageNet + +Pre-Activation ResNet for ImageNet, based on "Identity Mappings in Deep Residual Networks". +This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate changes for pre-activation. + +@article{ + He2016, + author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, + title = {Identity Mappings in Deep Residual Networks}, + journal = {arXiv preprint arXiv:1603.05027}, + year = {2016} +} +""" + +import torch.nn as nn +import math + + +__all__ = ['PreactResNet', 'preact_resnet18', 'preact_resnet34', 'preact_resnet50', 'preact_resnet101', + 'preact_resnet152'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class PreactBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, preactivate=True): + super(PreactBasicBlock, self).__init__() + self.pre_bn = self.pre_relu = None + if preactivate: + self.pre_bn = nn.BatchNorm2d(inplanes) + self.pre_relu = nn.ReLU(inplace=True) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1_2 = nn.BatchNorm2d(planes) + self.relu1_2 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.downsample = downsample + self.stride = stride + self.preactivate = preactivate + + def forward(self, x): + if self.preactivate: + preact = self.pre_bn(x) + preact = self.pre_relu(preact) + else: + preact = x + + out = self.conv1(preact) + out = self.bn1_2(out) + out = self.relu1_2(out) + out = self.conv2(out) + + if self.downsample is not None: + residual = self.downsample(preact) + else: + residual = x + + out += residual + + return out + + +class PreactBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, preactivate=True): + super(PreactBottleneck, self).__init__() + self.pre_bn = self.pre_relu = None + if preactivate: + self.pre_bn = nn.BatchNorm2d(inplanes) + self.pre_relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1_2 = nn.BatchNorm2d(planes) + self.relu1_2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2_3 = nn.BatchNorm2d(planes) + self.relu2_3 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.downsample = downsample + self.stride = stride + self.preactivate = preactivate + + def forward(self, x): + if self.preactivate: + preact = self.pre_bn(x) + preact = self.pre_relu(preact) + else: + preact = x + + out = self.conv1(preact) + out = self.bn1_2(out) + out = self.relu1_2(out) + + out = self.conv2(out) + out = self.bn2_3(out) + out = self.relu2_3(out) + + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(preact) + else: + residual = x + + out += residual + + return out + + +class PreactResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(PreactResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.final_bn = nn.BatchNorm2d(512 * block.expansion) + self.final_relu = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + ) + + # On the first residual block in the first residual layer we don't pre-activate, + # because we take care of that (+ maxpool) after the initial conv layer + preactivate_first = stride != 1 + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, preactivate_first)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.final_bn(x) + x = self.final_relu(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def preact_resnet18(**kwargs): + """Constructs a ResNet-18 model. + """ + model = PreactResNet(PreactBasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def preact_resnet34(**kwargs): + """Constructs a ResNet-34 model. + """ + model = PreactResNet(PreactBasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def preact_resnet50(**kwargs): + """Constructs a ResNet-50 model. + """ + model = PreactResNet(PreactBottleneck, [3, 4, 6, 3], **kwargs) + return model + + +def preact_resnet101(**kwargs): + """Constructs a ResNet-101 model. + """ + model = PreactResNet(PreactBottleneck, [3, 4, 23, 3], **kwargs) + return model + + +def preact_resnet152(**kwargs): + """Constructs a ResNet-152 model. + """ + model = PreactResNet(PreactBottleneck, [3, 8, 36, 3], **kwargs) + return model diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..6e63121c0122e246194738e14aa34eafa5ac05af --- /dev/null +++ b/tests/test_learning_rate.py @@ -0,0 +1,50 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) + +import torch +from torch.optim import Optimizer +from distiller.learning_rate import MultiStepMultiGammaLR + + +def test_multi_step_multi_gamma_lr(): + dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True) + dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1}) + lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2]) + expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2] + expected_lrs = [0.1 * gamma for gamma in expected_gammas] + assert lr_sched.multiplicative_gammas == expected_gammas + lr_sched.step(0) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[0] + lr_sched.step(15) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[0] + lr_sched.step(30) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[1] + lr_sched.step(33) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[1] + lr_sched.step(60) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[2] + lr_sched.step(79) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[2] + lr_sched.step(80) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[3] + lr_sched.step(100) + assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[3] diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 50ce5753da84840bf1a8a89859ae67f3d950eec7..5bcb82da417a4035d222044a6fc2c8df03fcf393 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -80,10 +80,10 @@ def test_connectivity(): preds = g.predecessors(op, 2) assert preds == ['layer1.0.bn2', 'relu'] - op = g.find_op('layer1.0.relu1') + op = g.find_op('layer1.0.relu2') assert op is not None succs = g.successors(op, 4) - assert succs == ['layer1.1.bn1', 'layer1.1.relu1'] + assert succs == ['layer1.1.bn1', 'layer1.1.relu2'] preds = g.predecessors(g.find_op('bn1'), 10) assert preds == []