基于transformer实现机器翻译

WA2114176李毅恒

日中机器翻译模型与Transformer & PyTorch
本教程使用Jupyter Notebook、PyTorch、Torchtext和SentencePiece。

导入所需的包


首先,确保我们的系统中已安装以下包,如果发现缺少某些包,请确保安装它们。

import math
import torchtext
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from collections import Counter
from torchtext.vocab import Vocab
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
import io
import time
import pandas as pd
import numpy as np
import pickle
import tqdm
import sentencepiece as spm
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(torch.cuda.get_device_name(0)) ## 如果你有GPU,请在你自己的电脑上尝试运行这一套代码



获取并行数据集


在这个教程中,我们将使用从JParaCrawl![http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl]下载的日英平行数据集,它被描述为“由NTT创建的最大的公开可用英日平行语料库”。这个数据集是通过大规模爬取网页并自动对齐平行句子而创建的。你也可以在这里查看相关论文。

# 读取数据文件,并指定分隔符为制表符
df = pd.read_csv('./zh-ja/zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)
 
# 提取中文和日文文本列
trainen = df[2].values.tolist()  # 中文文本列表
trainja = df[3].values.tolist()  # 日文文本列表
# trainen.pop(5972)
# trainja.pop(5972)


导入所有的日语及其对应的英语之后,我删除了数据集中的最后一条数据,因为它缺少值。总的来说,trainen和trainja中的句子总数为5,973,071。但是,为了学习目的,通常建议抽样数据,并确保一切按预期工作,然后再一次性使用所有数据,以节省时间。


这里是数据集中的一对示例句子:

print(trainen[500])
print(trainja[500])


Output[0]

Chinese HS Code Harmonized Code System < HS编码 2905 无环醇及其卤化、磺化、硝化或亚硝化衍生物 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...
Japanese HS Code Harmonized Code System < HSコード 2905 非環式アルコール並びにそのハロゲン化誘導体、スルホン化誘導体、ニトロ化誘導体及びニトロソ化誘導体 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...


准备分词器


与英语或其他字母语言不同,日语句子中不包含分隔单词的空白。我们可以使用JParaCrawl提供的分词器,这些分词器是使用SentencePiece为日语和英语创建的,你可以访问JParaCrawl网站下载它们,或者点击这里。

en_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.en.nopretok.model')
ja_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.ja.nopretok.model')


加载了分词器之后,你可以通过执行以下代码进行测试。

en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')


Output[1]

['▁All',
 '▁residents',
 '▁aged',
 '▁20',
 '▁to',
 '▁59',
 '▁years',
 '▁who',
 '▁live',
 '▁in',
 '▁Japan',
 '▁must',
 '▁enroll',
 '▁in',
 '▁public',
 '▁pension',
 '▁system',
 '.']

ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')


Output[2]

['▁',
 '年',
 '金',
 '▁日本',
 'に住んでいる',
 '20',
 '歳',
 '~',
 '60',
 '歳の',
 '全ての',
 '人は',
 '、',
 '公的',
 '年',
 '金',
 '制度',
 'に',
 '加入',
 'しなければなりません',
 '。']


 构建TorchText的词汇表对象并将句子转换为Torch张量


使用分词器和原始句子,然后我们从TorchText导入Vocab对象进行构建。这个过程可能需要几秒钟或几分钟,具体取决于数据集的大小和计算能力。不同的分词器也会影响构建词汇表所需的时间,我尝试了其他几种日语分词器,但SentencePiece似乎对我来说工作得很好且足够快。

def build_vocab(sentences, tokenizer):
    counter = Counter()
    for sentence in sentences:
        counter.update(tokenizer.encode(sentence, out_type=str))
    return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
 
# 示例数据 trainja 和 trainen 分别为日语和英语句子列表
ja_vocab = build_vocab(trainja, ja_tokenizer)  # 构建日语词汇表
en_vocab = build_vocab(trainen, en_tokenizer)  # 构建英语词汇表


当我们有了词汇表对象之后,可以利用词汇表和分词器对象为我们的训练数据构建张量。

def data_process(ja, en):
    data = []
    for (raw_ja, raw_en) in zip(ja, en):
        # 将日语句子转换为张量
        ja_tensor = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
                                 dtype=torch.long)
        # 将英语句子转换为张量
        en_tensor = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
                                 dtype=torch.long)
        # 将日语和英语张量数据对加入到数据列表中
        data.append((ja_tensor, en_tensor))
    return data
 
# 将训练数据处理为张量数据集
train_data = data_process(trainja, trainen)


创建用于训练迭代的DataLoader对象


在这里,我将批量大小(BATCH_SIZE)设置为8,以避免“cuda out of memory”错误。实际的批量大小可以根据你的机器内存容量、数据集大小等因素进行调整。根据PyTorch教程使用Multi30k德英数据集的做法,批量大小设置为128。

BATCH_SIZE = 8
PAD_IDX = ja_vocab['<pad>']
BOS_IDX = ja_vocab['<bos>']
EOS_IDX = ja_vocab['<eos>']
 
def generate_batch(data_batch):
    ja_batch, en_batch = [], []
    for (ja_item, en_item) in data_batch:
        # 添加起始标记 <bos> 和结束标记 <eos>,并拼接成张量序列
        ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    
    # 对日语和英语批量数据进行填充,使它们的长度保持一致
    ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    
    return ja_batch, en_batch
 
# 创建训练数据迭代器,每次返回一个批量的数据
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)

Sequence-to-sequence Transformer


Transformer是一个在“Attention is all you need”论文中提出的Seq2Seq模型,用于解决机器翻译任务。Transformer模型由一个编码器和一个解码器组成,每个都包含固定数量的层。


编码器通过传播输入序列,通过一系列的多头注意力和前馈网络层来处理输入序列。编码器的输出,被称为记忆,与目标张量一起被送入解码器。编码器和解码器使用教师强制(teacher forcing)技术以端到端的方式进行训练。
 

from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
 
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        
        # Transformer 编码器层
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # Transformer 解码器层
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
 
        # 线性层作为生成器,将解码器输出映射到目标词汇表的维度
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        
        # 源语言和目标语言的词嵌入层
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        
        # 位置编码层
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
 
    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        # 对源语言和目标语言的输入进行位置编码
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        
        # Transformer 编码器处理源语言序列
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        
        # Transformer 解码器处理目标语言序列
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        
        # 通过生成器映射到目标词汇表的维度
        return self.generator(outs)
 
    def encode(self, src: Tensor, src_mask: Tensor):
        # 编码器的编码过程,返回编码器的输出(编码器最后一层的输出)
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)
 
    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        # 解码器的解码过程,返回解码器的输出(解码器最后一层的输出)
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

文本标记通过使用标记嵌入来表示。位置编码通过添加到标记嵌入中引入了单词顺序的概念。

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        # 计算位置编码
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)
 
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)
 
    def forward(self, token_embedding: Tensor):
        # 返回加上位置编码后的结果,并应用 dropout
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0), :])
 
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    
    def forward(self, tokens: Tensor):
        # 返回词嵌入乘以 sqrt(词嵌入维度) 的结果
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

我们创建了一个后续单词掩码,以阻止目标单词关注其后续的单词。同时,我们还创建了掩码,用于屏蔽源语言和目标语言的填充标记。
 

def generate_square_subsequent_mask(sz):
    """
    生成一个用于 Transformer 解码器的目标序列遮蔽,确保在每个时间步只能看到之前的信息。
    Args:
    - sz (int): 序列的长度
    Returns:
    - mask (Tensor): 形状为 (sz, sz) 的上三角遮蔽矩阵,对角线及以下元素为-∞,其余元素为0
    """
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
 
def create_mask(src, tgt):
    """
    创建用于 Transformer 模型的遮蔽张量。
    Args:
    - src (Tensor): 源序列张量
    - tgt (Tensor): 目标序列张量
    Returns:
    - src_mask (Tensor): 源序列的填充遮蔽张量,形状为 (src_seq_len, src_seq_len)
    - tgt_mask (Tensor): 目标序列的上三角遮蔽张量,形状为 (tgt_seq_len, tgt_seq_len)
    - src_padding_mask (Tensor): 源序列的填充遮蔽张量,形状为 (src_seq_len, batch_size)
    - tgt_padding_mask (Tensor): 目标序列的填充遮蔽张量,形状为 (tgt_seq_len, batch_size)
    """
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]
 
    # 生成目标序列的遮蔽张量
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    
    # 源序列的遮蔽张量为全零矩阵,不遮蔽任何内容
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
 
    # 生成源序列和目标序列的填充遮蔽张量
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
 
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Define model parameters and instantiate model. 这里我们服务器实在是计算能力有限,按照以下配置可以训练但是效果应该是不行的。如果想要看到训练的效果请使用你自己的带GPU的电脑运行这一套代码。

当你使用自己的GPU的时候,NUM_ENCODER_LAYERS 和 NUM_DECODER_LAYERS 设置为3或者更高,NHEAD设置8,EMB_SIZE设置为512。

SRC_VOCAB_SIZE = len(ja_vocab)  # 源语言词汇表大小
TGT_VOCAB_SIZE = len(en_vocab)  # 目标语言词汇表大小
EMB_SIZE = 512  # 词嵌入维度
NHEAD = 8  # 注意力头的数量
FFN_HID_DIM = 512  # FeedForward 层隐藏层维度
BATCH_SIZE = 16  # 批量大小
NUM_ENCODER_LAYERS = 3  # 编码器层数
NUM_DECODER_LAYERS = 3  # 解码器层数
NUM_EPOCHS = 16  # 训练轮数
 
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)
 
# 使用 Xavier 初始化网络参数
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
 
transformer = transformer.to(device)  # 将模型移动到GPU上(如果可用)
 
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # 定义损失函数,忽略填充索引
 
optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)
 
def train_epoch(model, train_iter, optimizer):
    """
    训练模型的一个epoch。
    Args:
    - model (Seq2SeqTransformer): Transformer 模型
    - train_iter (DataLoader): 训练数据迭代器
    - optimizer (torch.optim.Adam): 优化器
    Returns:
    - float: 平均损失值
    """
    model.train()
    losses = 0
    for idx, (src, tgt) in enumerate(train_iter):
        src = src.to(device)
        tgt = tgt.to(device)
 
        tgt_input = tgt[:-1, :]
 
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
 
        logits = model(src, tgt_input, src_mask, tgt_mask,
                      src_padding_mask, tgt_padding_mask, src_padding_mask)
 
        optimizer.zero_grad()
 
        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
 
        optimizer.step()
        losses += loss.item()
    return losses / len(train_iter)
 
 
def evaluate(model, val_iter):
    """
    评估模型在验证集上的表现。
    Args:
    - model (Seq2SeqTransformer): Transformer 模型
    - val_iter (DataLoader): 验证数据迭代器
    Returns:
    - float: 平均验证损失值
    """
    model.eval()
    losses = 0
    for idx, (src, tgt) in (enumerate(valid_iter)):
        src = src.to(device)
        tgt = tgt.to(device)
 
        tgt_input = tgt[:-1, :]
 
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
 
        logits = model(src, tgt_input, src_mask, tgt_mask,
                      src_padding_mask, tgt_padding_mask, src_padding_mask)
        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(val_iter)

开始训练


在准备必要的类和函数之后,我们终于准备好训练我们的模型了。不用说,完成训练所需的时间可能会因计算能力、参数和数据集大小等因素而有很大的不同。

当我使用包含大约590万条句子的JParaCrawl的完整句子列表来训练模型时,使用单个NVIDIA GeForce RTX 3070 GPU,每个epoch大约需要5个小时。

下面是代码:

for epoch in tqdm.tqdm(range(1, NUM_EPOCHS+1)):
  start_time = time.time()  # 记录每个epoch的开始时间
  train_loss = train_epoch(transformer, train_iter, optimizer)  # 执行一个epoch的训练并计算损失
  end_time = time.time()  # 记录每个epoch的结束时间
  print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s"))  # 打印当前epoch的训练损失和耗时

尝试使用训练好的模型翻译日语句子


首先,我们创建函数来翻译新句子,包括获取日语句子、分词、转换成张量、推理,然后将结果解码回句子,但这次是英语。

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    """
    使用贪婪解码方法生成目标语言序列。
    Args:
    - model (Seq2SeqTransformer): Transformer模型对象。
    - src (Tensor): 源语言输入序列张量,形状为(seq_len, batch_size)。
    - src_mask (Tensor): 源语言输入序列的mask张量,形状为(seq_len, seq_len)。
    - max_len (int): 生成序列的最大长度。
    - start_symbol (int): 目标语言序列的起始符号索引。
    Returns:
    - ys (Tensor): 生成的目标语言序列张量,形状为(seq_len, 1)。
    """
    src = src.to(device)
    src_mask = src_mask.to(device)
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    
    for i in range(max_len-1):
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        
        if next_word == EOS_IDX:
            break
    
    return ys
 
def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
    """
    将源语言文本翻译为目标语言文本。
    Args:
    - model (Seq2SeqTransformer): Transformer模型对象。
    - src (str): 源语言文本字符串。
    - src_vocab (Vocab): 源语言词汇表对象。
    - tgt_vocab (Vocab): 目标语言词汇表对象。
    - src_tokenizer (Tokenizer): 源语言文本分词器。
    Returns:
    - translation (str): 翻译后的目标语言文本字符串。
    """
    model.eval()
    tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)] + [EOS_IDX]
    num_tokens = len(tokens)
    src = torch.LongTensor(tokens).reshape(num_tokens, 1)
    src_mask = torch.zeros(num_tokens, num_tokens).type(torch.bool)
    tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")


然后,我们可以直接调用翻译函数并传递所需的参数。
 

translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)
trainen.pop(5)


Output[3]

'美国 设施: 停车场, 24小时前台, 健身中心, 报纸, 露台, 禁烟客房, 干洗, 无障碍设施, 免费停车, 上网服务, 电梯, 快速办理入住/退房手续, 保险箱, 暖气, 传真/复印, 行李寄存, 无线网络, 免费无线网络连接, 酒店各处禁烟, 空调, 阳光露台, 自动售货机(饮品), 自动售货机(零食), 每日清洁服务, 内部停车场, 私人停车场, WiFi(覆盖酒店各处), 停车库, 无障碍停车场, 简短描述Gateway Hotel Santa Monica酒店距离海滩2英里(3.2公里),提供24小时健身房。每间客房均提供免费WiFi,客人可以使用酒店的免费地下停车场。'

trainja.pop(5)


Output[4]

'アメリカ合衆国 施設・設備: 駐車場, 24時間対応フロント, フィットネスセンター, 新聞, テラス, 禁煙ルーム, ドライクリーニング, バリアフリー, 無料駐車場, インターネット, エレベーター, エクスプレス・チェックイン / チェックアウト, セーフティボックス, 暖房, FAX / コピー, 荷物預かり, Wi-Fi, 無料Wi-Fi, 全館禁煙, エアコン, サンテラス, 自販機(ドリンク類), 自販機(スナック類), 客室清掃サービス(毎日), 敷地内駐車場, 専用駐車場, Wi-Fi(館内全域), 立体駐車場, 障害者用駐車場, 短い説明Gateway Hotel Santa Monicaはビーチから3.2kmの場所に位置し、24時間利用可能なジム、無料Wi-Fi付きのお部屋、無料の地下駐車場を提供しています。'

保存Vocab对象和训练好的模型

 
最后,训练完成后,我们首先使用Pickle保存Vocab对象(en_vocab和ja_vocab)。

import pickle
 
# 打开一个文件,用于存储数据
file = open('en_vocab.pkl', 'wb')
 
# 将词汇表对象 en_vocab 序列化并写入文件
pickle.dump(en_vocab, file)
 
# 关闭文件
file.close()
 
# 打开一个文件,用于存储数据
file = open('ja_vocab.pkl', 'wb')
 
# 将词汇表对象 ja_vocab 序列化并写入文件
pickle.dump(ja_vocab, file)
 
# 关闭文件
file.close()


最后,我们还可以使用PyTorch的保存和加载函数来保存模型以备后用。通常,根据我们以后想要用模型来做什么,有两种保存模型的方法。第一种是仅用于推理,我们可以在以后加载模型并使用它来从日语翻译成英语。
 

# save model for inference
torch.save(transformer.state_dict(), 'inference_model')


第二种也是用于推理,但同时也适用于当我们以后想要加载模型并希望恢复训练时。
 

# 使用torch.save()保存以下内容到'model_checkpoint.tar'文件中:
torch.save({
  'epoch': NUM_EPOCHS,
  'model_state_dict': transformer.state_dict(),
  'optimizer_state_dict': optimizer.state_dict(),
  'loss': train_loss,
  }, 'model_checkpoint.tar')


That’s it!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值