Skip to content
Snippets Groups Projects
  • Neta Zmora's avatar
    718f777b
    ADC (Automatic Deep Compression) example + features, tests, bug fixes (#28) · 718f777b
    Neta Zmora authored
    This is a merge of the ADC branch and master.
    ADC (using a DDPG RL agent to compress image classifiers) is still WiP and requires
    An unreleased version of Coach (https://github.com/NervanaSystems/coach).
    
    Small features in this commit:
    -Added model_find_module() - find module object given its name
    - Add channel ranking and pruning: pruning/ranked_structures_pruner.py
    - Add a CIFAR10 VGG16 model: models/cifar10/vgg_cifar.py
    - Thinning: change the level of some log messages – some of the messages were
    moved to ‘debug’ level because they are not usually interesting.
    - Add a function to print nicely formatted integers - distiller/utils.py
    - Sensitivity analysis for channels-removal
    - compress_classifier.py – handle keyboard interrupts
    - compress_classifier.py – fix re-raise of exceptions, so they maintain call-stack
    
    -Added tests:
    -- test_summarygraph.py: test_simplenet() - Added a regression test to target a bug that occurs when taking the predecessor of the first node in a graph
    -- test_ranking.py - test_ch_ranking, test_ranked_channel_pruning
    -- test_model_summary.py - test_png_generation, test_summary (sparsity/ compute/model/modules)
    
    - Bug fixes in this commit:
    -- Thinning bug fix: handle zero-sized 'indices' tensor
    During the thinning process, the 'indices' tensor can become zero-sized,
    and will have an undefiend length. Therefore, we need to check for this
    situation when assessing the number of elements in 'indices'
    -- Language model: adjust main.py to new distiller.model_summary API
    ADC (Automatic Deep Compression) example + features, tests, bug fixes (#28)
    Neta Zmora authored
    This is a merge of the ADC branch and master.
    ADC (using a DDPG RL agent to compress image classifiers) is still WiP and requires
    An unreleased version of Coach (https://github.com/NervanaSystems/coach).
    
    Small features in this commit:
    -Added model_find_module() - find module object given its name
    - Add channel ranking and pruning: pruning/ranked_structures_pruner.py
    - Add a CIFAR10 VGG16 model: models/cifar10/vgg_cifar.py
    - Thinning: change the level of some log messages – some of the messages were
    moved to ‘debug’ level because they are not usually interesting.
    - Add a function to print nicely formatted integers - distiller/utils.py
    - Sensitivity analysis for channels-removal
    - compress_classifier.py – handle keyboard interrupts
    - compress_classifier.py – fix re-raise of exceptions, so they maintain call-stack
    
    -Added tests:
    -- test_summarygraph.py: test_simplenet() - Added a regression test to target a bug that occurs when taking the predecessor of the first node in a graph
    -- test_ranking.py - test_ch_ranking, test_ranked_channel_pruning
    -- test_model_summary.py - test_png_generation, test_summary (sparsity/ compute/model/modules)
    
    - Bug fixes in this commit:
    -- Thinning bug fix: handle zero-sized 'indices' tensor
    During the thinning process, the 'indices' tensor can become zero-sized,
    and will have an undefiend length. Therefore, we need to check for this
    situation when assessing the number of elements in 'indices'
    -- Language model: adjust main.py to new distiller.model_summary API
test_model_summary.py 2.06 KiB
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import torch
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import distiller
import pytest
import common  # common test code
import apputils

# Logging configuration
logging.basicConfig(level=logging.INFO)
fh = logging.FileHandler('test.log')
logger = logging.getLogger()
logger.addHandler(fh)


def test_png_generation():
    DATASET = "cifar10"
    ARCH = "resnet20_cifar"
    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
    # 2 different ways to create a PNG
    apputils.draw_img_classifier_to_file(model, 'model.png', DATASET, True)
    apputils.draw_img_classifier_to_file(model, 'model.png', DATASET, False)


def test_negative():
    DATASET = "cifar10"
    ARCH = "resnet20_cifar"
    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)

    with pytest.raises(ValueError):
        # png is not a supported summary type, so we expect this to fail with a ValueError
        distiller.model_summary(model, what='png', dataset=DATASET)


def test_summary():
    DATASET = "cifar10"
    ARCH = "resnet20_cifar"
    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)

    distiller.model_summary(model, what='sparsity', dataset=DATASET)
    distiller.model_summary(model, what='compute', dataset=DATASET)
    distiller.model_summary(model, what='model', dataset=DATASET)
    distiller.model_summary(model, what='modules', dataset=DATASET)