Skip to content
Snippets Groups Projects
Commit a75a7483 authored by dsun18's avatar dsun18
Browse files

Final commit

parent 712444d8
No related branches found
No related tags found
1 merge request!1Final commit
# Pycharm files
.idea/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
lightning_logs/ lightning_logs/
exps/
*.py[cod] *.py[cod]
*$py.class *$py.class
......
# CS547 Project - Transformer # CS547 Project - Transformer
Requirements: numpy, pytorch, pytorch_lightning Requirements: numpy, pytorch, torchvision, pytorch_lightning, sentencepiece
Run `python3 train.py` to perform the synthetic dataset test. - Run `python3 get_wmt2014.py` to prepare the WMT2014 dataset.
- Modify `common.py` for constants used.
- Run `python3 train.py` to train the model.
- Run `python3 test.py` to get the BLEU score.
- Run `python3 translate.py` to start a prompt for live translation.
EXP_DIR = 'exp_wmt14/'
SP_MODEL = 'data/wmt14/wmtende.model'
TRAIN_PATH = 'data/wmt14/train_clean'
VAL_PATH = 'data/wmt14/valid'
TEST_PATH = 'data/wmt14/test'
SRC_LANG = 'en'
TGT_LANG = 'de'
MAX_SEQ_LEN = 200
ENC_NUM_LAYERS = 6
DEC_NUM_LAYERS = 6
D_MODEL = 512
D_FF = 2048
NUM_HEADS = 8
#%%
import os
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
import codecs
import sentencepiece as spm
from utils import subsequent_mask from utils import subsequent_mask
class SynthesisDataset(Dataset):
#%%
class SyntheticDataset(Dataset):
def __init__(self, vocab, seq_len, length): def __init__(self, vocab, seq_len, length):
super(SynthesisDataset, self).__init__() super(SyntheticDataset, self).__init__()
src = torch.from_numpy(np.random.randint(1, vocab, size=(length, seq_len))) src = torch.from_numpy(np.random.randint(1, vocab, size=(length, seq_len)))
src[:, 0] = 1 src[:, 0] = 1
self.src = src self.src = src
...@@ -23,8 +31,122 @@ class SynthesisDataset(Dataset): ...@@ -23,8 +31,122 @@ class SynthesisDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return (self.src[idx], self.src_mask[idx], self.tgt[idx], self.tgt_y[idx], self.tgt_mask[idx], self.ntokens[idx]) return (self.src[idx], self.src_mask[idx], self.tgt[idx], self.tgt_y[idx], self.tgt_mask[idx], self.ntokens[idx])
if __name__ == '__main__': class SyntheticDataModule(pl.LightningDataModule):
training_data = SynthesisDataset(11, 10, 4) def __init__(self, batch_size=30, vocab=11, seq_len=10, length=(6000, 150, 300), num_workers=1):
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True) super().__init__()
for (src, src_mask, tgt, tgt_y, tgt_mask, ntokens) in train_dataloader: self.batch_size = batch_size
print(src, tgt, tgt_y, src_mask, tgt_mask) self.num_workers = num_workers
self.vocab = vocab
self.seq_len = seq_len
self.train_len, self.val_len, self.test_len = length
def setup(self, stage=None):
self.train = SyntheticDataset(self.vocab, self.seq_len, self.train_len)
self.val = SyntheticDataset(self.vocab, self.seq_len, self.val_len)
self.test = SyntheticDataset(self.vocab, self.seq_len, self.test_len)
def train_dataloader(self):
return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)
#%%
class WMT2014Dataset(Dataset):
def __init__(self, sl_path, tl_path, sp, max_seq_len):
super(WMT2014Dataset, self).__init__()
self.sl_sentences, self.tl_sentences = [], []
pad_idx, sos_idx, eos_idx = sp['<blank>'], sp['<s>'], sp['</s>']
with codecs.open(sl_path, 'r', 'utf-8') as f:
for l in f:
tmp = l.split()
assert len(tmp) > 0
self.sl_sentences.append(torch.LongTensor(sp.piece_to_id(tmp)[:max_seq_len]))
with codecs.open(tl_path, 'r', 'utf-8') as f:
for l in f:
tmp = l.split()
assert len(tmp) > 0
self.tl_sentences.append(torch.LongTensor(([sos_idx] + sp.piece_to_id(tmp) + [eos_idx])[:max_seq_len]))
assert len(self.sl_sentences) == len(self.tl_sentences)
def __len__(self):
assert len(self.sl_sentences) == len(self.tl_sentences)
return len(self.sl_sentences)
def __getitem__(self, idx):
return (self.sl_sentences[idx], self.tl_sentences[idx])
#%%
def WMT2014_collate(pad_idx):
def inner(batch):
src = pad_sequence([t[0] for t in batch], batch_first=True, padding_value=pad_idx)
tmptgt = pad_sequence([t[1] for t in batch], batch_first=True, padding_value=pad_idx)
src_mask = (src != pad_idx).unsqueeze(-2)
tgt = tmptgt[:, :-1]
tgt_y = tmptgt[:, 1:]
tgt_mask = (tgt != pad_idx).unsqueeze(-2)
tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
ntokens = (tgt_y != pad_idx).data.sum(axis=1)
return (src, src_mask, tmptgt, tgt_mask, ntokens)
return inner
#%%
class WMT2014DataModule(pl.LightningDataModule):
def __init__(self, model_path, train_path, val_path, test_path, sl, tl, max_seq_len=2048, batch_size=128, num_workers=1):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.sp = spm.SentencePieceProcessor(model_file=model_path)
self.pad_idx, self.sos_idx, self.eos_idx = self.sp['<blank>'], self.sp['<s>'], self.sp['</s>']
self.vocabSize = self.sp.GetPieceSize()
self.train_path, self.val_path, self.test_path = train_path, val_path, test_path
self.sl, self.tl = sl, tl
self.max_seq_len = max_seq_len
self.train, self.val, self.test = None, None, None
self.initialized = False
def setup(self, stage=None):
if not self.initialized:
print('Initializing data module')
if self.train_path:
self.train = WMT2014Dataset(self.train_path + '.' + self.sl, self.train_path + '.' + self.tl, self.sp, self.max_seq_len)
if self.val_path:
self.val = WMT2014Dataset(self.val_path + '.' + self.sl, self.val_path + '.' + self.tl, self.sp, self.max_seq_len)
if self.test_path:
self.test = WMT2014Dataset(self.test_path + '.' + self.sl, self.test_path + '.' + self.tl, self.sp, self.max_seq_len)
self.initialized = True
print('Data module initialized')
def train_dataloader(self):
return DataLoader(
self.train,
collate_fn=WMT2014_collate(self.pad_idx),
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
drop_last=True
)
def val_dataloader(self):
return DataLoader(
self.val,
collate_fn=WMT2014_collate(self.pad_idx),
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
def test_dataloader(self):
return DataLoader(
self.test,
collate_fn=WMT2014_collate(self.pad_idx),
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
import os
from torchvision.datasets.utils import download_file_from_google_drive, extract_archive
if __name__ == '__main__':
TEST_ID = '1PcMHcW8NWVq4544IMuwNQ4REWYKO0mYj'
FULL_ID = '1Spk9gPIdB3wLLosAh652T6JJdmTM-qtq'
if not os.path.exists('data'):
os.makedirs('data')
if not os.path.exists('data/wmt14'):
os.makedirs('data/wmt14')
download_file_from_google_drive(
file_id=FULL_ID,
root='./data/',
filename='tmp.tar.gz'
)
extract_archive('./data/tmp.tar.gz', './data/wmt14/', remove_finished=True)
...@@ -20,12 +20,13 @@ class PositionalEncoding(nn.Module): ...@@ -20,12 +20,13 @@ class PositionalEncoding(nn.Module):
super(PositionalEncoding, self).__init__() super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.pe = torch.zeros(max_seq_len, d_model) pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1) position = torch.arange(0, max_seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.div(math.log(10000.0), d_model)))
self.pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
self.pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
self.pe = self.pe.unsqueeze(0) pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x): def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
...@@ -34,17 +35,18 @@ class PositionalEncoding(nn.Module): ...@@ -34,17 +35,18 @@ class PositionalEncoding(nn.Module):
class AttentionHead(nn.Module): class AttentionHead(nn.Module):
def __init__(self, d_model=512, d_k=64, d_v=64): def __init__(self, d_model=512, d_k=64, d_v=64):
super(AttentionHead, self).__init__() super(AttentionHead, self).__init__()
self.sq_d_k = d_k ** 0.5
self.linear_q = nn.Linear(d_model, d_k) self.linear_q = nn.Linear(d_model, d_k)
self.linear_k = nn.Linear(d_model, d_k) self.linear_k = nn.Linear(d_model, d_k)
self.linear_v = nn.Linear(d_model, d_v) self.linear_v = nn.Linear(d_model, d_v)
def forward(self, query, key, value, mask=None): def forward(self, query, key, value, mask=None):
Q, K, V = self.linear_q(query), self.linear_k(key), self.linear_v(value) Q, K, V = self.linear_q(query), self.linear_k(key), self.linear_v(value)
out = Q.bmm(K.transpose(1, 2)) / (Q.size(-1) ** 0.5) out = torch.div(torch.matmul(Q, K.transpose(-2, -1)), self.sq_d_k)
if mask is not None: if mask is not None:
out = out.masked_fill(mask == 0, -np.inf) out = out.masked_fill(mask == 0, -np.inf)
out = torch.softmax(out, dim=-1) out = torch.softmax(out, dim=-1)
return out.bmm(V) return torch.matmul(out, V)
class MultiHeadedAttention(nn.Module): class MultiHeadedAttention(nn.Module):
def __init__(self, d_model=512, d_k=64, d_v=64, h=8): def __init__(self, d_model=512, d_k=64, d_v=64, h=8):
...@@ -86,6 +88,7 @@ class Encoder(nn.Module): ...@@ -86,6 +88,7 @@ class Encoder(nn.Module):
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1, eps=1e-5): def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1, eps=1e-5):
super().__init__() super().__init__()
assert h != 0
d_k = d_v = d_model // h d_k = d_v = d_model // h
self.self_attn = MultiHeadedAttention(d_model, d_k, d_v, h) self.self_attn = MultiHeadedAttention(d_model, d_k, d_v, h)
self.feedforward = PositionwiseFeedForward(d_model, d_ff, dropout) self.feedforward = PositionwiseFeedForward(d_model, d_ff, dropout)
...@@ -94,7 +97,7 @@ class EncoderLayer(nn.Module): ...@@ -94,7 +97,7 @@ class EncoderLayer(nn.Module):
def forward(self, src, src_mask = None): def forward(self, src, src_mask = None):
x = self.residual1(src, lambda x: self.self_attn(x, x, x, src_mask)) x = self.residual1(src, lambda x: self.self_attn(x, x, x, src_mask))
return self.residual2(src, self.feedforward) return self.residual2(x, self.feedforward)
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, num_layers=6, d_model=512, h=8, d_ff=2048, dropout=0.1, eps=1e-5): def __init__(self, num_layers=6, d_model=512, h=8, d_ff=2048, dropout=0.1, eps=1e-5):
...@@ -110,6 +113,7 @@ class Decoder(nn.Module): ...@@ -110,6 +113,7 @@ class Decoder(nn.Module):
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1, eps=1e-5): def __init__(self, d_model=512, h=8, d_ff=2048, dropout=0.1, eps=1e-5):
super().__init__() super().__init__()
assert h != 0
d_k = d_v = d_model // h d_k = d_v = d_model // h
self.size = d_model self.size = d_model
self.self_attn = MultiHeadedAttention(d_model, d_k, d_v, h) self.self_attn = MultiHeadedAttention(d_model, d_k, d_v, h)
......
...@@ -29,3 +29,6 @@ class Transformer(nn.Module): ...@@ -29,3 +29,6 @@ class Transformer(nn.Module):
def decode(self, memory, src_mask, tgt, tgt_mask): def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
def project(self, x):
return self.proj(x)
...@@ -5,45 +5,61 @@ import pytorch_lightning as pl ...@@ -5,45 +5,61 @@ import pytorch_lightning as pl
from model import Transformer from model import Transformer
from optimizer import NoamOpt, LabelSmoothing from optimizer import NoamOpt, LabelSmoothing
from utils import greedy_decode, beam_search
class TransformerModule(pl.LightningModule): class TransformerModule(pl.LightningModule):
def __init__(self, src_vocab, tgt_vocab, enc_num_layers=6, dec_num_layers=6, def __init__(self, src_vocab, tgt_vocab, pad_idx, sos_idx, eos_idx,
d_model=512, d_ff=2048, h=8, max_seq_len=2048, padding_idx=0, smoothing=0.0, enc_num_layers=6, dec_num_layers=6,
d_model=512, d_ff=2048, h=8, max_seq_len=2048,
smoothing=0.1, factor=1, warmup=400, lr=0.0007, beta1=0.9, beta2=0.98,
dropout=0.1, eps=1e-5): dropout=0.1, eps=1e-5):
super(TransformerModule, self).__init__() super(TransformerModule, self).__init__()
self.padding_idx = padding_idx self.factor = factor
self.warmup = warmup
self.lr = lr
self.betas = (beta1, beta2)
self.eps = eps
self.pad_idx, self.sos_idx, self.eos_idx = pad_idx, sos_idx, eos_idx
self.max_seq_len = max_seq_len
self.model = Transformer(src_vocab, tgt_vocab, enc_num_layers, dec_num_layers, self.model = Transformer(src_vocab, tgt_vocab, enc_num_layers, dec_num_layers,
d_model, h, d_ff, max_seq_len, dropout, eps) d_model, h, d_ff, max_seq_len, dropout, eps)
self.criterion = LabelSmoothing(tgt_vocab, padding_idx, smoothing) self.criterion = LabelSmoothing(tgt_vocab, pad_idx, smoothing)
for p in self.model.parameters(): for p in self.model.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward(self, src, tgt, src_mask, tgt_mask): def forward(self, src, src_mask):
return self.model(src, tgt, src_mask, tgt_mask) return beam_search(self.model, src, src_mask, self.max_seq_len, self.sos_idx, self.eos_idx)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
src, src_mask, tgt, tgt_y, tgt_mask, ntokens = batch loss = self.shared_step(batch)
out = self.model(src, tgt, src_mask, tgt_mask) self.trainer.train_loop.running_loss.append(loss)
loss = self.criterion(out.contiguous().view(-1, out.size(-1)), tgt_y.contiguous().view(-1)) / ntokens.sum()
self.log('train_loss', loss) self.log('train_loss', loss)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
src, src_mask, tgt, tgt_y, tgt_mask, ntokens = batch self.log('val_loss', self.shared_step(batch))
out = self(src, tgt, src_mask, tgt_mask)
loss = self.criterion(out.contiguous().view(-1, out.size(-1)), tgt_y.contiguous().view(-1)) / ntokens.sum()
self.log('val_loss', loss)
return loss
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
src, src_mask, tgt, tgt_y, tgt_mask, ntokens = batch self.log('test_loss', self.shared_step(batch))
out = self(src, tgt, src_mask, tgt_mask)
loss = self.criterion(out.contiguous().view(-1, out.size(-1)), tgt_y.contiguous().view(-1)) / ntokens.sum() def shared_step(self, batch):
self.log('test_loss', loss) src, src_mask, tgt, tgt_mask, ntokens = batch
return loss out = self.model(src, tgt[:, :-1], src_mask, tgt_mask)
return torch.div(self.criterion(out.contiguous().view(-1, out.size(-1)), tgt[:, 1:].contiguous().view(-1)),
ntokens.sum())
def configure_optimizers(self): def configure_optimizers(self):
optimizer = NoamOpt(self.parameters(), size=self.model.d_model, factor=1, warmup=400, lr=0, betas=(0.9, 0.98), eps=1e-9) optimizer = NoamOpt(
self.parameters(),
size=self.model.d_model,
factor=self.factor,
warmup=self.warmup,
lr=self.lr,
betas=self.betas,
eps=self.eps
)
return optimizer return optimizer
...@@ -36,7 +36,7 @@ class LabelSmoothing(nn.Module): ...@@ -36,7 +36,7 @@ class LabelSmoothing(nn.Module):
def forward(self, x, target): def forward(self, x, target):
assert x.size(1) == self.size assert x.size(1) == self.size
true_dist = x.data.clone() true_dist = x.data.clone()
true_dist.fill_(self.smoothing / (self.size - 2)) true_dist.fill_(torch.div(self.smoothing, (self.size - 2)))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
true_dist[:, self.padding_idx] = 0 true_dist[:, self.padding_idx] = 0
mask = torch.nonzero(target.data == self.padding_idx) mask = torch.nonzero(target.data == self.padding_idx)
......
test.py 0 → 100644
#%%
import os
import codecs
import pickle as pkl
from tqdm import tqdm
from argparse import ArgumentParser
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import sentencepiece as spm
import sacrebleu
from module import TransformerModule
from utils import beam_search
from common import *
#%%
def main(hparams):
sp = spm.SentencePieceProcessor(model_file=hparams.sp_model)
pad_idx, sos_idx, eos_idx = sp['<blank>'], sp['<s>'], sp['</s>']
V = sp.GetPieceSize()
model = TransformerModule.load_from_checkpoint(
hparams.ckpt,
src_vocab=V, tgt_vocab=V,
pad_idx=pad_idx, sos_idx=sos_idx, eos_idx=eos_idx,
enc_num_layers=hparams.enc_num_layers, dec_num_layers=hparams.dec_num_layers,
d_model=hparams.d_model, d_ff=hparams.d_ff,
h=hparams.num_heads, max_seq_len=hparams.max_seq_len,
dropout=0.0, eps=1e-5
).to(hparams.device)
model.eval()
total = 0
with codecs.open(hparams.path + '.' + hparams.tl, 'r', 'utf-8') as f:
for l in f:
total += 1
refs = []
with codecs.open(hparams.path + '.' + hparams.tl, 'r', 'utf-8') as f:
for l in tqdm(f, total=total):
refs.append(sp.decode(l.split()))
mts = []
with codecs.open(hparams.path + '.' + hparams.sl, 'r', 'utf-8') as f:
for l in tqdm(f, total=total):
src = Variable(torch.LongTensor([sp.piece_to_id(l.split())[:hparams.max_seq_len]])).to(hparams.device)
src_mask = Variable((src != pad_idx).unsqueeze(-2)).to(hparams.device)
mts.append(sp.decode(
beam_search(model.model, src, src_mask, hparams.max_seq_len, sp['<s>'], sp['</s>'], 10)[0].cpu().numpy().tolist()
))
bleu = sacrebleu.corpus_bleu(mts, [refs])
print(bleu.score)
with open(hparams.save, 'wb') as f:
pkl.dump((mts, refs), f)
# %%
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--device', type=str, default="cpu")
parser.add_argument('--ckpt', type=str, default=os.path.join(EXP_DIR, 'last.ckpt'))
parser.add_argument('--save', type=str, default='test_result.pkl')
parser.add_argument('--sp_model', type=str, default=SP_MODEL)
parser.add_argument('--path', type=str, default=TEST_PATH)
parser.add_argument('--sl', type=str, default=SRC_LANG)
parser.add_argument('--tl', type=str, default=TGT_LANG)
parser.add_argument('--max_seq_len', type=int, default=MAX_SEQ_LEN)
parser.add_argument('--enc_num_layers', type=int, default=ENC_NUM_LAYERS)
parser.add_argument('--dec_num_layers', type=int, default=DEC_NUM_LAYERS)
parser.add_argument('--d_model', type=int, default=D_MODEL)
parser.add_argument('--d_ff', type=int, default=D_FF)
parser.add_argument('--num_heads', type=int, default=NUM_HEADS)
args = parser.parse_args()
main(args)
#%% #%%
import os
from argparse import ArgumentParser
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import sentencepiece as spm
from utils import greedy_decode
from dataloader import SynthesisDataset
from module import TransformerModule from module import TransformerModule
from dataloader import SyntheticDataModule, WMT2014DataModule
from common import *
# %% #%%
if __name__ == '__main__': def main(hparams):
V = 11
data = WMT2014DataModule(
hparams.sp_model,
hparams.train_path,
hparams.val_path,
None,
hparams.sl, hparams.tl,
max_seq_len=hparams.max_seq_len,
batch_size=hparams.batch_size,
num_workers=hparams.num_workers
)
train = SynthesisDataset(V, 10, 6000) V = data.vocabSize
train = DataLoader(train, batch_size=30, num_workers=16) sp = data.sp
model = TransformerModule(
src_vocab=V, tgt_vocab=V,
pad_idx=sp['<blank>'], sos_idx=sp['<s>'], eos_idx=sp['</s>'],
enc_num_layers=hparams.enc_num_layers, dec_num_layers=hparams.dec_num_layers,
d_model=hparams.d_model, d_ff=hparams.d_ff,
h=hparams.num_heads, max_seq_len=hparams.max_seq_len,
smoothing=hparams.smoothing, factor=hparams.noam_factor,
warmup=hparams.noam_warmup, lr=hparams.noam_lr,
beta1=hparams.noam_beta1, beta2=hparams.noam_beta2,
dropout=hparams.dropout, eps=hparams.eps
)
val = SynthesisDataset(V, 10, 150) if not os.path.exists(hparams.exp_dir):
val = DataLoader(val, batch_size=30, num_workers=16) os.makedirs(hparams.exp_dir)
checkpoint_callback = ModelCheckpoint(
dirpath=hparams.exp_dir,
monitor='val_loss',
mode='min',
filename='transformer-{epoch:d}-{step:d}-{train_loss:.2f}-{val_loss:.2f}',
save_last=True
)
# lr_monitor = LearningRateMonitor(logging_interval='step')
logger = TensorBoardLogger('lightning_logs', name='Transformer', default_hp_metric=False)
trainer = pl.Trainer(
gpus=hparams.gpus,
accelerator=hparams.accelerator,
precision=hparams.precision,
max_epochs=hparams.epoch,
resume_from_checkpoint=hparams.ckpt,
track_grad_norm=2,
gradient_clip_val=10.0,
accumulate_grad_batches=hparams.accumulate_grad_batches,
val_check_interval=0.1,
callbacks=[checkpoint_callback],
progress_bar_refresh_rate=1,
logger=logger,
log_every_n_steps=1,
log_gpu_memory='all',
)
trainer.logger.log_hyperparams(hparams)
trainer.fit(model, datamodule=data)
# %%
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--exp_dir', type=str, default=EXP_DIR)
parser.add_argument('--gpus', type=int, default=None)
parser.add_argument('--accelerator', type=str, default=None)
parser.add_argument('--precision', type=int, default=32)
parser.add_argument('--epoch', type=int, default=32)
parser.add_argument('--ckpt', type=str, default=None)
parser.add_argument('--sp_model', type=str, default=SP_MODEL)
parser.add_argument('--train_path', type=str, default=TRAIN_PATH)
parser.add_argument('--val_path', type=str, default=VAL_PATH)
parser.add_argument('--sl', type=str, default=SRC_LANG)
parser.add_argument('--tl', type=str, default=TGT_LANG)
module = TransformerModule(V, V, 2, 2) parser.add_argument('--max_seq_len', type=int, default=MAX_SEQ_LEN)
trainer = pl.Trainer(max_epochs=1, progress_bar_refresh_rate=1) parser.add_argument('--batch_size', type=int, default=16)
trainer.fit(module, train, val) parser.add_argument('--accumulate_grad_batches', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--enc_num_layers', type=int, default=ENC_NUM_LAYERS)
parser.add_argument('--dec_num_layers', type=int, default=DEC_NUM_LAYERS)
parser.add_argument('--d_model', type=int, default=D_MODEL)
parser.add_argument('--d_ff', type=int, default=D_FF)
parser.add_argument('--num_heads', type=int, default=NUM_HEADS)
parser.add_argument('--smoothing', type=float, default=0.1)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--eps', type=float, default=1e-9)
parser.add_argument('--noam_factor', type=float, default=1)
parser.add_argument('--noam_warmup', type=float, default=400)
parser.add_argument('--noam_lr', type=float, default=0.0007)
parser.add_argument('--noam_beta1', type=float, default=0.9)
parser.add_argument('--noam_beta2', type=float, default=0.98)
args = parser.parse_args()
module.eval() main(args)
src = Variable(torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) )
src_mask = Variable(torch.ones(1, 1, 10) ) \ No newline at end of file
print(greedy_decode(module.model, src, src_mask, max_len=10, start_symbol=1))
#%%
import os
import codecs
import pickle as pkl
from tqdm import tqdm
from argparse import ArgumentParser
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import sentencepiece as spm
import sacrebleu
from module import TransformerModule
from utils import beam_search
from common import *
#%%
def main(hparams):
sp = spm.SentencePieceProcessor(model_file=hparams.sp_model)
pad_idx, sos_idx, eos_idx = sp['<blank>'], sp['<s>'], sp['</s>']
V = sp.GetPieceSize()
model = TransformerModule.load_from_checkpoint(
hparams.ckpt,
src_vocab=V, tgt_vocab=V,
pad_idx=pad_idx, sos_idx=sos_idx, eos_idx=eos_idx,
enc_num_layers=hparams.enc_num_layers, dec_num_layers=hparams.dec_num_layers,
d_model=hparams.d_model, d_ff=hparams.d_ff,
h=hparams.num_heads, max_seq_len=hparams.max_seq_len,
dropout=0.0, eps=1e-5
)
model.eval()
try:
while True:
_in = input(">>> ")
if not _in:
break
# print('Input:', _in)
src = Variable(torch.LongTensor([sp.encode(_in)[:hparams.max_seq_len]]))
src_mask = Variable((src != pad_idx).unsqueeze(-2))
print('Encoded:', src)
result = beam_search(model.model, src, src_mask, hparams.max_seq_len, sp['<s>'], sp['</s>'], 10)[0].numpy().tolist()
print('Beam:', result)
print('Result:', sp.decode(result))
print()
except KeyboardInterrupt as e:
pass
except EOFError as e:
pass
except Exception as e:
print(f"Error: {e}")
exit(1)
print("\nExiting...")
# %%
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--ckpt', type=str, default=os.path.join(EXP_DIR, 'last.ckpt'))
parser.add_argument('--sp_model', type=str, default=SP_MODEL)
parser.add_argument('--max_seq_len', type=int, default=MAX_SEQ_LEN)
parser.add_argument('--enc_num_layers', type=int, default=ENC_NUM_LAYERS)
parser.add_argument('--dec_num_layers', type=int, default=DEC_NUM_LAYERS)
parser.add_argument('--d_model', type=int, default=D_MODEL)
parser.add_argument('--d_ff', type=int, default=D_FF)
parser.add_argument('--num_heads', type=int, default=NUM_HEADS)
args = parser.parse_args()
main(args)
#%%
import numpy as np import numpy as np
import torch import torch
#%%
def subsequent_mask(size): def subsequent_mask(size):
mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8') mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8')
return torch.from_numpy(mask) == 0 return torch.from_numpy(mask) == 0
def greedy_decode(model, src, src_mask, max_len, start_symbol): #%%
def greedy_decode(model, src, src_mask, max_len, start_symbol, stop_symbol):
memory = model.encode(src, src_mask) memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data) ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
for i in range(max_len-1): for i in range(1, max_len):
out = model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)) out = model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
prob = model.proj(out[:, -1]) prob = model.proj(out[:, -1])
_, next_word = torch.max(prob, dim = 1) _, next_word = torch.max(prob, dim = 1)
next_word = next_word.data[0] next_word = next_word.data[0]
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src.data)], dim=1)
if next_word.item() == stop_symbol:
break
return ys return ys
#%%
def beam_search(model, src, src_mask, max_len, start_symbol, stop_symbol, k=5, alpha=0.7):
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
out = model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
log_scores, idx = model.proj(out[:, -1]).data.topk(k)
outputs = torch.zeros(k, max_len).type_as(src.data)
outputs[:, 0] = start_symbol
outputs[:, 1] = idx[0]
memorys = torch.zeros(k, memory.size(-2), memory.size(-1))
memorys[:, :] = memory[0]
for i in range(2, max_len):
out = model.decode(memory, src_mask, outputs[:, :i], subsequent_mask(i).type_as(src.data))
log_probs, idx = model.proj(out[:, -1]).data.topk(k)
log_probs = log_probs + log_scores.transpose(0, 1)
k_probs, k_ix = log_probs.view(-1).topk(k)
assert k != 0
row = k_ix // k
col = k_ix % k
outputs[:, :i] = outputs[row, :i]
outputs[:, i] = idx[row, col].type_as(src.data)
log_scores = k_probs.unsqueeze(0)
if (((outputs == stop_symbol).sum(axis=1)) > 0).sum() == k:
break
_, sentence_len = (((outputs == stop_symbol).cumsum(axis=1) == 1) & (outputs == stop_symbol)).max(axis=1)
sentence_len[sentence_len == 0] = max_len - 1
sentence_len += 1
_, ind = torch.max(log_scores * (torch.div(1, (sentence_len.type_as(log_scores) ** alpha))), axis=1)
return outputs[ind, :sentence_len[ind]]
# %%
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment