基于 Transformer 和 PyTorch 的日汉机器翻译模型

本次实验为基于Transformer的机器翻译任务

1.导入所需的包

首先,让我们确保在系统中安装了以下软件包,如果发现缺少某些软件包,请确保安装它们。

import math
import torchtext
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from collections import Counter
from torchtext.vocab import Vocab
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
import io
import time
import pandas as pd
import numpy as np
import pickle
import tqdm
import sentencepiece as spm
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')

2.获取平行语料库

在本教程中,我们将使用从 JParaCrawl 下载的日英平行语料库![http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl],该语料库被描述为“由 NTT 创建的最大公开可用的英日平行语料库。它主要是通过抓取网络和自动对齐平行句子创建的”。您也可以在这里查看论文。

df = pd.read_csv('./zh-ja/zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)# 使用pandas读取CSV文件,数据文件中的字段是以制表符'\t'分隔的,不包含表头信息
trainen = df[2].values.tolist()#[:10000]# 将数据框中的第三列(索引为2,通常为英文句子)转换为列表
trainja = df[3].values.tolist()#[:10000]# 将数据框中的第四列(索引为3,通常为日文句子)转换为列表
# trainen.pop(5972)
# trainja.pop(5972)

导入所有日语及其对应的英语后,我删除了数据集中最后一个数据,因为它缺少值。trainen 和 trainja 中的句子总数为 5,973,071,但是,出于学习目的,通常建议对数据进行采样,并确保一切按预期工作,然后再一次使用所有数据,以节省时间。

以下是数据集中包含的句子示例。

print(trainen[500])
print(trainja[500])
Chinese HS Code Harmonized Code System < HS编码 2905 无环醇及其卤化、磺化、硝化或亚硝化衍生物 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...
Japanese HS Code Harmonized Code System < HSコード 2905 非環式アルコール並びにそのハロゲン化誘導体、スルホン化誘導体、ニトロ化誘導体及びニトロソ化誘導体 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...

我们也可以使用不同的平行语料库来学习本文,只需确保我们可以像上面显示的那样将数据处理成两个字符串列表,分别包含日语和英语句子。

3.准备分词器

与英语或其他字母语言不同,日语句子不包含用于分隔单词的空格。我们可以使用 JParaCrawl 提供的分词器,该分词器是使用 SentencePiece 为日语和英语创建的,您可以访问 JParaCrawl 网站下载它们,或单击此处。

en_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.en.nopretok.model')
# 初始化一个SentencePiece Processor对象用于英文文本处理
# 'enja_spm_models/spm.en.nopretok.model'是SentencePiece模型文件的路径,
# 这个模型是之前通过对英文文本训练得到的,用于将英文文本切分成子词
ja_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.ja.nopretok.model')
['▁All',
 '▁residents',
 '▁aged',
 '▁20',
 '▁to',
 '▁59',
 '▁years',
 '▁who',
 '▁live',
 '▁in',
 '▁Japan',
 '▁must',
 '▁enroll',
 '▁in',
 '▁public',
 '▁pension',
 '▁system',
 '.']
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type=str)
['▁',
 '年',
 '金',
 '▁日本',
 'に住んでいる',
 '20',
 '歳',
 '~',
 '60',
 '歳の',
 '全ての',
 '人は',
 '、',
 '公的',
 '年',
 '金',
 '制度',
 'に',
 '加入',
 'しなければなりません',
 '。']

4.构建 TorchText 词汇表对象并将句子转换为 Torch 张量

使用分词器和原始句子,我们然后构建从 TorchText 导入的 Vocab 对象。此过程可能需要几秒钟或几分钟,具体取决于我们数据集的大小和计算能力。不同的分词器也会影响构建词汇表所需的时间,我尝试了其他几种日语分词器,但 SentencePiece 似乎运行良好且对我来说足够快。

# 定义一个函数用于根据句子列表和tokenizer构建词汇表
def build_vocab(sentences, tokenizer):
  counter = Counter()# 使用Counter来统计所有子词出现的频次
  for sentence in sentences:
    # 使用tokenizer对每个句子进行编码,并更新到counter中
    # out_type=str确保编码后的结果是字符串形式,便于计数
    counter.update(tokenizer.encode(sentence, out_type=str))
  return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
ja_vocab = build_vocab(trainja, ja_tokenizer)# 利用上述函数和日文训练句子列表及日文tokenizer构建日文词汇表
en_vocab = build_vocab(trainen, en_tokenizer)#构建英文词汇表

在获得词汇表对象后,我们就可以使用词汇表和分词器对象为我们的训练数据构建张量。

def data_process(ja, en):
  data = []# 初始化一个空列表,用于存储处理后的数据对
# 使用zip函数同时遍历日文句子列表和英文句子列表
  for (raw_ja, raw_en) in zip(ja, en):
    # 对于每一对句子,使用ja_tokenizer对日文句子进行编码
    # rstrip("\n")移除句子末尾的换行符,out_type=str保证编码结果为字符串
    # 然后将每个子词(token)映射到其在词汇表中的索引,并转换为Long类型的张量
    ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
                            dtype=torch.long)
    en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
                            dtype=torch.long)
    data.append((ja_tensor_, en_tensor_))# 将处理后的日文和英文张量作为元组加入到data列表中
  return data
train_data = data_process(trainja, trainen)# 调用data_process函数处理训练集中的日文和英文句子,生成训练数据集

5.创建 DataLoader 对象以便在训练期间进行迭代

在这里,我将 BATCH_SIZE 设置为 16 以防止出现“cuda 内存不足”,但这取决于各种因素,例如您的机器内存容量、数据大小等,因此请根据您的需要随意更改批处理大小(注意:PyTorch 的教程使用 Multi30k 德语-英语 数据集将批大小设置为 128。)

BATCH_SIZE = 8# 定义批次大小
PAD_IDX = ja_vocab['<pad>']# 填充符号的索引
BOS_IDX = ja_vocab['<bos>']# 句子开始符号的索引
EOS_IDX = ja_vocab['<eos>']# 句子结束符号的索引
def generate_batch(data_batch):
  ja_batch, en_batch = [], []# 初始化批次列表
  for (ja_item, en_item) in data_batch:
    # 对于每个句子,在开头添加句子开始符号<BOS>,结尾添加句子结束符号<EOS>
    # 并使用torch.cat将它们连接起来形成新的张量
    ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
    en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
  # 使用pad_sequence函数对所有句子进行填充,确保每个批次内的所有句子长度一致
  # 短的句子会用PAD_IDX填充到最长句子的长度
  ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
  en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
  return ja_batch, en_batch
# 创建DataLoader用于迭代训练数据
# 指定了批次大小、是否打乱数据、以及自定义的生成批次函数
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)

6.序列到序列 Transformer

接下来的几段代码和文字解释(用斜体表示)取自 PyTorch 官方教程 Language Translation with nn.Transformer and torchtext — PyTorch Tutorials 2.3.0+cu121 documentation除了 BATCH_SIZE 和单词 de_vocab(已更改为 ja_vocab)之外,我没有进行任何更改。

Transformer 是“Attention is all you need”论文中介绍的一种 Seq2Seq 模型,用于解决机器翻译任务。Transformer 模型由编码器和解码器块组成,每个块包含固定数量的层。

编码器通过一系列多头注意力和前馈网络层传播输入序列来处理输入序列。编码器的输出称为内存,与目标张量一起馈送到解码器。编码器和解码器使用强制教学技术以端到端的方式进行训练。

from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)


class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        # 定义编码器层
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        # 使用定义的层构建Transformer编码器
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        # 同样定义解码器层
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)# 构建Transformer解码器
        # 用于从模型输出到目标词典概率的线性层
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        # 定义源语言和目标语言的词嵌入层
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        # 位置编码层,给每个位置的词嵌入添加位置信息
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        # 对源语言和目标语言的输入应用位置编码和词嵌入
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)# 编码阶段:使用编码器处理源语言序列
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask) # 解码阶段:使用解码器处理目标语言序列,同时考虑编码器的输出、掩码等
        return self.generator(outs)# 最后通过一个线性层映射到目标词汇表大小的logits

    def encode(self, src: Tensor, src_mask: Tensor):# 提供单独的编码方法
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):# 提供单独的解码方法
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

文本标记通过使用标记嵌入来表示。位置编码被添加到标记嵌入中以引入词序的概念。

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()# 计算位置编码中的sin和cos部分,这有助于模型理解序列中词语的位置信息
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)# 偶数索引位置使用sin函数
        pos_embedding[:, 1::2] = torch.cos(pos * den)# 奇数索引位置使用cos函数
        pos_embedding = pos_embedding.unsqueeze(-2)# 在序列长度这一维度前增加一个维度以匹配后续操作

        self.dropout = nn.Dropout(dropout)# 定义dropout层以防止过拟合
        self.register_buffer('pos_embedding', pos_embedding)# 将位置嵌入注册为持久的缓冲区,不会被优化器更新

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0),:])# 将位置编码加到词嵌入上,然后应用dropout

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()# 初始化一个 Embedding 层,用于将词汇表中的索引转换成词嵌入向量
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        # 将输入的词语索引(tokens)转换为词嵌入向量,并乘以嵌入维度的平方根,
        # 这一操作通常是为了保持梯度传递时的数值稳定性
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

我们创建了一个后续词掩码,以阻止目标词关注其后续词。我们还创建了掩码,用于屏蔽源和目标填充标记。

def generate_square_subsequent_mask(sz):
    # 生成一个上三角矩阵,对角线及上方元素为1,下方元素为0
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    # 将mask转换为float类型,并将原值为0的位置替换为负无穷大(用于后续softmax时相应位置输出为0)
    # 将原值为1的位置替换为0,这样在加到attention得分上不会有影响
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
  src_seq_len = src.shape[0]# 获取源序列的长度
  tgt_seq_len = tgt.shape[0]# 获取目标序列的长度

  tgt_mask = generate_square_subsequent_mask(tgt_seq_len)# 为目标序列生成后续掩码(用于掩盖未来时刻的信息)
  src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)# 创建源序列的自注意力掩码,这里全为False,意味着不屏蔽任何元素(如果模型架构需要的话,可以进行调整)

  src_padding_mask = (src == PAD_IDX).transpose(0, 1)
  tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
  return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask# 返回所有类型的掩码

Define model parameters and instantiate model. 这里我们服务器实在是计算能力有限,按照以下配置可以训练但是效果应该是不行的。如果想要看到训练的效果请使用你自己的带GPU的电脑运行这一套代码。

当你使用自己的GPU的时候,NUM_ENCODER_LAYERS 和 NUM_DECODER_LAYERS 设置为3或者更高,NHEAD设置8,EMB_SIZE设置为512。

SRC_VOCAB_SIZE = len(ja_vocab)# 定义源语言词汇表大小
TGT_VOCAB_SIZE = len(en_vocab)# 定义目标语言词汇表大小
EMB_SIZE = 512# 设置词嵌入维度
NHEAD = 8# 设置多头注意力中头的数量
FFN_HID_DIM = 512# 设置前馈网络的隐藏层维度
BATCH_SIZE = 16# 批次大小
NUM_ENCODER_LAYERS = 3# 编码器层数
NUM_DECODER_LAYERS = 3# 解码器层数
NUM_EPOCHS = 16# 总训练轮数
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)# 实例化Seq2SeqTransformer模型
# 使用Xavier初始化对模型参数进行初始化,有利于权重初始化的均匀分布
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
# 将模型移到预设的计算设备上,如GPU
transformer = transformer.to(device)
# 定义损失函数,忽略PAD_IDX对应的位置,以避免填充部分对损失计算的影响
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
# 初始化Adam优化器,配置学习率、betas值和epsilon值
optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)
def train_epoch(model, train_iter, optimizer):
  model.train() # 将模型设置为训练模式
  losses = 0  # 初始化损失总和
  for idx, (src, tgt) in  enumerate(train_iter):
      src = src.to(device)# 将数据移至计算设备
      tgt = tgt.to(device)

      tgt_input = tgt[:-1, :]# 准备目标序列的输入(去掉序列的第一个元素,用于预测)
        # 为当前批次创建掩码
      src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

      logits = model(src, tgt_input, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)# 模型前向传播得到logits

      optimizer.zero_grad()# 清零梯度

      tgt_out = tgt[1:,:]# 截取目标序列的真实输出(去掉序列的第一个元素)
      loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))# 计算损失
      loss.backward()# 反向传播

      optimizer.step()# 更新模型参数
      losses += loss.item()# 累加损失
  return losses / len(train_iter)# 返回平均损失


def evaluate(model, val_iter):
  model.eval()# 将模型设置为评估模式
  losses = 0# 初始化损失总和
  for idx, (src, tgt) in (enumerate(valid_iter)):
    src = src.to(device)
    tgt = tgt.to(device)

    tgt_input = tgt[:-1, :]# 准备目标序列的输入

    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)# 创建掩码
    # 模型前向传播
    logits = model(src, tgt_input, src_mask, tgt_mask,
                              src_padding_mask, tgt_padding_mask, src_padding_mask)
    tgt_out = tgt[1:,:]# 截取目标序列的真实输出
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))# 计算损失
    losses += loss.item()# 累加损失
  return losses / len(val_iter)# 返回平均损失

7.开始训练

最后,在准备好必要的类和函数之后,我们就可以开始训练模型了。毫无疑问,完成训练所需的时间可能会有很大差异,这取决于很多因素,例如计算能力、参数和数据集的大小。

当我使用来自 JParaCrawl 的完整句子列表(每种语言约有 590 万个句子)训练模型时,使用单个 NVIDIA GeForce RTX 3070 GPU,每个 epoch 大约需要 5 个小时。

以下是代码:

# 使用tqdm库来显示训练进度条,使训练过程可视化
for epoch in tqdm.tqdm(range(1, NUM_EPOCHS+1)):
  start_time = time.time()# 记录当前轮次开始的时间点
  train_loss = train_epoch(transformer, train_iter, optimizer)# 执行一个训练轮次,并返回该轮次的平均训练损失
  end_time = time.time()# 记录当前轮次结束的时间点
  print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "

8.使用训练好的模型尝试翻译日语句子

首先,我们创建用于翻译新句子的函数,包括获取日语句子、分词、转换为张量、推理,然后将结果解码回句子等步骤,但这次是用英语。

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    # 将源序列和掩码移到指定设备
    src = src.to(device)
    src_mask = src_mask.to(device)
    # 使用模型的编码器得到编码后的记忆向量
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)# 初始化预测序列,开始符号
    # 迭代最大长度减去起始符号的位置
    for i in range(max_len-1):
        memory = memory.to(device)# 确保记忆向量在目标设备上
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)# 创建源序列的记忆掩码,初始为全False,表示没有需要遮盖的部分
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(device)# 生成目标序列的自注意力掩码,确保只关注之前的词
        out = model.decode(ys, memory, tgt_mask)# 使用解码器得到输出
        out = out.transpose(0, 1)# 转置输出以适配生成过程
        prob = model.generator(out[:, -1])# 通过生成器得到下一个词的概率分布
        _, next_word = torch.max(prob, dim = 1)# 找到概率最大的下一个词
        next_word = next_word.item()
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)# 将下一个词添加到预测序列中
        if next_word == EOS_IDX:
          break# 如果遇到结束符号,则停止生成
    return ys
def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
    model.eval()# 将模型设置为评估模式
    # 对源文本进行预处理,添加开始和结束符号,并转换为索引表示
    tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
    num_tokens = len(tokens)
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)# 创建源序列的掩码
    tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()# 使用贪婪解码生成目标序列的索引表示
    return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")# 将目标序列的索引转换回单词,并去除开始和结束符号,然后拼接成翻译文本

然后,我们可以调用translate函数并传递所需的参数

translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)
' ▁H S 代 码 ▁85 15 ▁ 用 焊 接 设 备 ( 包 括 电 气 加 热 气 体 ) 。 '
trainen.pop(5)
'Chinese HS Code Harmonized Code System < HS编码 8515 : 电气(包括电热气体)、激光、其他光、光子束、超声波、电子束、磁脉冲或等离子弧焊接机器及装置,不论是否 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...'
trainja.pop(5)
'Japanese HS Code Harmonized Code System < HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)、レーザーその他の光子ビーム式、超音波式、電子ビーム式、 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...'

9.保存词汇表对象和训练好的模型

最后,训练完成后,我们将首先使用 Pickle 保存词汇表对象(en_vocab 和 ja_vocab)。

import pickle# 导入pickle模块,用于对象的序列化和反序列化
# open a file, where you want to store the data
file = open('en_vocab.pkl', 'wb')# 打开一个文件(二进制写模式'wb'),准备将数据存储到这个文件中
# dump information to that file
pickle.dump(en_vocab, file)# 使用pickle的dump方法将en_vocab对象的数据写入到刚刚打开的文件中
file.close()# 关闭文件,释放系统资源
file = open('ja_vocab.pkl', 'wb')# 再次打开一个新文件(二进制写模式'wb'),这次用于存储日语词汇表
pickle.dump(ja_vocab, file)# 使用pickle.dump将ja_vocab对象的数据写入文件
file.close()

最后,我们还可以使用 PyTorch 的保存和加载函数保存模型以供以后使用。通常,根据我们以后想用模型做什么,有两种保存模型的方法。第一种方法仅用于推理,我们可以稍后加载模型并使用它将日语翻译成英语。

# save model for inference
torch.save(transformer.state_dict(), 'inference_model')

第二种方法也用于推理,但也用于我们希望稍后加载模型并希望继续训练的情况。

# save model for inference
torch.save(transformer.state_dict(), 'inference_model')
torch.save({
    'epoch': NUM_EPOCHS,
    'model_state_dict': transformer.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': train_loss,
    }, 'model_checkpoint.tar')

  • 38
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值