From a9b28923c42c6d7c6f691716f597510ed0204c90 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 13 Jun 2018 10:42:54 +0300 Subject: [PATCH] Language model: replace the optimizer and LR-decay scheduler Replace the original "homebrew" optimizer and LR-decay schedule with PyTorch's SGD and ReduceLROnPlateau. SGD with momentum=0 and weight_decay=0, and ReduceLROnPlateau with patience=0 and factor=0.5 will give the same behavior as in the original PyTorch example. Having a standard optimizer and LR-decay schedule gives us the flexibility to experiment with these during the training process. --- examples/word_language_model/main.py | 103 +++++++++++++++------------ 1 file changed, 56 insertions(+), 47 deletions(-) diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py index 4d50581..328eb45 100755 --- a/examples/word_language_model/main.py +++ b/examples/word_language_model/main.py @@ -16,6 +16,7 @@ from collections import OrderedDict import data import model +# Distiller imports import os import sys script_dir = os.path.dirname(__file__) @@ -24,8 +25,8 @@ if module_path not in sys.path: sys.path.append(module_path) import distiller import apputils -from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector -import torchnet.meter as tnt +from distiller.data_loggers import TensorBoardLogger, PythonLogger + parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model') parser.add_argument('--data', type=str, default='./data/wikitext-2', @@ -58,7 +59,7 @@ parser.add_argument('--cuda', action='store_true', help='use CUDA') parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval') -parser.add_argument('--save', type=str, default='model.pt', +parser.add_argument('--save', type=str, default='checkpoint.pth.tar', help='path to save the final model') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') @@ -66,12 +67,17 @@ parser.add_argument('--onnx-export', type=str, default='', help='path to export the final model in onnx format') # Distiller-related arguments -SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'percentile'] +SUMMARY_CHOICES = ['sparsity', 'model', 'modules', 'png', 'percentile'] parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, help='print a summary of the model, and exit - options: ' + ' | '.join(SUMMARY_CHOICES)) parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', help='configuration file for pruning the model (default is to use hard-coded schedule)') +parser.add_argument('--momentum', default=0., type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=0., type=float, + metavar='W', help='weight decay (default: 1e-4)') + args = parser.parse_args() # Set the random seed manually for reproducibility. @@ -84,6 +90,11 @@ device = torch.device("cuda" if args.cuda else "cpu") def draw_lang_model_to_file(model, png_fname, dataset): + """Draw a language model graph to a PNG file. + + Caveat: the PNG that is produced has some problems, which we suspect are due to + PyTorch issues related to RNN ONNX export. + """ try: if dataset == 'wikitext2': batch_size = 20 @@ -92,16 +103,16 @@ def draw_lang_model_to_file(model, png_fname, dataset): hidden = model.init_hidden(batch_size) dummy_input = (dummy_input, hidden) else: - print("Unsupported dataset (%s) - aborting draw operation" % dataset) + msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset) return g = apputils.SummaryGraph(model, dummy_input) apputils.draw_model_to_file(g, png_fname) - print("Network PNG image generation completed") + msglogger.info("Network PNG image generation completed") except FileNotFoundError as e: - print("An error has occured while generating the network PNG image.") - print("Please check that you have graphviz installed.") - print("\t$ sudo apt-get install graphviz") + msglogger.info("An error has occured while generating the network PNG image.") + msglogger.info("Please check that you have graphviz installed.") + msglogger.info("\t$ sudo apt-get install graphviz") raise e ############################################################################### @@ -229,18 +240,13 @@ def train(epoch, optimizer, compression_scheduler=None): regularizer_loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch, loss=loss) loss += regularizer_loss - #losses['regularizer_loss'].add(regularizer_loss.item()) - model.zero_grad() - #optimizer.zero_grad() + optimizer.zero_grad() loss.backward() - # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - for p in model.parameters(): - p.data.add_(-lr, p.grad.data) - #optimizer.step() + optimizer.step() total_loss += loss.item() @@ -250,8 +256,10 @@ def train(epoch, optimizer, compression_scheduler=None): if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss / args.log_interval elapsed = time.time() - start_time - print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} | ' - 'loss {:5.2f} | ppl {:8.2f}'.format( + lr = optimizer.param_groups[0]['lr'] + msglogger.info( + '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} ' + '| loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, lr, elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0 @@ -264,13 +272,12 @@ def train(epoch, optimizer, compression_scheduler=None): ('Batch Time', elapsed * 1000)]) ) steps_completed = batch + 1 - #tflogger.log_training_progress(stats, epoch, steps_completed, total=steps_per_epoch, freq=args.log_interval) distiller.log_training_progress(stats, model.named_parameters(), epoch, steps_completed, steps_per_epoch, args.log_interval, [tflogger]) def export_onnx(path, batch_size, seq_len): - print('The model is also exported in ONNX format at {}'. + msglogger.info('The model is also exported in ONNX format at {}'. format(os.path.realpath(args.onnx_export))) model.eval() dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device) @@ -278,10 +285,7 @@ def export_onnx(path, batch_size, seq_len): torch.onnx.export(model, (dummy_input, hidden), path) -# Loop over epochs. -lr = args.lr -best_val_loss = None - +# Distiller loggers msglogger = apputils.config_pylogger('logging.conf', None) tflogger = TensorBoardLogger(msglogger.logdir) tflogger.log_gradients = True @@ -297,28 +301,31 @@ if args.summary: if param.dim() < 2: # Skip biases continue - bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()), largest=False, sorted=True) + bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()), + largest=False, sorted=True) threshold = bottomk.data[-1] - print("parameter %s: q = %.2f" %(name, threshold)) + msglogger.info("parameter %s: q = %.2f" %(name, threshold)) else: distiller.model_summary(model, None, which_summary, 'wikitext2') - exit(0) compression_scheduler = None if args.compress: - # The main use-case for this sample application is CNN compression. Compression - # requires a compression schedule configuration file in YAML. + # Create a CompressionScheduler and configure it from a YAML schedule file source = args.compress compression_scheduler = distiller.CompressionScheduler(model) distiller.config.fileConfig(model, None, compression_scheduler, args.compress, msglogger) -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) -#optimizer = optim.SparseAdam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) - +optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) +lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', + patience=0, verbose=True, factor=0.5) +# Loop over epochs. # At any point you can hit Ctrl + C to break out of training early. +best_val_loss = float("inf") try: for epoch in range(0, args.epochs): epoch_start_time = time.time() @@ -328,11 +335,12 @@ try: train(epoch, optimizer, compression_scheduler) val_loss = evaluate(val_data) - print('-' * 89) - print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | ' - 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), + msglogger.info('-' * 89) + msglogger.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | ' + 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss))) - print('-' * 89) + msglogger.info('-' * 89) + distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) stats = ('Peformance/Validation/', @@ -341,23 +349,24 @@ try: ('Perplexity', math.exp(val_loss))])) tflogger.log_training_progress(stats, epoch, 0, total=1, freq=1) + with open(args.save, 'wb') as f: + torch.save(model, f) + # Save the model if the validation loss is the best we've seen so far. - if not best_val_loss or val_loss < best_val_loss: - with open(args.save, 'wb') as f: + if val_loss < best_val_loss: + with open(args.save+".best", 'wb') as f: torch.save(model, f) best_val_loss = val_loss - else: - # Anneal the learning rate if no improvement has been seen in the validation dataset. - lr /= 4 #1.2 + lr_scheduler.step(val_loss) if compression_scheduler: compression_scheduler.on_epoch_end(epoch) except KeyboardInterrupt: - print('-' * 89) - print('Exiting from training early') + msglogger.info('-' * 89) + msglogger.info('Exiting from training early') -# Load the best saved model. +# Load the last saved model. with open(args.save, 'rb') as f: model = torch.load(f) # after load the rnn params are not a continuous chunk of memory @@ -366,10 +375,10 @@ with open(args.save, 'rb') as f: # Run on test data. test_loss = evaluate(test_data) -print('=' * 89) -print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( +msglogger.info('=' * 89) +msglogger.info('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) -print('=' * 89) +msglogger.info('=' * 89) if len(args.onnx_export) > 0: # Export the model in ONNX format. -- GitLab