Japanese-Chinese Machine Translation Model with Transformer & PyTorch
A tutorial using Jupyter Notebook, PyTorch, Torchtext, and SentencePiece
Import required packages
Firstly, let’s make sure we have the below packages installed in our system, if you found that some packages are missing, make sure to install them.
import math
import torchtext
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from collections import Counter
from torchtext.vocab import Vocab
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
import io
import time
import pandas as pd
import numpy as np
import pickle
import tqdm
import sentencepiece as spm
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(torch.cuda.get_device_name(0)) ## 如果你有GPU,请在你自己的电脑上尝试运行这一套代码
device
device(type='cpu')
Get the parallel dataset
In this tutorial, we will use the Japanese-English parallel dataset downloaded from JParaCrawl![http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl] which is described as the “largest publicly available English-Japanese parallel corpus created by NTT. It was created by largely crawling the web and automatically aligning parallel sentences.” You can also see the paper here.
# 读取并处理数据
df = pd.read_csv('./zh-ja/zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)
trainen = df[2].values.tolist() # 提取英语列
trainja = df[3].values.tolist() # 提取日语列
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
<ipython-input-3-26f0f38b39f3> in <module>
1 # 读取并处理数据
----> 2 df = pd.read_csv('./zh-ja/zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)
3 trainen = df[2].values.tolist() # 提取英语列
4 trainja = df[3].values.tolist() # 提取日语列
/opt/conda/lib/python3.6/site-packages/pandas/io/parsers.py in parser_f(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, escapechar, comment, encoding, dialect, tupleize_cols, error_bad_lines, warn_bad_lines, skipfooter, doublequote, delim_whitespace, low_memory, memory_map, float_precision)
676 skip_blank_lines=skip_blank_lines)
677
--> 678 return _read(filepath_or_buffer, kwds)
679
680 parser_f.__name__ = name
/opt/conda/lib/python3.6/site-packages/pandas/io/parsers.py in _read(filepath_or_buffer, kwds)
438
439 # Create the parser.
--> 440 parser = TextFileReader(filepath_or_buffer, **kwds)
441
442 if chunksize or iterator:
/opt/conda/lib/python3.6/site-packages/pandas/io/parsers.py in __init__(self, f, engine, **kwds)
785 self.options['has_index_names'] = kwds['has_index_names']
786
--> 787 self._make_engine(self.engine)
788
789 def close(self):
/opt/conda/lib/python3.6/site-packages/pandas/io/parsers.py in _make_engine(self, engine)
1022 ' "c", "python", or' ' "python-fwf")'.format(
1023 engine=engine))
-> 1024 self._engine = klass(self.f, **self.options)
1025
1026 def _failover_to_python(self):
/opt/conda/lib/python3.6/site-packages/pandas/io/parsers.py in __init__(self, f, **kwds)
2075 f, handles = _get_handle(f, mode, encoding=self.encoding,
2076 compression=self.compression,
-> 2077 memory_map=self.memory_map)
2078 self.handles.extend(handles)
2079
/opt/conda/lib/python3.6/site-packages/pandas/io/common.py in _get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text)
401 elif is_text:
402 # Python 3 and no explicit encoding
--> 403 f = open(path_or_buf, mode, errors='replace')
404 else:
405 # Python 3 and binary mode
FileNotFoundError: [Errno 2] No such file or directory: './zh-ja/zh-ja.bicleaner05.txt'
After importing all the Japanese and their English counterparts, I deleted the last data in the dataset because it has a missing value. In total, the number of sentences in both trainen and trainja is 5,973,071, however, for learning purposes, it is often recommended to sample the data and make sure everything is working as intended, before using all the data at once, to save time.
Here is an example of sentence contained in the dataset.
# 打印示例数据
print(trainen[500])
print(trainja[500])
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-6-003849ff797b> in <module>
1 # 打印示例数据
----> 2 print(trainen[500])
3 print(trainja[500])
NameError: name 'trainen' is not defined
We can also use different parallel datasets to follow along with this article, just make sure that we can process the data into the two lists of strings as shown above, containing the Japanese and English sentences.
Prepare the tokenizers
Unlike English or other alphabetical languages, a Japanese sentence does not contain whitespaces to separate the words. We can use the tokenizers provided by JParaCrawl which was created using SentencePiece for both Japanese and English, you can visit the JParaCrawl website to download them, or click here.
# 加载SentencePiece模型
en_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.en.nopretok.model')
ja_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.ja.nopretok.model')
---------------------------------------------------------------------------
OSError Traceback (most recent call last)
<ipython-input-7-692c04289ad5> in <module>
1 # 加载SentencePiece模型
----> 2 en_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.en.nopretok.model')
3 ja_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.ja.nopretok.model')
/opt/conda/lib/python3.6/site-packages/sentencepiece.py in Init(self, model_file, model_proto, out_type, add_bos, add_eos, reverse, enable_sampling, nbest_size, alpha)
216 self._alpha = alpha
217 if model_file or model_proto:
--> 218 self.Load(model_file=model_file, model_proto=model_proto)
219
220
/opt/conda/lib/python3.6/site-packages/sentencepiece.py in Load(self, model_file, model_proto)
365 if model_proto:
366 return self.LoadFromSerializedProto(model_proto)
--> 367 return self.LoadFromFile(model_file)
368
369
/opt/conda/lib/python3.6/site-packages/sentencepiece.py in LoadFromFile(self, arg)
175
176 def LoadFromFile(self, arg):
--> 177 return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
178
179 def Init(self,
OSError: Not found: "enja_spm_models/spm.en.nopretok.model": No such file or directory Error #2
After the tokenizers are loaded, you can test them, for example, by executing the below code.
en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-8-e6762df8cedf> in <module>
----> 1 en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')
NameError: name 'en_tokenizer' is not defined
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-9-0972dbdafd18> in <module>
----> 1 ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
NameError: name 'ja_tokenizer' is not defined
Build the TorchText Vocab objects and convert the sentences into Torch tensors
Using the tokenizers and raw sentences, we then build the Vocab object imported from TorchText. This process can take a few seconds or minutes depending on the size of our dataset and computing power. Different tokenizer can also affect the time needed to build the vocab, I tried several other tokenizers for Japanese but SentencePiece seems to be working well and fast enough for me.
def build_vocab(sentences, tokenizer):
counter = Counter()
for sentence in sentences:
counter.update(tokenizer.encode(sentence, out_type=str))
return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
ja_vocab = build_vocab(trainja, ja_tokenizer)
en_vocab = build_vocab(trainen, en_tokenizer)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-10-e2f74fdb1e8d> in <module>
4 counter.update(tokenizer.encode(sentence, out_type=str))
5 return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
----> 6 ja_vocab = build_vocab(trainja, ja_tokenizer)
7 en_vocab = build_vocab(trainen, en_tokenizer)
NameError: name 'trainja' is not defined
After we have the vocabulary objects, we can then use the vocab and the tokenizer objects to build the tensors for our training data.
def data_process(ja, en):
data = []
for (raw_ja, raw_en) in zip(ja, en):
ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
dtype=torch.long)
en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
dtype=torch.long)
data.append((ja_tensor_, en_tensor_))
return data
train_data = data_process(trainja, trainen)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-11-3e958deb1424> in <module>
8 data.append((ja_tensor_, en_tensor_))
9 return data
---> 10 train_data = data_process(trainja, trainen)
NameError: name 'trainja' is not defined
Create the DataLoader object to be iterated during training
Here, I set the BATCH_SIZE to 16 to prevent “cuda out of memory”, but this depends on various things such as your machine memory capacity, size of data, etc., so feel free to change the batch size according to your needs (note: the tutorial from PyTorch sets the batch size as 128 using the Multi30k German-English dataset.)
BATCH_SIZE = 8
PAD_IDX = ja_vocab['<pad>']
BOS_IDX = ja_vocab['<bos>']
EOS_IDX = ja_vocab['<eos>']
def generate_batch(data_batch):
ja_batch, en_batch = [], []
for (ja_item, en_item) in data_batch:
ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
return ja_batch, en_batch
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-12-d24c4e282e38> in <module>
1 BATCH_SIZE = 8
----> 2 PAD_IDX = ja_vocab['<pad>']
3 BOS_IDX = ja_vocab['<bos>']
4 EOS_IDX = ja_vocab['<eos>']
5 def generate_batch(data_batch):
NameError: name 'ja_vocab' is not defined
Sequence-to-sequence Transformer
The next couple of codes and text explanations (written in italic) are taken from the original PyTorch tutorial [https://pytorch.org/tutorials/beginner/translation_transformer.html]. I did not make any change except for the BATCH_SIZE and the word de_vocabwhich is changed to ja_vocab.
Transformer is a Seq2Seq model introduced in “Attention is all you need” paper for solving machine translation task. Transformer model consists of an encoder and decoder block each containing fixed number of layers.
Encoder processes the input sequence by propagating it, through a series of Multi-head Attention and Feed forward network layers. The output from the Encoder referred to as memory, is fed to the decoder along with target tensors. Encoder and decoder are trained in an end-to-end fashion using teacher forcing technique.
from torch.nn import (TransformerEncoder, TransformerDecoder,
TransformerEncoderLayer, TransformerDecoderLayer)
class Seq2SeqTransformer(nn.Module):
def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
dim_feedforward:int = 512, dropout:float = 0.1):
super(Seq2SeqTransformer, self).__init__()
encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
dim_feedforward=dim_feedforward)
self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
dim_feedforward=dim_feedforward)
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
tgt_mask: Tensor, src_padding_mask: Tensor,
tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer_encoder(self.positional_encoding(
self.src_tok_emb(src)), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer_decoder(self.positional_encoding(
self.tgt_tok_emb(tgt)), memory,
tgt_mask)
Text tokens are represented by using token embeddings. Positional encoding is added to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)
self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding +
self.pos_embedding[:token_embedding.size(0),:])
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens: Tensor):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
We create a subsequent word mask to stop a target word from attending to its subsequent words. We also create masks, for masking source and target padding tokens
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(src, tgt):
src_seq_len = src.shape[0]
tgt_seq_len = tgt.shape[0]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
Define model parameters and instantiate model. 这里我们服务器实在是计算能力有限,按照以下配置可以训练但是效果应该是不行的。如果想要看到训练的效果请使用你自己的带GPU的电脑运行这一套代码。
当你使用自己的GPU的时候,NUM_ENCODER_LAYERS 和 NUM_DECODER_LAYERS 设置为3或者更高,NHEAD设置8,EMB_SIZE设置为512。
SRC_VOCAB_SIZE = len(ja_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 16
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
FFN_HID_DIM)
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
transformer = transformer.to(device)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(
transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)
def train_epoch(model, train_iter, optimizer):
model.train()
losses = 0
for idx, (src, tgt) in enumerate(train_iter):
src = src.to(device)
tgt = tgt.to(device)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask)
optimizer.zero_grad()
tgt_out = tgt[1:,:]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
loss.backward()
optimizer.step()
losses += loss.item()
return losses / len(train_iter)
def evaluate(model, val_iter):
model.eval()
losses = 0
for idx, (src, tgt) in (enumerate(valid_iter)):
src = src.to(device)
tgt = tgt.to(device)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask)
tgt_out = tgt[1:,:]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
losses += loss.item()
return losses / len(val_iter)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-16-343c3c7c8d2c> in <module>
----> 1 SRC_VOCAB_SIZE = len(ja_vocab)
2 TGT_VOCAB_SIZE = len(en_vocab)
3 EMB_SIZE = 512
4 NHEAD = 8
5 FFN_HID_DIM = 512
NameError: name 'ja_vocab' is not defined
Start training
Finally, after preparing the necessary classes and functions, we are ready to train our model. This goes without saying but the time needed to finish training could vary greatly depending on a lot of things such as computing power, parameters, and size of datasets.
When I trained the model using the complete list of sentences from JParaCrawl which has around 5.9 million sentences for each language, it took around 5 hours per epoch using a single NVIDIA GeForce RTX 3070 GPU.
Here is the code:
for epoch in tqdm.tqdm(range(1, NUM_EPOCHS+1)):
start_time = time.time()
train_loss = train_epoch(transformer, train_iter, optimizer)
end_time = time.time()
print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
f"Epoch time = {(end_time - start_time):.3f}s"))
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-17-0ea7948b09d0> in <module>
----> 1 for epoch in tqdm.tqdm(range(1, NUM_EPOCHS+1)):
2 start_time = time.time()
3 train_loss = train_epoch(transformer, train_iter, optimizer)
4 end_time = time.time()
5 print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
NameError: name 'NUM_EPOCHS' is not defined
Try translating a Japanese sentence using the trained model
First, we create the functions to translate a new sentence, including steps such as to get the Japanese sentence, tokenize, convert to tensors, inference, and then decode the result back into a sentence, but this time in English.
def greedy_decode(model, src, src_mask, max_len, start_symbol):
src = src.to(device)
src_mask = src_mask.to(device)
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
for i in range(max_len-1):
memory = memory.to(device)
memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
tgt_mask = (generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(device)
out = model.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.item()
ys = torch.cat([ys,
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == EOS_IDX:
break
return ys
def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
model.eval()
tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
num_tokens = len(tokens)
src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
Then, we can just call the translate function and pass the required parameters.
translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-19-6f0241428bbd> in <module>
----> 1 translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)
NameError: name 'transformer' is not defined
trainen.pop(5)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-20-9968728e9d74> in <module>
----> 1 trainen.pop(5)
NameError: name 'trainen' is not defined
trainja.pop(5)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-21-9ed292a62dd8> in <module>
----> 1 trainja.pop(5)
NameError: name 'trainja' is not defined
Save the Vocab objects and trained model
Finally, after the training has finished, we will save the Vocab objects (en_vocab and ja_vocab) first, using Pickle.
import pickle
# open a file, where you want to store the data
file = open('en_vocab.pkl', 'wb')
# dump information to that file
pickle.dump(en_vocab, file)
file.close()
file = open('ja_vocab.pkl', 'wb')
pickle.dump(ja_vocab, file)
file.close()
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-22-1c00363b5ec3> in <module>
3 file = open('en_vocab.pkl', 'wb')
4 # dump information to that file
----> 5 pickle.dump(en_vocab, file)
6 file.close()
7 file = open('ja_vocab.pkl', 'wb')
NameError: name 'en_vocab' is not defined
Lastly, we can also save the model for later use using PyTorch save and load functions. Generally, there are two ways to save the model depending what we want to use them for later. The first one is for inference only, we can load the model later and use it to translate from Japanese to English.
# save model for inference
torch.save(transformer.state_dict(), 'inference_model')
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-23-70fc58d2b1f4> in <module>
1 # save model for inference
----> 2 torch.save(transformer.state_dict(), 'inference_model')
NameError: name 'transformer' is not defined
The second one is for inference too, but also for when we want to load the model later, and want to resume the training.
# save model + checkpoint to resume training later
torch.save({
'epoch': NUM_EPOCHS,
'model_state_dict': transformer.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
}, 'model_checkpoint.tar')
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-24-2e0ecdf61e01> in <module>
1 # save model + checkpoint to resume training later
2 torch.save({
----> 3 'epoch': NUM_EPOCHS,
4 'model_state_dict': transformer.state_dict(),
5 'optimizer_state_dict': optimizer.state_dict(),
NameError: name 'NUM_EPOCHS' is not defined
Conclusion
That’s it!