From 697b3cfee5b23b205cc843377ce1a054c6bed0c4 Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Mon, 2 Dec 2019 16:38:59 +0200 Subject: [PATCH] Object Detection Compression (#343) Add an example of compressing OD pytorch models. In this example we compress torchvision's object detection models - FasterRCNN / MaskRCNN / KeypointRCNN. We've modified the reference code for object detection to allow easy compression scheduling with YAML configuration. --- distiller/model_transforms.py | 3 +- distiller/models/__init__.py | 4 +- distiller/quantization/sim_bn_fold.py | 21 +- .../object_detection_compression/README.md | 57 +++ .../object_detection_compression/__init__.py | 0 .../object_detection_compression/coco_eval.py | 351 +++++++++++++++++ .../coco_utils.py | 253 +++++++++++++ .../compress_detector.py | 356 ++++++++++++++++++ .../data/download_dataset.sh | 9 + .../object_detection_compression/engine.py | 125 ++++++ .../group_by_aspect_ratio.py | 189 ++++++++++ .../object_detection_compression/logging.conf | 54 +++ .../maskrcnn.scheduler_agp.non_parallel.yaml | 210 +++++++++++ .../maskrcnn.scheduler_agp.yaml | 267 +++++++++++++ .../requirements.txt | 2 + .../transforms.py | 52 +++ .../object_detection_compression/utils.py | 328 ++++++++++++++++ 17 files changed, 2274 insertions(+), 7 deletions(-) create mode 100644 examples/object_detection_compression/README.md create mode 100644 examples/object_detection_compression/__init__.py create mode 100644 examples/object_detection_compression/coco_eval.py create mode 100644 examples/object_detection_compression/coco_utils.py create mode 100644 examples/object_detection_compression/compress_detector.py create mode 100644 examples/object_detection_compression/data/download_dataset.sh create mode 100644 examples/object_detection_compression/engine.py create mode 100644 examples/object_detection_compression/group_by_aspect_ratio.py create mode 100644 examples/object_detection_compression/logging.conf create mode 100644 examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml create mode 100644 examples/object_detection_compression/maskrcnn.scheduler_agp.yaml create mode 100644 examples/object_detection_compression/requirements.txt create mode 100644 examples/object_detection_compression/transforms.py create mode 100644 examples/object_detection_compression/utils.py diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py index 5761431..647d2ed 100644 --- a/distiller/model_transforms.py +++ b/distiller/model_transforms.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d from collections import OrderedDict import distiller import distiller.modules @@ -129,7 +130,7 @@ def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True return folded_module foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) - batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) + batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, FrozenBatchNorm2d) return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map) diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 8f94bc9..366edd4 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -20,13 +20,15 @@ import copy from functools import partial import torch import torchvision.models as torch_models +from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN +from torchvision.ops.misc import FrozenBatchNorm2d import torch.nn as nn from . import cifar10 as cifar10_models from . import mnist as mnist_models from . import imagenet as imagenet_extra_models import pretrainedmodels -from distiller.utils import set_model_input_shape_attr +from distiller.utils import set_model_input_shape_attr, model_setattr from distiller.modules import Mean, EltwiseAdd import logging diff --git a/distiller/quantization/sim_bn_fold.py b/distiller/quantization/sim_bn_fold.py index 8c3a6f3..5fa17ea 100644 --- a/distiller/quantization/sim_bn_fold.py +++ b/distiller/quantization/sim_bn_fold.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +from torchvision.ops.misc import FrozenBatchNorm2d __all__ = ['SimulatedFoldedBatchNorm'] @@ -28,22 +29,32 @@ class SimulatedFoldedBatchNorm(nn.Module): Wrapper for simulated folding of BatchNorm into convolution / linear layers during training Args: param_module (nn.Linear or nn.Conv1d or nn.Conv2d or nn.Conv3d): the wrapped parameter module - bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d): batch normalization module + bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d or FrozenBatchNorm2d): batch normalization module freeze_bn_delay (int): number of steps before freezing the batch-norm running stats param_quantization_fn (function): function to be used for weight/bias quantization Note: The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2. """ - SimulatedFoldedBatchNorm.verify_module_types(param_module, bn) - if not bn.track_running_stats: - raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats") - super(SimulatedFoldedBatchNorm, self).__init__() self.param_module = param_module self.bn = bn self.freeze_bn_delay = freeze_bn_delay self.frozen = False self._has_bias = (self.param_module.bias is not None) self.param_quant_fn = param_quantization_fn + if isinstance(bn, FrozenBatchNorm2d): + if not isinstance(param_module, nn.Conv2d): + error_msg = "Can't fold sequence of {} --> {}. ".format( + param_module.__class__.__name__, bn.__class__.__name__ + ) + raise TypeError(error_msg + ' FrozenBatchNorm2d must follow a nn.Conv2d.') + # This torchvision op is frozen from the beginning, so we fuse it + # directly into the linear layer. + self.freeze() + return + SimulatedFoldedBatchNorm.verify_module_types(param_module, bn) + if not bn.track_running_stats: + raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats") + super(SimulatedFoldedBatchNorm, self).__init__() if isinstance(param_module, nn.Linear): self.param_forward_fn = self._linear_layer_forward self.param_module_type = "fc" diff --git a/examples/object_detection_compression/README.md b/examples/object_detection_compression/README.md new file mode 100644 index 0000000..9ac368a --- /dev/null +++ b/examples/object_detection_compression/README.md @@ -0,0 +1,57 @@ +# Object Detection Compression + +In this example we compress torchvision's object detection models - FasterRCNN / MaskRCNN / KeypointRCNN. +We've modified the [reference code for object detection](https://github.com/pytorch/vision/tree/master/references/detection) +to allow easy compression scheduling with yaml configuration. + +## Setup +Install the dependencies +(most of which are already installed from Distiller dependies, the rest are `Cython` and `pycocotools`): + + cd <distiller root>/examples/object_detection_compression/ + pip3 install -r requirements.txt + +The dataset can be downloaded at the [COCO dataset website](http://cocodataset.org/#download). +Please keep in mind that COCO dataset takes up 18 GB of storage. +In this example we'll use the 2017 training+validation sets, which you can download using the command line: + + cd data + bash download_dataset.sh + +## Running the Example +The command line for running this example is closely related to +[`compress_classifier.py`](../classifier_compression/compress_classifier.py), i.e. the +compression scheduler format and most of the Distiller related arguments are the same. +However - running in a multi-GPU environment is different from `compress_classifier.py`, because this script is a modified +[`train.py` from torchvision references](https://github.com/pytorch/vision/tree/master/references/detection/train.py), +where they used `torch.distributed.launch` for multi-GPU (and multi-node in general) training. + +**Note-** Use of `torch.distributed.launch` will spawn multiple processes, on each process +there will be a copy of the model and the weights, each of the models is an instance of +[`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel). +During backward pass, the gradients from each node are averaged and then passed to all nodes, +thus promising the weights on the nodes are the same. +This also promises that the pruning masks remain identical on all the nodes. + + Example Single GPU Command Line - + + python compress_detector.py --data-path /path/to/dataset.COCO --pretrained --compress maskrcnn.scheduler_agp.non_parallel.yaml + + Example Multi GPU Command Line - + + python -m torch.distributed.launch --nproc_per_node=4 --use_env compress_detector.py --data-path /path/to/dataset.COCO \ + --compress maskrcnn.scheduler_agp.yaml --pretrained --world-size 4 --batch-size 2 --epochs 80 + +Since the dataset is large and FasterRCNN models are compute heavy, we strongly recommend +running the script on a Multi GPU environment. Keep in mind that the multi-GPU case is +running on multiple processes via `torch.distributed.launch`, and ending one of the processes +might break all of them and leave them in an undefined state (In that case you'll have to end +them manually). Also, even though the multi-GPU distributes the memory over all the GPUs, the +model is quite memory intensive, so using a large batch size is guaranteed to yield OOM on the GPU. +Our GPUs are TITAN X (Pascal) with 12GB memory, and a batch size of 3 is the most we could run without +memory errors. + + +The default model is `torchvision.models.detection.maskrcnn_resnet50_fpn`, you can specify +any model that is part of `torchvision.models.detection` using + the `--model` argument, e.g. `--model maskrcnn_resnet50_fpn`. \ No newline at end of file diff --git a/examples/object_detection_compression/__init__.py b/examples/object_detection_compression/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/object_detection_compression/coco_eval.py b/examples/object_detection_compression/coco_eval.py new file mode 100644 index 0000000..5b68514 --- /dev/null +++ b/examples/object_detection_compression/coco_eval.py @@ -0,0 +1,351 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection +import json +import tempfile + +import numpy as np +import copy +import time +import torch +import torch._six + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from collections import defaultdict + +import utils + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + coco_dt = loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = utils.all_gather(img_ids) + all_eval_imgs = utils.all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + +# Ideally, pycocotools wouldn't have hard-coded prints +# so that we could avoid copy-pasting those two functions + +def createIndex(self): + # create index + # print('creating index...') + anns, cats, imgs = {}, {}, {} + imgToAnns, catToImgs = defaultdict(list), defaultdict(list) + if 'annotations' in self.dataset: + for ann in self.dataset['annotations']: + imgToAnns[ann['image_id']].append(ann) + anns[ann['id']] = ann + + if 'images' in self.dataset: + for img in self.dataset['images']: + imgs[img['id']] = img + + if 'categories' in self.dataset: + for cat in self.dataset['categories']: + cats[cat['id']] = cat + + if 'annotations' in self.dataset and 'categories' in self.dataset: + for ann in self.dataset['annotations']: + catToImgs[ann['category_id']].append(ann['image_id']) + + # print('index created!') + + # create class members + self.anns = anns + self.imgToAnns = imgToAnns + self.catToImgs = catToImgs + self.imgs = imgs + self.cats = cats + + +maskUtils = mask_util + + +def loadRes(self, resFile): + """ + Load result file and return a result api object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = COCO() + res.dataset['images'] = [img for img in self.dataset['images']] + + # print('Loading and preparing results...') + # tic = time.time() + if isinstance(resFile, torch._six.string_classes): + anns = json.load(open(resFile)) + elif type(resFile) == np.ndarray: + anns = self.loadNumpyAnnotations(resFile) + else: + anns = resFile + assert type(anns) == list, 'results in not an array of objects' + annsImgIds = [ann['image_id'] for ann in anns] + assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ + 'Results do not correspond to current coco set' + if 'caption' in anns[0]: + imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) + res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] + for id, ann in enumerate(anns): + ann['id'] = id + 1 + elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: + res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + for id, ann in enumerate(anns): + bb = ann['bbox'] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if 'segmentation' not in ann: + ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann['area'] = bb[2] * bb[3] + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'segmentation' in anns[0]: + res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + for id, ann in enumerate(anns): + # now only support compressed RLE format as segmentation results + ann['area'] = maskUtils.area(ann['segmentation']) + if 'bbox' not in ann: + ann['bbox'] = maskUtils.toBbox(ann['segmentation']) + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'keypoints' in anns[0]: + res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + for id, ann in enumerate(anns): + s = ann['keypoints'] + x = s[0::3] + y = s[1::3] + x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y) + ann['area'] = (x2 - x1) * (y2 - y1) + ann['id'] = id + 1 + ann['bbox'] = [x1, y1, x2 - x1, y2 - y1] + # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) + + res.dataset['annotations'] = anns + createIndex(res) + return res + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/examples/object_detection_compression/coco_utils.py b/examples/object_detection_compression/coco_utils.py new file mode 100644 index 0000000..b9bb5e8 --- /dev/null +++ b/examples/object_detection_compression/coco_utils.py @@ -0,0 +1,253 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection +import copy +import os +from PIL import Image + +import torch +import torch.utils.data +import torchvision + +from pycocotools import mask as coco_mask +from pycocotools.coco import COCO + +import transforms as T + + +class FilterAndRemapCocoCategories(object): + def __init__(self, categories, remap=True): + self.categories = categories + self.remap = remap + + def __call__(self, image, target): + anno = target["annotations"] + anno = [obj for obj in anno if obj["category_id"] in self.categories] + if not self.remap: + target["annotations"] = anno + return image, target + anno = copy.deepcopy(anno) + for obj in anno: + obj["category_id"] = self.categories.index(obj["category_id"]) + target["annotations"] = anno + return image, target + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + target["area"] = area + target["iscrowd"] = iscrowd + + return image, target + + +def _coco_remove_images_without_annotations(dataset, cat_list=None): + def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + min_keypoints_per_image = 10 + + def _has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different critera for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + assert isinstance(dataset, torchvision.datasets.CocoDetection) + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if cat_list: + anno = [obj for obj in anno if obj["category_id"] in cat_list] + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +def convert_to_coco_api(ds): + coco_ds = COCO() + ann_id = 0 + dataset = {'images': [], 'categories': [], 'annotations': []} + categories = set() + for img_idx in range(len(ds)): + # find better way to get target + # targets = ds.get_annotations(img_idx) + img, targets = ds[img_idx] + image_id = targets["image_id"].item() + img_dict = {} + img_dict['id'] = image_id + img_dict['height'] = img.shape[-2] + img_dict['width'] = img.shape[-1] + dataset['images'].append(img_dict) + bboxes = targets["boxes"] + bboxes[:, 2:] -= bboxes[:, :2] + bboxes = bboxes.tolist() + labels = targets['labels'].tolist() + areas = targets['area'].tolist() + iscrowd = targets['iscrowd'].tolist() + if 'masks' in targets: + masks = targets['masks'] + # make masks Fortran contiguous for coco_mask + masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) + if 'keypoints' in targets: + keypoints = targets['keypoints'] + keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() + num_objs = len(bboxes) + for i in range(num_objs): + ann = {} + ann['image_id'] = image_id + ann['bbox'] = bboxes[i] + ann['category_id'] = labels[i] + categories.add(labels[i]) + ann['area'] = areas[i] + ann['iscrowd'] = iscrowd[i] + ann['id'] = ann_id + if 'masks' in targets: + ann["segmentation"] = coco_mask.encode(masks[i].numpy()) + if 'keypoints' in targets: + ann['keypoints'] = keypoints[i] + ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) + dataset['annotations'].append(ann) + ann_id += 1 + dataset['categories'] = [{'id': i} for i in sorted(categories)] + coco_ds.dataset = dataset + coco_ds.createIndex() + return coco_ds + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + if isinstance(dataset, torchvision.datasets.CocoDetection): + break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, torchvision.datasets.CocoDetection): + return dataset.coco + return convert_to_coco_api(dataset) + + +class CocoDetection(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, ann_file, transforms): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = dict(image_id=image_id, annotations=target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +def get_coco(root, image_set, transforms, mode='instances'): + anno_file_template = "{}_{}2017.json" + PATHS = { + "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), + "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), + # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) + } + + t = [ConvertCocoPolysToMask()] + + if transforms is not None: + t.append(transforms) + transforms = T.Compose(t) + + img_folder, ann_file = PATHS[image_set] + img_folder = os.path.join(root, img_folder) + ann_file = os.path.join(root, ann_file) + + dataset = CocoDetection(img_folder, ann_file, transforms=transforms) + + if image_set == "train": + dataset = _coco_remove_images_without_annotations(dataset) + + # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) + + return dataset + + +def get_coco_kp(root, image_set, transforms): + return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/examples/object_detection_compression/compress_detector.py b/examples/object_detection_compression/compress_detector.py new file mode 100644 index 0000000..4f477c3 --- /dev/null +++ b/examples/object_detection_compression/compress_detector.py @@ -0,0 +1,356 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection/train.py +# It contains code to support compression (distiller) +r"""PyTorch Detection Training. + +To run in a multi-gpu environment, use the distributed launcher:: + + python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \ + compress_detector.py ... --world-size $NGPU + +""" +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn +import torchvision +import torchvision.models.detection as detection +from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN +from torchvision.ops.misc import FrozenBatchNorm2d +import torch.distributed as dist + +import distiller +from distiller.data_loggers import * +import distiller.apputils as apputils +import distiller.pruning +import distiller.models +from distiller.model_transforms import fold_batch_norms + +from coco_utils import get_coco, get_coco_kp + +from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups +from engine import train_one_epoch, evaluate + +import utils +import transforms as T + +import logging +logging.getLogger().setLevel(logging.INFO) # Allow distiller info to be logged. + + +def get_dataset(name, image_set, transform, data_path): + paths = { + "coco": (data_path, get_coco, 91), + "coco_kp": (data_path, get_coco_kp, 2) + } + p, ds_fn, num_classes = paths[name] + + ds = ds_fn(p, image_set=image_set, transforms=transform) + return ds, num_classes + + +def get_transform(train): + transforms = [] + transforms.append(T.ToTensor()) + if train: + transforms.append(T.RandomHorizontalFlip(0.5)) + return T.Compose(transforms) + + +def patch_fastrcnn(model): + """ + TODO - complete quantization example + Partial patch for torchvision's FastRCNN models to allow quantization, by replacing all FrozenBatchNorm2d + with regular nn.BatchNorm2d-s. + Args: + model (GeneralizedRCNN): the model to patch + """ + assert isinstance(model, GeneralizedRCNN) + + def replace_frozen_bn(frozen_bn: FrozenBatchNorm2d): + num_features = frozen_bn.weight.shape[0] + bn = nn.BatchNorm2d(num_features) + eps = bn.eps + bn.weight.data = frozen_bn.weight.data + bn.bias.data = frozen_bn.bias.data + bn.running_mean.data = frozen_bn.running_mean.data + bn.running_var.data = frozen_bn.running_var.data + return bn.eval() + + for n, m in model.named_modules(): + if isinstance(m, FrozenBatchNorm2d): + distiller.model_setattr(model, n, replace_frozen_bn(m)) + + +def main(args): + utils.init_distributed_mode(args) + print(args) + + device = torch.device(args.device) + + script_dir = os.path.dirname(__file__) + module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if utils.is_main_process(): + msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir, + args.verbose) + + # Log various details about the execution environment. It is sometimes useful + # to refer to past experiment executions and this information may be useful. + apputils.log_execution_env_state( + filter(None, [args.compress, args.qe_stats_file]), # remove both None and empty strings + msglogger.logdir) + msglogger.debug("Distiller: %s", distiller.__version__) + else: + msglogger = logging.getLogger() + msglogger.disabled = True + + # Data loading code + print("Loading data") + dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path) + dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path) + + print("Creating data loaders") + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + else: + train_sampler = torch.utils.data.RandomSampler(dataset) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + if args.aspect_ratio_group_factor >= 0: + group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) + train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) + else: + train_batch_sampler = torch.utils.data.BatchSampler( + train_sampler, args.batch_size, drop_last=True) + + data_loader = torch.utils.data.DataLoader( + dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, + collate_fn=utils.collate_fn) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=1, + sampler=test_sampler, num_workers=args.workers, + collate_fn=utils.collate_fn) + + print("Creating model") + model = detection.__dict__[args.model](num_classes=num_classes, + pretrained=args.pretrained) + patch_fastrcnn(model) + model.to(device) + + if args.summary and utils.is_main_process(): + for summary in args.summary: + distiller.model_summary(model, summary, args.dataset) + return + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD( + params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) + + compression_scheduler = None + if utils.is_main_process(): + # Create a couple of logging backends. TensorBoardLogger writes log files in a format + # that can be read by Google's Tensor Board. PythonLogger writes to the Python logger. + tflogger = TensorBoardLogger(msglogger.logdir) + pylogger = PythonLogger(msglogger) + + if args.compress: + # The main use-case for this sample application is CNN compression. Compression + # requires a compression schedule configuration file in YAML. + compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler, None) + # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) + model.to(args.device) + elif compression_scheduler is None: + compression_scheduler = distiller.CompressionScheduler(model) + + if args.qe_calibration: + def test_fn(model): + return evaluate(model, data_loader_test, device=device) + collect_quant_stats(model_without_ddp, test_fn, save_dir=args.output_dir, + modules_to_collect=['backbone', 'rpn', 'roi_heads']) + # We skip `.transform` because it is a pre-processing unit. + return + + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + if compression_scheduler and 'compression_scheduler' in checkpoint: + compression_scheduler.load_state_dict(checkpoint['compression_scheduler']) + + if args.test_only: + evaluate(model, data_loader_test, device=device) + return + activations_collectors = create_activation_stats_collectors(model, *args.activation_stats) + print("Start training") + start_time = time.time() + + # if not isinstance(model, nn.DataParallel) and torch.cuda.is_available() \ + # and torch.cuda.device_count() > 1: + # msglogger.info("Using %d GPUs on DataParallel." % torch.cuda.device_count()) + # model = nn.DataParallel(model) + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + dist.barrier() + + if compression_scheduler: + compression_scheduler.on_epoch_begin(epoch) + + with collectors_context(activations_collectors["train"]) as collectors: + train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, compression_scheduler) + if utils.is_main_process(): + distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) + distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger], + collector=collectors["sparsity"]) + if args.masks_sparsity and utils.is_main_process(): + msglogger.info(distiller.masks_sparsity_tbl_summary(model, compression_scheduler)) + + lr_scheduler.step() + if args.output_dir: + save_dict = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'args': args} + if compression_scheduler: + save_dict['compression_scheduler'] = compression_scheduler.state_dict() + utils.save_on_master(save_dict, + os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + + # evaluate after every epoch + evaluate(model, data_loader_test, device=device) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +class missingdict(dict): + """This is a little trick to prevent KeyError""" + def __missing__(self, key): + return None # note, does *not* set self[key] - we don't want defaultdict's behavior + + +def create_activation_stats_collectors(model, *phases): + """Create objects that collect activation statistics. + + This is a utility function that creates two collectors: + 1. Fine-grade sparsity levels of the activations + 2. L1-magnitude of each of the activation channels + + Args: + model - the model on which we want to collect statistics + phases - the statistics collection phases: train, valid, and/or test + + WARNING! Enabling activation statsitics collection will significantly slow down training! + """ + genCollectors = lambda: missingdict({ + "sparsity": SummaryActivationStatsCollector(model, "sparsity", + lambda t: 100 * distiller.utils.sparsity(t)), + "l1_channels": SummaryActivationStatsCollector(model, "l1_channels", + distiller.utils.activation_channels_l1), + "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels", + distiller.utils.activation_channels_apoz), + "mean_channels": SummaryActivationStatsCollector(model, "mean_channels", + distiller.utils.activation_channels_means), + "records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) + }) + + return {k: (genCollectors() if k in phases else missingdict()) + for k in ('train', 'valid', 'test')} + + +def add_distiller_compression_args(parser): + SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] + distiller_parser = parser.add_argument_group('Distiller related arguemnts') + distiller_parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, action='append', + help='print a summary of the model, and exit - options: | '.join(SUMMARY_CHOICES)) + distiller_parser.add_argument('--export-onnx', action='store', nargs='?', type=str, const='model.onnx', + default=None, + help='export model to ONNX format') + distiller_parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', + help='configuration file for pruning the model ' + '(default is to use hard-coded schedule)') + distiller.pruning.greedy_filter_pruning.add_greedy_pruner_args(distiller_parser) + distiller_parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') + distiller_parser.add_argument('--verbose', '-v', action='store_true', help='Emit debug log messages') + distiller.quantization.add_post_train_quant_args(distiller_parser) + distiller_parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(), + help='collect activation statistics on phases: train, valid, and/or test' + ' (WARNING: this slows down training)') + distiller_parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False, + help='print masks sparsity table at end of each epoch') + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser( + description=__doc__) + + parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset') + parser.add_argument('--dataset', default='coco', help='dataset') + parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model') + parser.add_argument('--device', default='cuda', help='device') + parser.add_argument('-b', '--batch-size', default=2, type=int) + parser.add_argument('--epochs', default=13, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('--start-epoch', default=0, type=int, help='starting epoch number') + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 16)') + parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs') + parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs') + parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') + parser.add_argument('--print-freq', default=20, type=int, help='print frequency') + parser.add_argument('--output-dir', default='.', help='path where to save') + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--aspect-ratio-group-factor', default=0, type=int) + parser.add_argument( + "--evaluate", + dest="test_only", + help="Only test the model", + action="store_true", + ) + parser.add_argument( + "--pretrained", + dest="pretrained", + help="Use pre-trained models from the modelzoo", + action="store_true", + ) + + # distributed training parameters + parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + + add_distiller_compression_args(parser) + + args = parser.parse_args() + + if args.output_dir: + utils.mkdir(args.output_dir) + + main(args) diff --git a/examples/object_detection_compression/data/download_dataset.sh b/examples/object_detection_compression/data/download_dataset.sh new file mode 100644 index 0000000..60fdc6d --- /dev/null +++ b/examples/object_detection_compression/data/download_dataset.sh @@ -0,0 +1,9 @@ +echo "Downloading Dataset..." +wget http://images.cocodataset.org/zips/train2017.zip +wget http://images.cocodataset.org/zips/val2017.zip +wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip + +echo "Extracting Dataset..." +unzip train2017.zip +unzip val2017.zip +unzip annotations_trainval2017.zip \ No newline at end of file diff --git a/examples/object_detection_compression/engine.py b/examples/object_detection_compression/engine.py new file mode 100644 index 0000000..ecc4c52 --- /dev/null +++ b/examples/object_detection_compression/engine.py @@ -0,0 +1,125 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection +# (old commit) +# It contains code to support compression (distiller) +import math +import sys +import time +import torch + +import torchvision.models.detection.mask_rcnn + +from coco_utils import get_coco_api_from_dataset +from coco_eval import CocoEvaluator +import utils + + +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, compression_scheduler=None): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + steps_per_epoch = len(data_loader) + lr_scheduler = None + if epoch == 0: + warmup_factor = 1. / 1000 + warmup_iters = min(1000, len(data_loader) - 1) + + lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) + + for train_step, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + if compression_scheduler: + compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer) + + images = list(image.to(device) for image in images) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + loss_dict = model(images, targets) + + losses = sum(loss for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + loss_value = losses_reduced.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + if compression_scheduler: + losses = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, losses, + optimizer=optimizer) + + optimizer.zero_grad() + losses.backward() + + if compression_scheduler: + compression_scheduler.before_parameter_optimization(epoch, train_step, steps_per_epoch, optimizer) + optimizer.step() + + if compression_scheduler: + compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch, optimizer) + + if lr_scheduler is not None: + lr_scheduler.step() + + metric_logger.update(loss=losses_reduced, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + + +def _get_iou_types(model): + model_without_ddp = model + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_without_ddp = model.module + iou_types = ["bbox"] + if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): + iou_types.append("segm") + if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): + iou_types.append("keypoints") + return iou_types + + +@torch.no_grad() +def evaluate(model, data_loader, device): + n_threads = torch.get_num_threads() + # FIXME remove this and make paste_masks_in_image run on the GPU + torch.set_num_threads(1) + cpu_device = torch.device("cpu") + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + coco = get_coco_api_from_dataset(data_loader.dataset) + iou_types = _get_iou_types(model) + coco_evaluator = CocoEvaluator(coco, iou_types) + + for image, targets in metric_logger.log_every(data_loader, 100, header): + image = list(img.to(device) for img in image) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + torch.cuda.synchronize() + model_time = time.time() + outputs = model(image) + + outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] + model_time = time.time() - model_time + + res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + evaluator_time = time.time() + coco_evaluator.update(res) + evaluator_time = time.time() - evaluator_time + metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + coco_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + coco_evaluator.accumulate() + coco_evaluator.summarize() + torch.set_num_threads(n_threads) + return coco_evaluator diff --git a/examples/object_detection_compression/group_by_aspect_ratio.py b/examples/object_detection_compression/group_by_aspect_ratio.py new file mode 100644 index 0000000..b82a884 --- /dev/null +++ b/examples/object_detection_compression/group_by_aspect_ratio.py @@ -0,0 +1,189 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection +import bisect +from collections import defaultdict +import copy +import numpy as np + +import torch +import torch.utils.data +from torch.utils.data.sampler import BatchSampler, Sampler +from torch.utils.model_zoo import tqdm +import torchvision + +from PIL import Image + + +class GroupedBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices. + It enforces that the batch only contain elements from the same group. + It also tries to provide mini-batches which follows an ordering which is + as close as possible to the ordering from the original sampler. + Arguments: + sampler (Sampler): Base sampler. + group_ids (list[int]): If the sampler produces indices in range [0, N), + `group_ids` must be a list of `N` ints which contains the group id of each sample. + The group ids must be a continuous set of integers starting from + 0, i.e. they must be in the range [0, num_groups). + batch_size (int): Size of mini-batch. + """ + def __init__(self, sampler, group_ids, batch_size): + if not isinstance(sampler, Sampler): + raise ValueError( + "sampler should be an instance of " + "torch.utils.data.Sampler, but got sampler={}".format(sampler) + ) + self.sampler = sampler + self.group_ids = group_ids + self.batch_size = batch_size + + def __iter__(self): + buffer_per_group = defaultdict(list) + samples_per_group = defaultdict(list) + + num_batches = 0 + for idx in self.sampler: + group_id = self.group_ids[idx] + buffer_per_group[group_id].append(idx) + samples_per_group[group_id].append(idx) + if len(buffer_per_group[group_id]) == self.batch_size: + yield buffer_per_group[group_id] + num_batches += 1 + del buffer_per_group[group_id] + assert len(buffer_per_group[group_id]) < self.batch_size + + # now we have run out of elements that satisfy + # the group criteria, let's return the remaining + # elements so that the size of the sampler is + # deterministic + expected_num_batches = len(self) + num_remaining = expected_num_batches - num_batches + if num_remaining > 0: + # for the remaining batches, take first the buffers with largest number + # of elements + for group_id, _ in sorted(buffer_per_group.items(), + key=lambda x: len(x[1]), reverse=True): + remaining = self.batch_size - len(buffer_per_group[group_id]) + buffer_per_group[group_id].extend( + samples_per_group[group_id][:remaining]) + assert len(buffer_per_group[group_id]) == self.batch_size + yield buffer_per_group[group_id] + num_remaining -= 1 + if num_remaining == 0: + break + assert num_remaining == 0 + + def __len__(self): + return len(self.sampler) // self.batch_size + + +def _compute_aspect_ratios_slow(dataset, indices=None): + print("Your dataset doesn't support the fast path for " + "computing the aspect ratios, so will iterate over " + "the full dataset and load every image instead. " + "This might take some time...") + if indices is None: + indices = range(len(dataset)) + + class SubsetSampler(Sampler): + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return iter(self.indices) + + def __len__(self): + return len(self.indices) + + sampler = SubsetSampler(indices) + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=1, sampler=sampler, + num_workers=14, # you might want to increase it for faster processing + collate_fn=lambda x: x[0]) + aspect_ratios = [] + with tqdm(total=len(dataset)) as pbar: + for _i, (img, _) in enumerate(data_loader): + pbar.update(1) + height, width = img.shape[-2:] + aspect_ratio = float(height) / float(width) + aspect_ratios.append(aspect_ratio) + return aspect_ratios + + +def _compute_aspect_ratios_custom_dataset(dataset, indices=None): + if indices is None: + indices = range(len(dataset)) + aspect_ratios = [] + for i in indices: + height, width = dataset.get_height_and_width(i) + aspect_ratio = float(height) / float(width) + aspect_ratios.append(aspect_ratio) + return aspect_ratios + + +def _compute_aspect_ratios_coco_dataset(dataset, indices=None): + if indices is None: + indices = range(len(dataset)) + aspect_ratios = [] + for i in indices: + img_info = dataset.coco.imgs[dataset.ids[i]] + aspect_ratio = float(img_info["height"]) / float(img_info["width"]) + aspect_ratios.append(aspect_ratio) + return aspect_ratios + + +def _compute_aspect_ratios_voc_dataset(dataset, indices=None): + if indices is None: + indices = range(len(dataset)) + aspect_ratios = [] + for i in indices: + # this doesn't load the data into memory, because PIL loads it lazily + width, height = Image.open(dataset.images[i]).size + aspect_ratio = float(height) / float(width) + aspect_ratios.append(aspect_ratio) + return aspect_ratios + + +def _compute_aspect_ratios_subset_dataset(dataset, indices=None): + if indices is None: + indices = range(len(dataset)) + + ds_indices = [dataset.indices[i] for i in indices] + return compute_aspect_ratios(dataset.dataset, ds_indices) + + +def compute_aspect_ratios(dataset, indices=None): + if hasattr(dataset, "get_height_and_width"): + return _compute_aspect_ratios_custom_dataset(dataset, indices) + + if isinstance(dataset, torchvision.datasets.CocoDetection): + return _compute_aspect_ratios_coco_dataset(dataset, indices) + + if isinstance(dataset, torchvision.datasets.VOCDetection): + return _compute_aspect_ratios_voc_dataset(dataset, indices) + + if isinstance(dataset, torch.utils.data.Subset): + return _compute_aspect_ratios_subset_dataset(dataset, indices) + + # slow path + return _compute_aspect_ratios_slow(dataset, indices) + + +def _quantize(x, bins): + bins = copy.deepcopy(bins) + bins = sorted(bins) + quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) + return quantized + + +def create_aspect_ratio_groups(dataset, k=0): + aspect_ratios = compute_aspect_ratios(dataset) + bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0] + groups = _quantize(aspect_ratios, bins) + # count number of elements per group + counts = np.unique(groups, return_counts=True)[1] + fbins = [0] + bins + [np.inf] + print("Using {} as bins for aspect ratio quantization".format(fbins)) + print("Count of instances per bin: {}".format(counts)) + return groups diff --git a/examples/object_detection_compression/logging.conf b/examples/object_detection_compression/logging.conf new file mode 100644 index 0000000..429419b --- /dev/null +++ b/examples/object_detection_compression/logging.conf @@ -0,0 +1,54 @@ +[formatters] +keys: simple, time_simple + +[handlers] +keys: console, file + +[loggers] +keys: root, app_cfg, distiller.thinning, apputils.model_summaries + +[formatter_simple] +format: %(message)s + +[formatter_time_simple] +format: %(asctime)s - %(message)s + +[handler_console] +class: StreamHandler +propagate: 0 +args: [] +formatter: simple + +[handler_file] +class: FileHandler +mode: 'w' +args=('%(logfilename)s', 'w') +formatter: time_simple + +[logger_root] +level: INFO +propagate: 1 +handlers: console, file + +[logger_app_cfg] +# Use this logger to log the application configuration and execution environment +level: DEBUG +qualname: app_cfg +propagate: 0 +handlers: file + +# Example of adding a module-specific logger +# Do not forget to add distiller.thinning to the list of keys in section [loggers] +[logger_distiller.thinning] +level: INFO +qualname: distiller.thinning +propagate: 0 +handlers: console, file + +# Example of adding a module-specific logger +# Do not forget to add apputils.model_summaries to the list of keys in section [loggers] +[logger_apputils.model_summaries] +level: INFO +qualname: apputils.model_summaries +propagate: 0 +handlers: console, file diff --git a/examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml b/examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml new file mode 100644 index 0000000..579582b --- /dev/null +++ b/examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml @@ -0,0 +1,210 @@ +# +# Command line: +# python compress_detector.py --data-path $DATASET_COCO \ +# --compress maskrcnn.scheduler_agp.non_parallel.yaml +# +#Parameters: +#+----+-------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +#| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +#|----+-------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +#| 0 | backbone.body.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12322 | -0.00050 | 0.07605 | +#| 1 | backbone.body.layer1.0.conv1.weight | (64, 64, 1, 1) | 4096 | 4096 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07083 | -0.00439 | 0.04032 | +#| 2 | backbone.body.layer1.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02915 | 0.00079 | 0.01717 | +#| 3 | backbone.body.layer1.0.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03527 | 0.00042 | 0.02105 | +#| 4 | backbone.body.layer1.0.downsample.0.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05750 | -0.00342 | 0.03227 | +#| 5 | backbone.body.layer1.1.conv1.weight | (64, 256, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03062 | 0.00135 | 0.02029 | +#| 6 | backbone.body.layer1.1.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02881 | 0.00005 | 0.01964 | +#| 7 | backbone.body.layer1.1.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03254 | -0.00019 | 0.02029 | +#| 8 | backbone.body.layer1.2.conv1.weight | (64, 256, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03008 | 0.00012 | 0.02169 | +#| 9 | backbone.body.layer1.2.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03178 | -0.00085 | 0.02378 | +#| 10 | backbone.body.layer1.2.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03109 | -0.00255 | 0.01857 | +#| 11 | backbone.body.layer2.0.conv1.weight | (128, 256, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03542 | -0.00152 | 0.02484 | +#| 12 | backbone.body.layer2.0.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02229 | -0.00054 | 0.01661 | +#| 13 | backbone.body.layer2.0.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02808 | -0.00013 | 0.01740 | +#| 14 | backbone.body.layer2.0.downsample.0.weight | (512, 256, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02311 | -0.00035 | 0.01352 | +#| 15 | backbone.body.layer2.1.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01666 | -0.00007 | 0.01004 | +#| 16 | backbone.body.layer2.1.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01934 | -0.00004 | 0.01239 | +#| 17 | backbone.body.layer2.1.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02206 | -0.00121 | 0.01247 | +#| 18 | backbone.body.layer2.2.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02332 | -0.00078 | 0.01618 | +#| 19 | backbone.body.layer2.2.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02159 | -0.00033 | 0.01536 | +#| 20 | backbone.body.layer2.2.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02617 | -0.00047 | 0.01840 | +#| 21 | backbone.body.layer2.3.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02432 | -0.00098 | 0.01803 | +#| 22 | backbone.body.layer2.3.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02241 | -0.00074 | 0.01691 | +#| 23 | backbone.body.layer2.3.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02432 | -0.00120 | 0.01681 | +#| 24 | backbone.body.layer3.0.conv1.weight | (256, 512, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03077 | -0.00127 | 0.02202 | +#| 25 | backbone.body.layer3.0.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01736 | -0.00041 | 0.01279 | +#| 26 | backbone.body.layer3.0.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02338 | -0.00042 | 0.01660 | +#| 27 | backbone.body.layer3.0.downsample.0.weight | (1024, 512, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01615 | 0.00004 | 0.01106 | +#| 28 | backbone.body.layer3.1.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01521 | -0.00055 | 0.01063 | +#| 29 | backbone.body.layer3.1.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01510 | -0.00031 | 0.01102 | +#| 30 | backbone.body.layer3.1.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02036 | -0.00109 | 0.01457 | +#| 31 | backbone.body.layer3.2.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01575 | -0.00052 | 0.01125 | +#| 32 | backbone.body.layer3.2.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01500 | -0.00071 | 0.01127 | +#| 33 | backbone.body.layer3.2.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01908 | -0.00077 | 0.01393 | +#| 34 | backbone.body.layer3.3.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01734 | -0.00077 | 0.01283 | +#| 35 | backbone.body.layer3.3.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01493 | -0.00077 | 0.01142 | +#| 36 | backbone.body.layer3.3.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01821 | -0.00117 | 0.01341 | +#| 37 | backbone.body.layer3.4.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01811 | -0.00101 | 0.01362 | +#| 38 | backbone.body.layer3.4.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01487 | -0.00094 | 0.01141 | +#| 39 | backbone.body.layer3.4.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01815 | -0.00160 | 0.01334 | +#| 40 | backbone.body.layer3.5.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01970 | -0.00096 | 0.01498 | +#| 41 | backbone.body.layer3.5.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01524 | -0.00090 | 0.01170 | +#| 42 | backbone.body.layer3.5.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01923 | -0.00231 | 0.01438 | +#| 43 | backbone.body.layer4.0.conv1.weight | (512, 1024, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02380 | -0.00136 | 0.01837 | +#| 44 | backbone.body.layer4.0.conv2.weight | (512, 512, 3, 3) | 2359296 | 2359296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01255 | -0.00051 | 0.00981 | +#| 45 | backbone.body.layer4.0.conv3.weight | (2048, 512, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01529 | -0.00061 | 0.01172 | +#| 46 | backbone.body.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) | 2097152 | 2097152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00994 | -0.00007 | 0.00750 | +#| 47 | backbone.body.layer4.1.conv1.weight | (512, 2048, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01472 | -0.00092 | 0.01143 | +#| 48 | backbone.body.layer4.1.conv2.weight | (512, 512, 3, 3) | 2359296 | 2359296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01230 | -0.00084 | 0.00971 | +#| 49 | backbone.body.layer4.1.conv3.weight | (2048, 512, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01495 | -0.00011 | 0.01149 | +#| 50 | backbone.body.layer4.2.conv1.weight | (512, 2048, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01813 | -0.00050 | 0.01413 | +#| 51 | backbone.body.layer4.2.conv2.weight | (512, 512, 3, 3) | 2359296 | 2359296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01104 | -0.00066 | 0.00874 | +#| 52 | backbone.body.layer4.2.conv3.weight | (2048, 512, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01417 | -0.00003 | 0.01062 | +#| 53 | backbone.fpn.inner_blocks.0.weight | (256, 256, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03602 | 0.00015 | 0.03119 | +#| 54 | backbone.fpn.inner_blocks.1.weight | (256, 512, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02552 | -0.00001 | 0.02209 | +#| 55 | backbone.fpn.inner_blocks.2.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01804 | -0.00005 | 0.01563 | +#| 56 | backbone.fpn.inner_blocks.3.weight | (256, 2048, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01276 | 0.00001 | 0.01105 | +#| 57 | backbone.fpn.layer_blocks.0.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01202 | -0.00001 | 0.01041 | +#| 58 | backbone.fpn.layer_blocks.1.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01204 | -0.00000 | 0.01043 | +#| 59 | backbone.fpn.layer_blocks.2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01204 | -0.00001 | 0.01043 | +#| 60 | backbone.fpn.layer_blocks.3.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01202 | 0.00001 | 0.01041 | +#| 61 | rpn.head.conv.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00999 | -0.00001 | 0.00797 | +#| 62 | rpn.head.cls_logits.weight | (3, 256, 1, 1) | 768 | 768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00978 | 0.00015 | 0.00772 | +#| 63 | rpn.head.bbox_pred.weight | (12, 256, 1, 1) | 3072 | 3072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00982 | -0.00034 | 0.00786 | +#| 64 | roi_heads.box_head.fc6.weight | (1024, 12544) | 12845056 | 12845053 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00002 | 0.00515 | -0.00000 | 0.00446 | +#| 65 | roi_heads.box_head.fc7.weight | (1024, 1024) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01805 | -0.00003 | 0.01563 | +#| 66 | roi_heads.box_predictor.cls_score.weight | (91, 1024) | 93184 | 93184 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01806 | 0.00009 | 0.01564 | +#| 67 | roi_heads.box_predictor.bbox_pred.weight | (364, 1024) | 372736 | 372736 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01805 | 0.00001 | 0.01563 | +#| 68 | roi_heads.mask_head.mask_fcn1.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02944 | 0.00004 | 0.02350 | +#| 69 | roi_heads.mask_head.mask_fcn2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02947 | -0.00004 | 0.02351 | +#| 70 | roi_heads.mask_head.mask_fcn3.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02949 | 0.00004 | 0.02352 | +#| 71 | roi_heads.mask_head.mask_fcn4.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02949 | 0.00000 | 0.02354 | +#| 72 | roi_heads.mask_predictor.conv5_mask.weight | (256, 256, 2, 2) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04420 | -0.00003 | 0.03524 | +#| 73 | roi_heads.mask_predictor.mask_fcn_logits.weight | (91, 256, 1, 1) | 23296 | 23296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14817 | 0.00125 | 0.11834 | +#| 74 | Total sparsity: | - | 44395200 | 44395197 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00001 | 0.00000 | 0.00000 | 0.00000 | +#+----+-------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ + +version: 1 + +pruners: + + fc_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.85 + weights: [ + roi_heads.box_head.fc6.weight, + roi_heads.box_head.fc7.weight + ] + + agp_pruner_75: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.75 + weights: [ + backbone.body.layer1.0.conv1.weight, + backbone.body.layer1.0.conv2.weight, + backbone.body.layer1.0.conv3.weight, + backbone.body.layer1.0.downsample.0.weight, + backbone.body.layer1.1.conv1.weight, + backbone.body.layer1.1.conv2.weight, + backbone.body.layer1.1.conv3.weight, + backbone.body.layer1.2.conv1.weight, + backbone.body.layer1.2.conv2.weight, + backbone.body.layer1.2.conv3.weight, + backbone.body.layer2.0.conv1.weight, + backbone.body.layer2.0.conv2.weight, + backbone.body.layer2.0.conv3.weight,] + + agp_pruner_85: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.85 + weights: [ + backbone.body.layer2.0.downsample.0.weight, + backbone.body.layer2.1.conv1.weight, + backbone.body.layer2.1.conv2.weight, + backbone.body.layer2.1.conv3.weight, + backbone.body.layer2.2.conv1.weight, + backbone.body.layer2.2.conv2.weight, + backbone.body.layer2.2.conv3.weight, + backbone.body.layer2.3.conv1.weight, + backbone.body.layer2.3.conv2.weight, + backbone.body.layer2.3.conv3.weight, + backbone.body.layer3.0.conv1.weight, + backbone.body.layer3.0.conv2.weight, + backbone.body.layer3.0.conv3.weight, + backbone.body.layer3.0.downsample.0.weight, + backbone.body.layer3.1.conv1.weight, + backbone.body.layer3.1.conv2.weight, + backbone.body.layer3.1.conv3.weight, + backbone.body.layer3.2.conv1.weight, + backbone.body.layer3.2.conv2.weight, + backbone.body.layer3.2.conv3.weight, + backbone.body.layer3.3.conv1.weight, + backbone.body.layer3.3.conv2.weight, + backbone.body.layer3.3.conv3.weight, + backbone.body.layer3.4.conv1.weight, + backbone.body.layer3.4.conv2.weight, + backbone.body.layer3.4.conv3.weight, + backbone.body.layer3.5.conv1.weight, + backbone.body.layer3.5.conv2.weight, + backbone.body.layer3.5.conv3.weight, + backbone.body.layer4.2.conv3.weight, + backbone.fpn.inner_blocks.0.weight, + backbone.fpn.inner_blocks.1.weight, + backbone.fpn.inner_blocks.2.weight, + backbone.fpn.inner_blocks.3.weight, + backbone.fpn.layer_blocks.0.weight, + backbone.fpn.layer_blocks.1.weight, + backbone.fpn.layer_blocks.2.weight, + backbone.fpn.layer_blocks.3.weight, + rpn.head.conv.weight, + roi_heads.mask_head.mask_fcn1.weight, + roi_heads.mask_head.mask_fcn2.weight, + roi_heads.mask_head.mask_fcn3.weight, + roi_heads.mask_head.mask_fcn4.weight, + roi_heads.mask_predictor.conv5_mask.weight, + ] + + agp_pruner_90: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.90 + weights: [ + backbone.body.layer4.0.conv1.weight, + backbone.body.layer4.0.conv2.weight, + backbone.body.layer4.0.conv3.weight, + backbone.body.layer4.0.downsample.0.weight, + backbone.body.layer4.1.conv1.weight, + backbone.body.layer4.1.conv2.weight, + backbone.body.layer4.1.conv3.weight, + backbone.body.layer4.2.conv1.weight, + backbone.body.layer4.2.conv2.weight, + ] + + +policies: + - pruner: + instance_name : agp_pruner_75 + starting_epoch: 0 + ending_epoch: 45 + frequency: 1 + + - pruner: + instance_name : agp_pruner_85 + starting_epoch: 0 + ending_epoch: 45 + frequency: 1 + + - pruner: + instance_name : fc_pruner + starting_epoch: 0 + ending_epoch: 45 + frequency: 3 + + - pruner: + instance_name : agp_pruner_90 + starting_epoch: 0 + ending_epoch: 45 + frequency: 1 diff --git a/examples/object_detection_compression/maskrcnn.scheduler_agp.yaml b/examples/object_detection_compression/maskrcnn.scheduler_agp.yaml new file mode 100644 index 0000000..fda56c6 --- /dev/null +++ b/examples/object_detection_compression/maskrcnn.scheduler_agp.yaml @@ -0,0 +1,267 @@ +# +# Command line: +# python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env compress_detector.py --data-path $DATASET_COCO \ +# --compress maskrcnn.scheduler_agp.yaml --world-size $NGPU +# +#Parameters: +#+----+--------------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +#| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +#|----+--------------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +#| 0 | module.backbone.body.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12322 | -0.00050 | 0.07605 | +#| 1 | module.backbone.body.layer1.0.conv1.weight | (64, 64, 1, 1) | 4096 | 1024 | 0.00000 | 0.00000 | 3.12500 | 75.00000 | 7.81250 | 75.00000 | 0.06810 | -0.00616 | 0.02771 | +#| 2 | module.backbone.body.layer1.0.conv2.weight | (64, 64, 3, 3) | 36864 | 9216 | 0.00000 | 0.00000 | 7.81250 | 32.93457 | 6.25000 | 75.00000 | 0.02777 | 0.00070 | 0.01135 | +#| 3 | module.backbone.body.layer1.0.conv3.weight | (256, 64, 1, 1) | 16384 | 4096 | 0.00000 | 0.00000 | 6.25000 | 75.00000 | 12.10938 | 75.00000 | 0.03366 | 0.00035 | 0.01453 | +#| 4 | module.backbone.body.layer1.0.downsample.0.weight | (256, 64, 1, 1) | 16384 | 4096 | 0.00000 | 0.00000 | 1.56250 | 75.00000 | 13.28125 | 75.00000 | 0.05548 | -0.00382 | 0.02256 | +#| 5 | module.backbone.body.layer1.1.conv1.weight | (64, 256, 1, 1) | 16384 | 4096 | 0.00000 | 0.00000 | 11.71875 | 75.00000 | 6.25000 | 75.00000 | 0.02841 | 0.00126 | 0.01292 | +#| 6 | module.backbone.body.layer1.1.conv2.weight | (64, 64, 3, 3) | 36864 | 9216 | 0.00000 | 0.00000 | 6.25000 | 26.46484 | 0.00000 | 75.00000 | 0.02650 | 0.00022 | 0.01178 | +#| 7 | module.backbone.body.layer1.1.conv3.weight | (256, 64, 1, 1) | 16384 | 4096 | 0.00000 | 0.00000 | 0.00000 | 75.00000 | 3.51562 | 75.00000 | 0.03090 | 0.00000 | 0.01370 | +#| 8 | module.backbone.body.layer1.2.conv1.weight | (64, 256, 1, 1) | 16384 | 4096 | 0.00000 | 0.00000 | 7.03125 | 75.00000 | 0.00000 | 75.00000 | 0.02725 | 0.00012 | 0.01270 | +#| 9 | module.backbone.body.layer1.2.conv2.weight | (64, 64, 3, 3) | 36864 | 9216 | 0.00000 | 0.00000 | 0.00000 | 21.19141 | 0.00000 | 75.00000 | 0.02828 | -0.00049 | 0.01323 | +#| 10 | module.backbone.body.layer1.2.conv3.weight | (256, 64, 1, 1) | 16384 | 4096 | 0.00000 | 0.00000 | 0.00000 | 75.00000 | 1.17188 | 75.00000 | 0.02999 | -0.00228 | 0.01325 | +#| 11 | module.backbone.body.layer2.0.conv1.weight | (128, 256, 1, 1) | 32768 | 8192 | 0.00000 | 0.00000 | 3.90625 | 75.00000 | 0.00000 | 75.00000 | 0.04454 | 0.00213 | 0.01983 | +#| 12 | module.backbone.body.layer2.0.conv2.weight | (128, 128, 3, 3) | 147456 | 36864 | 0.00000 | 0.00000 | 0.00000 | 28.43018 | 0.00000 | 75.00000 | 0.02019 | 0.00093 | 0.00941 | +#| 13 | module.backbone.body.layer2.0.conv3.weight | (512, 128, 1, 1) | 65536 | 16384 | 0.00000 | 0.00000 | 0.00000 | 75.00000 | 28.90625 | 75.00000 | 0.03100 | 0.00207 | 0.01292 | +#| 14 | module.backbone.body.layer2.0.downsample.0.weight | (512, 256, 1, 1) | 131072 | 19661 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 15.62500 | 84.99985 | 0.01999 | 0.00119 | 0.00595 | +#| 15 | module.backbone.body.layer2.1.conv1.weight | (128, 512, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 17.57812 | 84.99908 | 0.00000 | 84.99908 | 0.01914 | 0.00124 | 0.00620 | +#| 16 | module.backbone.body.layer2.1.conv2.weight | (128, 128, 3, 3) | 147456 | 22119 | 0.00000 | 0.00000 | 0.00000 | 53.41187 | 0.00000 | 84.99959 | 0.01481 | 0.00072 | 0.00531 | +#| 17 | module.backbone.body.layer2.1.conv3.weight | (512, 128, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 0.00000 | 84.99908 | 38.28125 | 84.99908 | 0.02156 | 0.00029 | 0.00685 | +#| 18 | module.backbone.body.layer2.2.conv1.weight | (128, 512, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 4.29688 | 84.99908 | 0.00000 | 84.99908 | 0.01961 | 0.00100 | 0.00688 | +#| 19 | module.backbone.body.layer2.2.conv2.weight | (128, 128, 3, 3) | 147456 | 22119 | 0.00000 | 0.00000 | 0.00000 | 41.84570 | 0.00000 | 84.99959 | 0.01564 | 0.00037 | 0.00562 | +#| 20 | module.backbone.body.layer2.2.conv3.weight | (512, 128, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 0.00000 | 84.99908 | 6.44531 | 84.99908 | 0.02155 | 0.00039 | 0.00747 | +#| 21 | module.backbone.body.layer2.3.conv1.weight | (128, 512, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 4.10156 | 84.99908 | 0.00000 | 84.99908 | 0.02108 | 0.00067 | 0.00752 | +#| 22 | module.backbone.body.layer2.3.conv2.weight | (128, 128, 3, 3) | 147456 | 22119 | 0.00000 | 0.00000 | 0.00000 | 36.71265 | 0.00000 | 84.99959 | 0.01695 | 0.00009 | 0.00623 | +#| 23 | module.backbone.body.layer2.3.conv3.weight | (512, 128, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 0.00000 | 84.99908 | 24.80469 | 84.99908 | 0.02311 | -0.00029 | 0.00796 | +#| 24 | module.backbone.body.layer3.0.conv1.weight | (256, 512, 1, 1) | 131072 | 19661 | 0.00000 | 0.00000 | 0.19531 | 84.99985 | 0.00000 | 84.99985 | 0.02292 | 0.00031 | 0.00801 | +#| 25 | module.backbone.body.layer3.0.conv2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 47.98279 | 0.00000 | 84.99993 | 0.01155 | 0.00032 | 0.00421 | +#| 26 | module.backbone.body.layer3.0.conv3.weight | (1024, 256, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 6.34766 | 84.99985 | 0.01704 | 0.00087 | 0.00601 | +#| 27 | module.backbone.body.layer3.0.downsample.0.weight | (1024, 512, 1, 1) | 524288 | 78644 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 4.39453 | 84.99985 | 0.01072 | 0.00059 | 0.00370 | +#| 28 | module.backbone.body.layer3.1.conv1.weight | (256, 1024, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 9.08203 | 84.99985 | 0.00000 | 84.99985 | 0.01141 | 0.00051 | 0.00392 | +#| 29 | module.backbone.body.layer3.1.conv2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 46.10596 | 0.00000 | 84.99993 | 0.00989 | 0.00000 | 0.00357 | +#| 30 | module.backbone.body.layer3.1.conv3.weight | (1024, 256, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 1.36719 | 84.99985 | 0.01376 | -0.00064 | 0.00485 | +#| 31 | module.backbone.body.layer3.2.conv1.weight | (256, 1024, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 2.24609 | 84.99985 | 0.00000 | 84.99985 | 0.01204 | 0.00033 | 0.00413 | +#| 32 | module.backbone.body.layer3.2.conv2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 37.43591 | 0.00000 | 84.99993 | 0.01000 | -0.00007 | 0.00365 | +#| 33 | module.backbone.body.layer3.2.conv3.weight | (1024, 256, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 1.85547 | 84.99985 | 0.01354 | -0.00039 | 0.00477 | +#| 34 | module.backbone.body.layer3.3.conv1.weight | (256, 1024, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.97656 | 84.99985 | 0.00000 | 84.99985 | 0.01287 | 0.00016 | 0.00452 | +#| 35 | module.backbone.body.layer3.3.conv2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 36.71875 | 0.00000 | 84.99993 | 0.01004 | 0.00000 | 0.00368 | +#| 36 | module.backbone.body.layer3.3.conv3.weight | (1024, 256, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 6.64062 | 84.99985 | 0.01315 | -0.00024 | 0.00462 | +#| 37 | module.backbone.body.layer3.4.conv1.weight | (256, 1024, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.29297 | 84.99985 | 0.00000 | 84.99985 | 0.01321 | 0.00005 | 0.00468 | +#| 38 | module.backbone.body.layer3.4.conv2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 38.03711 | 0.00000 | 84.99993 | 0.01016 | -0.00003 | 0.00370 | +#| 39 | module.backbone.body.layer3.4.conv3.weight | (1024, 256, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 8.00781 | 84.99985 | 0.01311 | -0.00017 | 0.00457 | +#| 40 | module.backbone.body.layer3.5.conv1.weight | (256, 1024, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.29297 | 84.99985 | 0.00000 | 84.99985 | 0.01289 | 0.00021 | 0.00455 | +#| 41 | module.backbone.body.layer3.5.conv2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 40.31677 | 0.00000 | 84.99993 | 0.00959 | 0.00013 | 0.00349 | +#| 42 | module.backbone.body.layer3.5.conv3.weight | (1024, 256, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 5.76172 | 84.99985 | 0.01286 | -0.00098 | 0.00458 | +#| 43 | module.backbone.body.layer4.0.conv1.weight | (512, 1024, 1, 1) | 524288 | 52429 | 0.00000 | 0.00000 | 0.00000 | 89.99996 | 0.00000 | 89.99996 | 0.01248 | 0.00061 | 0.00369 | +#| 44 | module.backbone.body.layer4.0.conv2.weight | (512, 512, 3, 3) | 2359296 | 235930 | 0.00000 | 0.00000 | 0.00000 | 57.85141 | 0.00000 | 89.99998 | 0.00621 | 0.00005 | 0.00189 | +#| 45 | module.backbone.body.layer4.0.conv3.weight | (2048, 512, 1, 1) | 1048576 | 104858 | 0.00000 | 0.00000 | 0.00000 | 89.99996 | 0.29297 | 89.99996 | 0.00865 | 0.00026 | 0.00259 | +#| 46 | module.backbone.body.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) | 2097152 | 209716 | 0.00000 | 0.00000 | 0.00000 | 89.99996 | 0.04883 | 89.99996 | 0.00611 | -0.00007 | 0.00184 | +#| 47 | module.backbone.body.layer4.1.conv1.weight | (512, 2048, 1, 1) | 1048576 | 104858 | 0.00000 | 0.00000 | 0.24414 | 89.99996 | 0.00000 | 89.99996 | 0.00817 | 0.00041 | 0.00238 | +#| 48 | module.backbone.body.layer4.1.conv2.weight | (512, 512, 3, 3) | 2359296 | 235930 | 0.00000 | 0.00000 | 0.00000 | 57.28302 | 0.00000 | 89.99998 | 0.00649 | 0.00001 | 0.00196 | +#| 49 | module.backbone.body.layer4.1.conv3.weight | (2048, 512, 1, 1) | 1048576 | 104858 | 0.00000 | 0.00000 | 0.00000 | 89.99996 | 0.24414 | 89.99996 | 0.00867 | -0.00013 | 0.00259 | +#| 50 | module.backbone.body.layer4.2.conv1.weight | (512, 2048, 1, 1) | 1048576 | 104858 | 0.00000 | 0.00000 | 0.04883 | 89.99996 | 0.00000 | 89.99996 | 0.00876 | 0.00056 | 0.00259 | +#| 51 | module.backbone.body.layer4.2.conv2.weight | (512, 512, 3, 3) | 2359296 | 235930 | 0.00000 | 0.00000 | 0.00000 | 59.87663 | 0.00000 | 89.99998 | 0.00634 | 0.00001 | 0.00189 | +#| 52 | module.backbone.body.layer4.2.conv3.weight | (2048, 512, 1, 1) | 1048576 | 157287 | 0.00000 | 0.00000 | 0.00000 | 84.99994 | 0.04883 | 84.99994 | 0.00883 | -0.00022 | 0.00311 | +#| 53 | module.backbone.fpn.inner_blocks.0.weight | (256, 256, 1, 1) | 65536 | 9831 | 0.00000 | 0.00000 | 9.37500 | 84.99908 | 0.00000 | 84.99908 | 0.00920 | 0.00016 | 0.00340 | +#| 54 | module.backbone.fpn.inner_blocks.1.weight | (256, 512, 1, 1) | 131072 | 19661 | 0.00000 | 0.00000 | 0.58594 | 84.99985 | 0.00000 | 84.99985 | 0.01087 | 0.00008 | 0.00403 | +#| 55 | module.backbone.fpn.inner_blocks.2.weight | (256, 1024, 1, 1) | 262144 | 39322 | 0.00000 | 0.00000 | 0.09766 | 84.99985 | 0.00000 | 84.99985 | 0.00976 | -0.00003 | 0.00358 | +#| 56 | module.backbone.fpn.inner_blocks.3.weight | (256, 2048, 1, 1) | 524288 | 78644 | 0.00000 | 0.00000 | 0.00000 | 84.99985 | 0.00000 | 84.99985 | 0.01260 | -0.00003 | 0.00470 | +#| 57 | module.backbone.fpn.layer_blocks.0.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 43.46161 | 0.00000 | 84.99993 | 0.00841 | -0.00003 | 0.00318 | +#| 58 | module.backbone.fpn.layer_blocks.1.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 50.99182 | 0.00000 | 84.99993 | 0.00608 | -0.00001 | 0.00230 | +#| 59 | module.backbone.fpn.layer_blocks.2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 50.06256 | 0.00000 | 84.99993 | 0.00655 | -0.00001 | 0.00248 | +#| 60 | module.backbone.fpn.layer_blocks.3.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 49.36371 | 0.00000 | 84.99993 | 0.00635 | -0.00003 | 0.00240 | +#| 61 | module.rpn.head.conv.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 41.99219 | 0.00000 | 84.99993 | 0.00580 | 0.00014 | 0.00215 | +#| 62 | module.rpn.head.cls_logits.weight | (3, 256, 1, 1) | 768 | 768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15008 | -0.01621 | 0.11047 | +#| 63 | module.rpn.head.bbox_pred.weight | (12, 256, 1, 1) | 3072 | 3072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03009 | -0.00090 | 0.01457 | +#| 64 | module.roi_heads.box_head.fc6.weight | (1024, 12544) | 12845056 | 1926759 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 85.00000 | 0.00312 | 0.00001 | 0.00115 | +#| 65 | module.roi_heads.box_head.fc7.weight | (1024, 1024) | 1048576 | 157287 | 1.66016 | 0.00000 | 0.00000 | 1.66016 | 0.00000 | 84.99994 | 0.01011 | -0.00015 | 0.00357 | +#| 66 | module.roi_heads.box_predictor.cls_score.weight | (91, 1024) | 93184 | 93184 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03252 | -0.00000 | 0.02113 | +#| 67 | module.roi_heads.box_predictor.bbox_pred.weight | (364, 1024) | 372736 | 372736 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01165 | 0.00011 | 0.00589 | +#| 68 | module.roi_heads.mask_head.mask_fcn1.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 40.37628 | 0.00000 | 84.99993 | 0.00778 | -0.00003 | 0.00290 | +#| 69 | module.roi_heads.mask_head.mask_fcn2.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 45.38727 | 0.00000 | 84.99993 | 0.00810 | -0.00004 | 0.00294 | +#| 70 | module.roi_heads.mask_head.mask_fcn3.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.00000 | 50.18463 | 1.95312 | 84.99993 | 0.00791 | -0.00003 | 0.00282 | +#| 71 | module.roi_heads.mask_head.mask_fcn4.weight | (256, 256, 3, 3) | 589824 | 88474 | 0.00000 | 0.00000 | 0.78125 | 61.83777 | 19.92188 | 84.99993 | 0.00762 | 0.00054 | 0.00259 | +#| 72 | module.roi_heads.mask_predictor.conv5_mask.weight | (256, 256, 2, 2) | 262144 | 39322 | 0.00000 | 0.00000 | 7.81250 | 76.00861 | 19.53125 | 84.99985 | 0.01107 | 0.00141 | 0.00371 | +#| 73 | module.roi_heads.mask_predictor.mask_fcn_logits.weight | (91, 256, 1, 1) | 23296 | 23296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04107 | -0.00201 | 0.02535 | +#| 74 | Total sparsity: | - | 44395200 | 6437593 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 85.49935 | 0.00000 | 0.00000 | 0.00000 | +#+----+--------------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +#Total sparsity: 85.50 +#Results: +# Baseline: + #IoU metric: bbox + # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.379 + # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.592 + # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.410 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.215 + # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.414 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.495 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.312 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.494 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.518 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.321 + # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.559 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.660 + #IoU metric: segm + # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.346 + # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.561 + # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.367 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.156 + # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.373 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.509 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.294 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.454 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.474 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.269 + # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.515 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631 + +#Post Pruning: + #IoU metric: bbox + # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.352 + # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.558 + # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.376 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.188 + # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.389 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.460 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.302 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.481 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.504 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.300 + # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.546 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.646 + #IoU metric: segm + # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.320 + # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.525 + # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.338 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.134 + # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.346 + # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.482 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.284 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.439 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.459 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.250 + # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.499 + # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.618 + +version: 1 + +pruners: + + fc_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.85 + weights: [ + module.roi_heads.box_head.fc6.weight, + module.roi_heads.box_head.fc7.weight + ] + + agp_pruner_75: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.75 + weights: [ + module.backbone.body.layer1.0.conv1.weight, + module.backbone.body.layer1.0.conv2.weight, + module.backbone.body.layer1.0.conv3.weight, + module.backbone.body.layer1.0.downsample.0.weight, + module.backbone.body.layer1.1.conv1.weight, + module.backbone.body.layer1.1.conv2.weight, + module.backbone.body.layer1.1.conv3.weight, + module.backbone.body.layer1.2.conv1.weight, + module.backbone.body.layer1.2.conv2.weight, + module.backbone.body.layer1.2.conv3.weight, + module.backbone.body.layer2.0.conv1.weight, + module.backbone.body.layer2.0.conv2.weight, + module.backbone.body.layer2.0.conv3.weight,] + + agp_pruner_85: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.85 + weights: [ + module.backbone.body.layer2.0.downsample.0.weight, + module.backbone.body.layer2.1.conv1.weight, + module.backbone.body.layer2.1.conv2.weight, + module.backbone.body.layer2.1.conv3.weight, + module.backbone.body.layer2.2.conv1.weight, + module.backbone.body.layer2.2.conv2.weight, + module.backbone.body.layer2.2.conv3.weight, + module.backbone.body.layer2.3.conv1.weight, + module.backbone.body.layer2.3.conv2.weight, + module.backbone.body.layer2.3.conv3.weight, + module.backbone.body.layer3.0.conv1.weight, + module.backbone.body.layer3.0.conv2.weight, + module.backbone.body.layer3.0.conv3.weight, + module.backbone.body.layer3.0.downsample.0.weight, + module.backbone.body.layer3.1.conv1.weight, + module.backbone.body.layer3.1.conv2.weight, + module.backbone.body.layer3.1.conv3.weight, + module.backbone.body.layer3.2.conv1.weight, + module.backbone.body.layer3.2.conv2.weight, + module.backbone.body.layer3.2.conv3.weight, + module.backbone.body.layer3.3.conv1.weight, + module.backbone.body.layer3.3.conv2.weight, + module.backbone.body.layer3.3.conv3.weight, + module.backbone.body.layer3.4.conv1.weight, + module.backbone.body.layer3.4.conv2.weight, + module.backbone.body.layer3.4.conv3.weight, + module.backbone.body.layer3.5.conv1.weight, + module.backbone.body.layer3.5.conv2.weight, + module.backbone.body.layer3.5.conv3.weight, + module.backbone.body.layer4.2.conv3.weight, + module.backbone.fpn.inner_blocks.0.weight, + module.backbone.fpn.inner_blocks.1.weight, + module.backbone.fpn.inner_blocks.2.weight, + module.backbone.fpn.inner_blocks.3.weight, + module.backbone.fpn.layer_blocks.0.weight, + module.backbone.fpn.layer_blocks.1.weight, + module.backbone.fpn.layer_blocks.2.weight, + module.backbone.fpn.layer_blocks.3.weight, + module.rpn.head.conv.weight, + module.roi_heads.mask_head.mask_fcn1.weight, + module.roi_heads.mask_head.mask_fcn2.weight, + module.roi_heads.mask_head.mask_fcn3.weight, + module.roi_heads.mask_head.mask_fcn4.weight, + module.roi_heads.mask_predictor.conv5_mask.weight, + ] + + agp_pruner_90: + class: AutomatedGradualPruner + initial_sparsity : 0.01 + final_sparsity: 0.90 + weights: [ + module.backbone.body.layer4.0.conv1.weight, + module.backbone.body.layer4.0.conv2.weight, + module.backbone.body.layer4.0.conv3.weight, + module.backbone.body.layer4.0.downsample.0.weight, + module.backbone.body.layer4.1.conv1.weight, + module.backbone.body.layer4.1.conv2.weight, + module.backbone.body.layer4.1.conv3.weight, + module.backbone.body.layer4.2.conv1.weight, + module.backbone.body.layer4.2.conv2.weight, + ] + + +policies: + - pruner: + instance_name : agp_pruner_75 + starting_epoch: 0 + ending_epoch: 45 + frequency: 1 + + - pruner: + instance_name : agp_pruner_85 + starting_epoch: 0 + ending_epoch: 45 + frequency: 1 + + - pruner: + instance_name : fc_pruner + starting_epoch: 0 + ending_epoch: 45 + frequency: 3 + + - pruner: + instance_name : agp_pruner_90 + starting_epoch: 0 + ending_epoch: 45 + frequency: 1 diff --git a/examples/object_detection_compression/requirements.txt b/examples/object_detection_compression/requirements.txt new file mode 100644 index 0000000..d4e6f64 --- /dev/null +++ b/examples/object_detection_compression/requirements.txt @@ -0,0 +1,2 @@ +Cython +pycocotools diff --git a/examples/object_detection_compression/transforms.py b/examples/object_detection_compression/transforms.py new file mode 100644 index 0000000..90a36b3 --- /dev/null +++ b/examples/object_detection_compression/transforms.py @@ -0,0 +1,52 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection +import random +import torch + +from torchvision.transforms import functional as F + + +def _flip_coco_person_keypoints(kps, width): + flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + flipped_data = kps[:, flip_inds] + flipped_data[..., 0] = width - flipped_data[..., 0] + # Maintain COCO convention that if visibility == 0, then x, y = 0 + inds = flipped_data[..., 2] == 0 + flipped_data[inds] = 0 + return flipped_data + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, prob): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + height, width = image.shape[-2:] + image = image.flip(-1) + bbox = target["boxes"] + bbox[:, [0, 2]] = width - bbox[:, [2, 0]] + target["boxes"] = bbox + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + if "keypoints" in target: + keypoints = target["keypoints"] + keypoints = _flip_coco_person_keypoints(keypoints, width) + target["keypoints"] = keypoints + return image, target + + +class ToTensor(object): + def __call__(self, image, target): + image = F.to_tensor(image) + return image, target diff --git a/examples/object_detection_compression/utils.py b/examples/object_detection_compression/utils.py new file mode 100644 index 0000000..ce33966 --- /dev/null +++ b/examples/object_detection_compression/utils.py @@ -0,0 +1,328 @@ +# This code is originally from: +# https://github.com/pytorch/vision/tree/v0.4.2/references/detection +from __future__ import print_function + +from collections import defaultdict, deque +import datetime +import pickle +import time + +import torch +import torch.distributed as dist + +import errno +import os + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def collate_fn(batch): + return tuple(zip(*batch)) + + +def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): + + def f(x): + if x >= warmup_iters: + return 1 + alpha = float(x) / warmup_iters + return warmup_factor * (1 - alpha) + alpha + + return torch.optim.lr_scheduler.LambdaLR(optimizer, f) + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) -- GitLab