From a9b28923c42c6d7c6f691716f597510ed0204c90 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 13 Jun 2018 10:42:54 +0300
Subject: [PATCH] Language model: replace the optimizer and LR-decay scheduler

Replace the original "homebrew" optimizer and LR-decay schedule with
PyTorch's SGD and ReduceLROnPlateau.
SGD with momentum=0 and weight_decay=0, and ReduceLROnPlateau with
patience=0 and factor=0.5 will give the same behavior as in the
original PyTorch example.

Having a standard optimizer and LR-decay schedule gives us the
flexibility to experiment with these during the training process.
---
 examples/word_language_model/main.py | 103 +++++++++++++++------------
 1 file changed, 56 insertions(+), 47 deletions(-)

diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py
index 4d50581..328eb45 100755
--- a/examples/word_language_model/main.py
+++ b/examples/word_language_model/main.py
@@ -16,6 +16,7 @@ from collections import OrderedDict
 import data
 import model
 
+# Distiller imports
 import os
 import sys
 script_dir = os.path.dirname(__file__)
@@ -24,8 +25,8 @@ if module_path not in sys.path:
     sys.path.append(module_path)
 import distiller
 import apputils
-from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
-import torchnet.meter as tnt
+from distiller.data_loggers import TensorBoardLogger, PythonLogger
+
 
 parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model')
 parser.add_argument('--data', type=str, default='./data/wikitext-2',
@@ -58,7 +59,7 @@ parser.add_argument('--cuda', action='store_true',
                     help='use CUDA')
 parser.add_argument('--log-interval', type=int, default=200, metavar='N',
                     help='report interval')
-parser.add_argument('--save', type=str, default='model.pt',
+parser.add_argument('--save', type=str, default='checkpoint.pth.tar',
                     help='path to save the final model')
 parser.add_argument('--resume', default='', type=str, metavar='PATH',
                     help='path to latest checkpoint (default: none)')
@@ -66,12 +67,17 @@ parser.add_argument('--onnx-export', type=str, default='',
                     help='path to export the final model in onnx format')
 
 # Distiller-related arguments
-SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'percentile']
+SUMMARY_CHOICES = ['sparsity', 'model', 'modules', 'png', 'percentile']
 parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
                     help='print a summary of the model, and exit - options: ' +
                     ' | '.join(SUMMARY_CHOICES))
 parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
                     help='configuration file for pruning the model (default is to use hard-coded schedule)')
+parser.add_argument('--momentum', default=0., type=float, metavar='M',
+                    help='momentum')
+parser.add_argument('--weight-decay', '--wd', default=0., type=float,
+                    metavar='W', help='weight decay (default: 1e-4)')
+
 args = parser.parse_args()
 
 # Set the random seed manually for reproducibility.
@@ -84,6 +90,11 @@ device = torch.device("cuda" if args.cuda else "cpu")
 
 
 def draw_lang_model_to_file(model, png_fname, dataset):
+    """Draw a language model graph to a PNG file.
+
+    Caveat: the PNG that is produced has some problems, which we suspect are due to
+    PyTorch issues related to RNN ONNX export.
+    """
     try:
         if dataset == 'wikitext2':
             batch_size = 20
@@ -92,16 +103,16 @@ def draw_lang_model_to_file(model, png_fname, dataset):
             hidden = model.init_hidden(batch_size)
             dummy_input = (dummy_input, hidden)
         else:
-            print("Unsupported dataset (%s) - aborting draw operation" % dataset)
+            msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset)
             return
         g = apputils.SummaryGraph(model, dummy_input)
         apputils.draw_model_to_file(g, png_fname)
-        print("Network PNG image generation completed")
+        msglogger.info("Network PNG image generation completed")
 
     except FileNotFoundError as e:
-        print("An error has occured while generating the network PNG image.")
-        print("Please check that you have graphviz installed.")
-        print("\t$ sudo apt-get install graphviz")
+        msglogger.info("An error has occured while generating the network PNG image.")
+        msglogger.info("Please check that you have graphviz installed.")
+        msglogger.info("\t$ sudo apt-get install graphviz")
         raise e
 
 ###############################################################################
@@ -229,18 +240,13 @@ def train(epoch, optimizer, compression_scheduler=None):
             regularizer_loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
                                                                           minibatches_per_epoch=steps_per_epoch, loss=loss)
             loss += regularizer_loss
-            #losses['regularizer_loss'].add(regularizer_loss.item())
 
-        model.zero_grad()
-        #optimizer.zero_grad()
+        optimizer.zero_grad()
         loss.backward()
 
-
         # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
         torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
-        for p in model.parameters():
-            p.data.add_(-lr, p.grad.data)
-        #optimizer.step()
+        optimizer.step()
 
         total_loss += loss.item()
 
@@ -250,8 +256,10 @@ def train(epoch, optimizer, compression_scheduler=None):
         if batch % args.log_interval == 0 and batch > 0:
             cur_loss = total_loss / args.log_interval
             elapsed = time.time() - start_time
-            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} | '
-                    'loss {:5.2f} | ppl {:8.2f}'.format(
+            lr = optimizer.param_groups[0]['lr']
+            msglogger.info(
+                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} '
+                    '| loss {:5.2f} | ppl {:8.2f}'.format(
                 epoch, batch, len(train_data) // args.bptt, lr,
                 elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
             total_loss = 0
@@ -264,13 +272,12 @@ def train(epoch, optimizer, compression_scheduler=None):
                     ('Batch Time', elapsed * 1000)])
                 )
             steps_completed = batch + 1
-            #tflogger.log_training_progress(stats, epoch, steps_completed, total=steps_per_epoch, freq=args.log_interval)
             distiller.log_training_progress(stats, model.named_parameters(), epoch, steps_completed,
                                             steps_per_epoch, args.log_interval, [tflogger])
 
 
 def export_onnx(path, batch_size, seq_len):
-    print('The model is also exported in ONNX format at {}'.
+    msglogger.info('The model is also exported in ONNX format at {}'.
           format(os.path.realpath(args.onnx_export)))
     model.eval()
     dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device)
@@ -278,10 +285,7 @@ def export_onnx(path, batch_size, seq_len):
     torch.onnx.export(model, (dummy_input, hidden), path)
 
 
-# Loop over epochs.
-lr = args.lr
-best_val_loss = None
-
+# Distiller loggers
 msglogger = apputils.config_pylogger('logging.conf', None)
 tflogger = TensorBoardLogger(msglogger.logdir)
 tflogger.log_gradients = True
@@ -297,28 +301,31 @@ if args.summary:
             if param.dim() < 2:
                 # Skip biases
                 continue
-            bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()), largest=False, sorted=True)
+            bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()),
+                                    largest=False, sorted=True)
             threshold = bottomk.data[-1]
-            print("parameter %s: q = %.2f" %(name, threshold))
+            msglogger.info("parameter %s: q = %.2f" %(name, threshold))
     else:
         distiller.model_summary(model, None, which_summary, 'wikitext2')
-
     exit(0)
 
 compression_scheduler = None
 
 if args.compress:
-    # The main use-case for this sample application is CNN compression.  Compression
-    # requires a compression schedule configuration file in YAML.
+    # Create a CompressionScheduler and configure it from a YAML schedule file
     source = args.compress
     compression_scheduler = distiller.CompressionScheduler(model)
     distiller.config.fileConfig(model, None, compression_scheduler, args.compress, msglogger)
 
-optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
-#optimizer = optim.SparseAdam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
-
+optimizer = torch.optim.SGD(model.parameters(), args.lr,
+                                 momentum=args.momentum,
+                                 weight_decay=args.weight_decay)
+lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
+                                                          patience=0, verbose=True, factor=0.5)
 
+# Loop over epochs.
 # At any point you can hit Ctrl + C to break out of training early.
+best_val_loss = float("inf")
 try:
     for epoch in range(0, args.epochs):
         epoch_start_time = time.time()
@@ -328,11 +335,12 @@ try:
         train(epoch, optimizer, compression_scheduler)
 
         val_loss = evaluate(val_data)
-        print('-' * 89)
-        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
-                'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
+        msglogger.info('-' * 89)
+        msglogger.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
+                       'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                            val_loss, math.exp(val_loss)))
-        print('-' * 89)
+        msglogger.info('-' * 89)
+
         distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
 
         stats = ('Peformance/Validation/',
@@ -341,23 +349,24 @@ try:
                 ('Perplexity', math.exp(val_loss))]))
         tflogger.log_training_progress(stats, epoch, 0, total=1, freq=1)
 
+        with open(args.save, 'wb') as f:
+            torch.save(model, f)
+
         # Save the model if the validation loss is the best we've seen so far.
-        if not best_val_loss or val_loss < best_val_loss:
-            with open(args.save, 'wb') as f:
+        if val_loss < best_val_loss:
+            with open(args.save+".best", 'wb') as f:
                 torch.save(model, f)
             best_val_loss = val_loss
-        else:
-            # Anneal the learning rate if no improvement has been seen in the validation dataset.
-            lr /= 4 #1.2
+        lr_scheduler.step(val_loss)
 
         if compression_scheduler:
             compression_scheduler.on_epoch_end(epoch)
 
 except KeyboardInterrupt:
-    print('-' * 89)
-    print('Exiting from training early')
+    msglogger.info('-' * 89)
+    msglogger.info('Exiting from training early')
 
-# Load the best saved model.
+# Load the last saved model.
 with open(args.save, 'rb') as f:
     model = torch.load(f)
     # after load the rnn params are not a continuous chunk of memory
@@ -366,10 +375,10 @@ with open(args.save, 'rb') as f:
 
 # Run on test data.
 test_loss = evaluate(test_data)
-print('=' * 89)
-print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
+msglogger.info('=' * 89)
+msglogger.info('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
     test_loss, math.exp(test_loss)))
-print('=' * 89)
+msglogger.info('=' * 89)
 
 if len(args.onnx_export) > 0:
     # Export the model in ONNX format.
-- 
GitLab