Skip to content
Snippets Groups Projects
  • Guy Jacob's avatar
    a43b9f10
    Generalize the loss value returned from before_backward_pass callbacks (#38) · a43b9f10
    Guy Jacob authored
    * Instead of a single additive value (which so far represented only the
      regularizer loss), callbacks return a new overall loss
    * Policy callbacks also return the individual loss components used to
      calculate the new overall loss.
    * Add boolean flag to the Scheduler's callback so applications can choose
      if they want to get individual loss components, or just the new overall
      loss
    * In compress_classifier.py, log the individual loss components
    * Add test for the loss-from-callback flow
    Generalize the loss value returned from before_backward_pass callbacks (#38)
    Guy Jacob authored
    * Instead of a single additive value (which so far represented only the
      regularizer loss), callbacks return a new overall loss
    * Policy callbacks also return the individual loss components used to
      calculate the new overall loss.
    * Add boolean flag to the Scheduler's callback so applications can choose
      if they want to get individual loss components, or just the new overall
      loss
    * In compress_classifier.py, log the individual loss components
    * Add test for the loss-from-callback flow
main.py 15.56 KiB
# This code is originally from:
#   https://github.com/pytorch/examples/blob/master/word_language_model/main.py
# It contains a fix as per:
#   https://github.com/pytorch/examples/issues/214
# It contains code to support compression (distiller)
#
# coding: utf-8
import argparse
import time
import math
import os
import torch
import torch.nn as nn
import torch.onnx
from collections import OrderedDict
import data
import model

# Distiller imports
import os
import sys
script_dir = os.path.dirname(__file__)
module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import distiller
import apputils
from distiller.data_loggers import TensorBoardLogger, PythonLogger


parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2',
                    help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
                    help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)')
parser.add_argument('--emsize', type=int, default=200,
                    help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=200,
                    help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=2,
                    help='number of layers')
parser.add_argument('--lr', type=float, default=20,
                    help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
                    help='gradient clipping')
parser.add_argument('--epochs', type=int, default=40,
                    help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=20, metavar='N',
                    help='batch size')
parser.add_argument('--bptt', type=int, default=35,
                    help='sequence length')
parser.add_argument('--dropout', type=float, default=0.2,
                    help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--tied', action='store_true',
                    help='tie the word embedding and softmax weights')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
                    help='report interval')
parser.add_argument('--save', type=str, default='checkpoint.pth.tar',
                    help='path to save the final model')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--onnx-export', type=str, default='',
                    help='path to export the final model in onnx format')

# Distiller-related arguments
SUMMARY_CHOICES = ['sparsity', 'model', 'modules', 'png', 'percentile']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
                    help='print a summary of the model, and exit - options: ' +
                    ' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
                    help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--momentum', default=0., type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0., type=float,
                    metavar='W', help='weight decay (default: 1e-4)')

args = parser.parse_args()

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

device = torch.device("cuda" if args.cuda else "cpu")


def draw_lang_model_to_file(model, png_fname, dataset):
    """Draw a language model graph to a PNG file.

    Caveat: the PNG that is produced has some problems, which we suspect are due to
    PyTorch issues related to RNN ONNX export.
    """
    try:
        if dataset == 'wikitext2':
            batch_size = 20
            seq_len = 35
            dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to('cuda')
            hidden = model.init_hidden(batch_size)
            dummy_input = (dummy_input, hidden)
        else:
            msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset)
            return
        g = apputils.SummaryGraph(model, dummy_input)
        apputils.draw_model_to_file(g, png_fname)
        msglogger.info("Network PNG image generation completed")

    except FileNotFoundError as e:
        msglogger.info("An error has occured while generating the network PNG image.")
        msglogger.info("Please check that you have graphviz installed.")
        msglogger.info("\t$ sudo apt-get install graphviz")
        raise e

###############################################################################
# Load data
###############################################################################

print("Preparing data (this may take several seconds)...")
corpus = data.Corpus(args.data)
print("Done")
# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.

def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

eval_batch_size = 10
train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

###############################################################################
# Build the model
###############################################################################

ntokens = len(corpus.dictionary)

if args.resume:
    with open(args.resume, 'rb') as f:
        model = torch.load(f)
        # after load the rnn params are not a continuous chunk of memory
        # this makes them a continuous chunk, and will speed up forward pass
        model.rnn.flatten_parameters()
else:
    model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)


criterion = nn.CrossEntropyLoss()

###############################################################################
# Training code
###############################################################################

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.

def get_batch(source, i):
    seq_len = min(args.bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target


def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        # The line below was fixed as per: https://github.com/pytorch/examples/issues/214
        for i in range(0, data_source.size(0), args.bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model(data, hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
            hidden = repackage_hidden(hidden)
    return total_loss / len(data_source)


def train(epoch, optimizer, compression_scheduler=None):
    # Turn on training mode which enables dropout.
    model.train()

    total_samples = train_data.size(0)
    steps_per_epoch = math.ceil(total_samples / args.bptt)

    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    # The line below was fixed as per: https://github.com/pytorch/examples/issues/214
    for batch, i in enumerate(range(0, train_data.size(0), args.bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)

        if compression_scheduler:
            compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)

        if compression_scheduler:
            # Before running the backward phase, we allow the scheduler to modify the loss
            # (e.g. add regularization loss)
            loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
                                                              minibatches_per_epoch=steps_per_epoch, loss=loss,
                                                              return_loss_components=False)

        optimizer.zero_grad()
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.item()

        if compression_scheduler:
            compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            lr = optimizer.param_groups[0]['lr']
            msglogger.info(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} '
                    '| loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, lr,
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
            stats = ('Peformance/Training/',
                OrderedDict([
                    ('Loss', cur_loss),
                    ('Perplexity', math.exp(cur_loss)),
                    ('LR', lr),
                    ('Batch Time', elapsed * 1000)])
                )
            steps_completed = batch + 1
            distiller.log_training_progress(stats, model.named_parameters(), epoch, steps_completed,
                                            steps_per_epoch, args.log_interval, [tflogger])


def export_onnx(path, batch_size, seq_len):
    msglogger.info('The model is also exported in ONNX format at {}'.
          format(os.path.realpath(args.onnx_export)))
    model.eval()
    dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device)
    hidden = model.init_hidden(batch_size)
    torch.onnx.export(model, (dummy_input, hidden), path)


# Distiller loggers
msglogger = apputils.config_pylogger('logging.conf', None)
tflogger = TensorBoardLogger(msglogger.logdir)
tflogger.log_gradients = True
pylogger = PythonLogger(msglogger)

if args.summary:
    which_summary = args.summary
    if which_summary == 'png':
        draw_lang_model_to_file(model, 'rnn.png', 'wikitext2')
    elif which_summary == 'percentile':
        percentile = 0.9
        for name, param in model.state_dict().items():
            if param.dim() < 2:
                # Skip biases
                continue
            bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()),
                                    largest=False, sorted=True)
            threshold = bottomk.data[-1]
            msglogger.info("parameter %s: q = %.2f" %(name, threshold))
    else:
        distiller.model_summary(model, which_summary, 'wikitext2')
    exit(0)

compression_scheduler = None

if args.compress:
    # Create a CompressionScheduler and configure it from a YAML schedule file
    source = args.compress
    compression_scheduler = distiller.config.file_config(model, None, args.compress)

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                          patience=0, verbose=True, factor=0.5)

# Loop over epochs.
# At any point you can hit Ctrl + C to break out of training early.
best_val_loss = float("inf")
try:
    for epoch in range(0, args.epochs):
        epoch_start_time = time.time()
        if compression_scheduler:
            compression_scheduler.on_epoch_begin(epoch)

        train(epoch, optimizer, compression_scheduler)

        val_loss = evaluate(val_data)
        msglogger.info('-' * 89)
        msglogger.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
                       'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                           val_loss, math.exp(val_loss)))
        msglogger.info('-' * 89)

        distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])

        stats = ('Peformance/Validation/',
            OrderedDict([
                ('Loss', val_loss),
                ('Perplexity', math.exp(val_loss))]))
        tflogger.log_training_progress(stats, epoch, 0, total=1, freq=1)

        with open(args.save, 'wb') as f:
            torch.save(model, f)

        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open(args.save+".best", 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        lr_scheduler.step(val_loss)

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch)

except KeyboardInterrupt:
    msglogger.info('-' * 89)
    msglogger.info('Exiting from training early')

# Load the last saved model.
with open(args.save, 'rb') as f:
    model = torch.load(f)
    # after load the rnn params are not a continuous chunk of memory
    # this makes them a continuous chunk, and will speed up forward pass
    model.rnn.flatten_parameters()

# Run on test data.
test_loss = evaluate(test_data)
msglogger.info('=' * 89)
msglogger.info('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
msglogger.info('=' * 89)

if len(args.onnx_export) > 0:
    # Export the model in ONNX format.
    export_onnx(args.onnx_export, batch_size=1, seq_len=args.bptt)