Skip to content
Snippets Groups Projects
Commit a9b28923 authored by Neta Zmora's avatar Neta Zmora
Browse files

Language model: replace the optimizer and LR-decay scheduler

Replace the original "homebrew" optimizer and LR-decay schedule with
PyTorch's SGD and ReduceLROnPlateau.
SGD with momentum=0 and weight_decay=0, and ReduceLROnPlateau with
patience=0 and factor=0.5 will give the same behavior as in the
original PyTorch example.

Having a standard optimizer and LR-decay schedule gives us the
flexibility to experiment with these during the training process.
parent d6ffeaf7
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@ from collections import OrderedDict
import data
import model
# Distiller imports
import os
import sys
script_dir = os.path.dirname(__file__)
......@@ -24,8 +25,8 @@ if module_path not in sys.path:
sys.path.append(module_path)
import distiller
import apputils
from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
import torchnet.meter as tnt
from distiller.data_loggers import TensorBoardLogger, PythonLogger
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2',
......@@ -58,7 +59,7 @@ parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.pt',
parser.add_argument('--save', type=str, default='checkpoint.pth.tar',
help='path to save the final model')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
......@@ -66,12 +67,17 @@ parser.add_argument('--onnx-export', type=str, default='',
help='path to export the final model in onnx format')
# Distiller-related arguments
SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'percentile']
SUMMARY_CHOICES = ['sparsity', 'model', 'modules', 'png', 'percentile']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--momentum', default=0., type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0., type=float,
metavar='W', help='weight decay (default: 1e-4)')
args = parser.parse_args()
# Set the random seed manually for reproducibility.
......@@ -84,6 +90,11 @@ device = torch.device("cuda" if args.cuda else "cpu")
def draw_lang_model_to_file(model, png_fname, dataset):
"""Draw a language model graph to a PNG file.
Caveat: the PNG that is produced has some problems, which we suspect are due to
PyTorch issues related to RNN ONNX export.
"""
try:
if dataset == 'wikitext2':
batch_size = 20
......@@ -92,16 +103,16 @@ def draw_lang_model_to_file(model, png_fname, dataset):
hidden = model.init_hidden(batch_size)
dummy_input = (dummy_input, hidden)
else:
print("Unsupported dataset (%s) - aborting draw operation" % dataset)
msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset)
return
g = apputils.SummaryGraph(model, dummy_input)
apputils.draw_model_to_file(g, png_fname)
print("Network PNG image generation completed")
msglogger.info("Network PNG image generation completed")
except FileNotFoundError as e:
print("An error has occured while generating the network PNG image.")
print("Please check that you have graphviz installed.")
print("\t$ sudo apt-get install graphviz")
msglogger.info("An error has occured while generating the network PNG image.")
msglogger.info("Please check that you have graphviz installed.")
msglogger.info("\t$ sudo apt-get install graphviz")
raise e
###############################################################################
......@@ -229,18 +240,13 @@ def train(epoch, optimizer, compression_scheduler=None):
regularizer_loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
minibatches_per_epoch=steps_per_epoch, loss=loss)
loss += regularizer_loss
#losses['regularizer_loss'].add(regularizer_loss.item())
model.zero_grad()
#optimizer.zero_grad()
optimizer.zero_grad()
loss.backward()
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
for p in model.parameters():
p.data.add_(-lr, p.grad.data)
#optimizer.step()
optimizer.step()
total_loss += loss.item()
......@@ -250,8 +256,10 @@ def train(epoch, optimizer, compression_scheduler=None):
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
lr = optimizer.param_groups[0]['lr']
msglogger.info(
'| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} '
'| loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
......@@ -264,13 +272,12 @@ def train(epoch, optimizer, compression_scheduler=None):
('Batch Time', elapsed * 1000)])
)
steps_completed = batch + 1
#tflogger.log_training_progress(stats, epoch, steps_completed, total=steps_per_epoch, freq=args.log_interval)
distiller.log_training_progress(stats, model.named_parameters(), epoch, steps_completed,
steps_per_epoch, args.log_interval, [tflogger])
def export_onnx(path, batch_size, seq_len):
print('The model is also exported in ONNX format at {}'.
msglogger.info('The model is also exported in ONNX format at {}'.
format(os.path.realpath(args.onnx_export)))
model.eval()
dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device)
......@@ -278,10 +285,7 @@ def export_onnx(path, batch_size, seq_len):
torch.onnx.export(model, (dummy_input, hidden), path)
# Loop over epochs.
lr = args.lr
best_val_loss = None
# Distiller loggers
msglogger = apputils.config_pylogger('logging.conf', None)
tflogger = TensorBoardLogger(msglogger.logdir)
tflogger.log_gradients = True
......@@ -297,28 +301,31 @@ if args.summary:
if param.dim() < 2:
# Skip biases
continue
bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()), largest=False, sorted=True)
bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()),
largest=False, sorted=True)
threshold = bottomk.data[-1]
print("parameter %s: q = %.2f" %(name, threshold))
msglogger.info("parameter %s: q = %.2f" %(name, threshold))
else:
distiller.model_summary(model, None, which_summary, 'wikitext2')
exit(0)
compression_scheduler = None
if args.compress:
# The main use-case for this sample application is CNN compression. Compression
# requires a compression schedule configuration file in YAML.
# Create a CompressionScheduler and configure it from a YAML schedule file
source = args.compress
compression_scheduler = distiller.CompressionScheduler(model)
distiller.config.fileConfig(model, None, compression_scheduler, args.compress, msglogger)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
#optimizer = optim.SparseAdam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
patience=0, verbose=True, factor=0.5)
# Loop over epochs.
# At any point you can hit Ctrl + C to break out of training early.
best_val_loss = float("inf")
try:
for epoch in range(0, args.epochs):
epoch_start_time = time.time()
......@@ -328,11 +335,12 @@ try:
train(epoch, optimizer, compression_scheduler)
val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
msglogger.info('-' * 89)
msglogger.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
msglogger.info('-' * 89)
distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
stats = ('Peformance/Validation/',
......@@ -341,23 +349,24 @@ try:
('Perplexity', math.exp(val_loss))]))
tflogger.log_training_progress(stats, epoch, 0, total=1, freq=1)
with open(args.save, 'wb') as f:
torch.save(model, f)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
if val_loss < best_val_loss:
with open(args.save+".best", 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4 #1.2
lr_scheduler.step(val_loss)
if compression_scheduler:
compression_scheduler.on_epoch_end(epoch)
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
msglogger.info('-' * 89)
msglogger.info('Exiting from training early')
# Load the best saved model.
# Load the last saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)
# after load the rnn params are not a continuous chunk of memory
......@@ -366,10 +375,10 @@ with open(args.save, 'rb') as f:
# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
msglogger.info('=' * 89)
msglogger.info('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
msglogger.info('=' * 89)
if len(args.onnx_export) > 0:
# Export the model in ONNX format.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment