基于Transformer实现机器翻译(NLP)

一、导入所需的库函数

首先,让我们确保我们的系统中安装了以下软件包,如果你发现有些软件包丢失了,一定要安装它们。

import math
import torchtext
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from collections import Counter
from torchtext.vocab import Vocab
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
import io
import time
import pandas as pd
import numpy as np
import pickle
import tqdm
import sentencepiece as spm
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(torch.cuda.get_device_name(0)) ## 如果你有GPU,请在你自己的电脑上尝试运行这一套代码


device
device(type='cpu')

二、获取并行数据集

在本教程中,我们将使用从http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl下载的日语-英语并行数据集,该数据集被描述为“NTT创建的最大的公开可用的英语-日语并行语料库”。它主要是通过抓取网络并自动对齐平行句子创建的。”你也可以在这里看到说明。

# 从名为'zh-ja.bicleaner05.txt'的CSV文件中读取数据到DataFrame 'df'中,使用制表符作为分隔符,不使用C引擎解析
df = pd.read_csv('zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)

# 将df中第2列的值转换为列表并赋给变量'trainen';即将英文文本数据提取为训练集
trainen = df[2].values.tolist()  # 将数据转换为列表类型

# 将df中第3列的值转换为列表并赋给变量'trainja';即将日文文本数据提取为训练集
trainja = df[3].values.tolist()  # 将数据转换为列表类型

在导入所有日语和英语对应项之后,我删除了数据集中的最后一个数据,因为它有一个缺失的值。总的来说,trainen和trainja中的句子数量都是5,973,071,然而,出于学习目的,通常建议在一次使用所有数据之前对数据进行采样并确保一切正常工作,以节省时间。

下面是一个包含在数据集中的句子示例。

# 打印训练集中索引为500的英文句子
print(trainen[500])

# 打印训练集中索引为500的日文句子
print(trainja[500])
Chinese HS Code Harmonized Code System < HS编码 2905 无环醇及其卤化、磺化、硝化或亚硝化衍生物 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...
Japanese HS Code Harmonized Code System < HSコード 2905 非環式アルコール並びにそのハロゲン化誘導体、スルホン化誘導体、ニトロ化誘導体及びニトロソ化誘導体 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...

我们也可以使用不同的并行数据集来跟随本文,只要确保我们可以将数据处理成如上所示的两个字符串列表,其中包含日语和英语句子。

三、准备标记器

与英语或其他按字母顺序排列的语言不同,日语句子不包含空格来分隔单词。我们可以使用JParaCrawl提供的标记器,它是使用sentencepece为日语和英语创建的,您可以访问JParaCrawl网站下载它们,或者点击这里。

# 创建英文句子的分词器,加载提前训练好的SentencePiece模型'spm.en.nopretok.model'
en_tokenizer = spm.SentencePieceProcessor(model_file='spm.en.nopretok.model')

# 创建日文句子的分词器,加载提前训练好的SentencePiece模型'spm.ja.nopretok.model'
ja_tokenizer = spm.SentencePieceProcessor(model_file='spm.ja.nopretok.model')

加载标记器之后,您可以测试它们,例如,通过执行下面的代码

# 使用英文分词器对给定句子进行编码,并输出为字符串形式
en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')
['▁All',
 '▁residents',
 '▁aged',
 '▁20',
 '▁to',
 '▁59',
 '▁years',
 '▁who',
 '▁live',
 '▁in',
 '▁Japan',
 '▁must',
 '▁enroll',
 '▁in',
 '▁public',
 '▁pension',
 '▁system',
 '.']
# 使用日文分词器对给定日文句子进行编码,并输出为字符串形式
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
['▁',
 '年',
 '金',
 '▁日本',
 'に住んでいる',
 '20',
 '歳',
 '~',
 '60',
 '歳の',
 '全ての',
 '人は',
 '、',
 '公的',
 '年',
 '金',
 '制度',
 'に',
 '加入',
 'しなければなりません',
 '。']

四、建立TorchText词汇对象并将句子转换为Torch张量

使用标记器和原始句子,然后构建从TorchText导入的Vocab对象。根据数据集的大小和计算能力,这个过程可能需要几秒钟或几分钟。不同的标记器也会影响构建词汇所需的时间,我尝试了其他几个日语标记器,但sensenepece似乎工作得很好,对我来说足够快。

# 定义一个函数build_vocab,用于构建词汇表
def build_vocab(sentences, tokenizer):
    counter = Counter()
    for sentence in sentences:
        counter.update(tokenizer.encode(sentence, out_type=str))
    return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

# 通过build_vocab函数构建日文句子的词汇表ja_vocab,统计训练集中的日文句子并使用日文分词器进行编码
ja_vocab = build_vocab(trainja, ja_tokenizer)

# 通过build_vocab函数构建英文句子的词汇表en_vocab,统计训练集中的英文句子并使用英文分词器进行编码

在我们有了词汇表对象之后,我们可以使用词汇表和标记器对象来为我们的训练数据构建张量。

# 定义数据处理函数data_process,用于将原始日文和英文句子转换为张量形式的数据
def data_process(ja, en):
    data = []
    for (raw_ja, raw_en) in zip(ja, en):
        # 使用日文分词器对原始日文句子进行编码,并转换为张量形式
        ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
                                  dtype=torch.long)
        # 使用英文分词器对原始英文句子进行编码,并转换为张量形式
        en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
                                  dtype=torch.long)
        data.append((ja_tensor_, en_tensor_))
    return data

# 调用数据处理函数data_process,将训练集的日文和英文句子处理成张量形式的训练数据
train_data = data_process(trainja, trainen)

五、创建要在训练期间迭代的DataLoader对象

在这里,我将BATCH_SIZE设置为16以防止“cuda内存不足”,但这取决于各种事情,例如您的机器内存容量,数据大小等,因此可以根据您的需要随意更改批大小(注意:PyTorch的教程使用Multi30k德语-英语数据集将批大小设置为128)。

# 定义批处理大小和特殊标记索引
BATCH_SIZE = 8
PAD_IDX = ja_vocab['<pad>']
BOS_IDX = ja_vocab['<bos>']
EOS_IDX = ja_vocab['<eos>']

# 定义生成批处理数据的函数generate_batch
def generate_batch(data_batch):
    ja_batch, en_batch = [], []
    for (ja_item, en_item) in data_batch:
        # 在日文句子前后分别添加起始和结束标记,并将其转化为张量形式
        ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
        # 在英文句子前后分别添加起始和结束标记,并将其转化为张量形式
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    # 对日文句子和英文句子进行填充操作,使它们具有相同的长度
    ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return ja_batch, en_batch

# 创建训练数据迭代器train_iter,用于对处理后的训练数据进行分批和混洗操作
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)

六、Sequence-to-sequence变压器

接下来的代码和文本解释(以斜体书写)来自原始的PyTorch教程https://pytorch.org/tutorials/beginner/translation_transformer.html。除了BATCH_SIZE和单词de_vocab被更改为ja_vocab之外,我没有做任何更改。

Transformer是在“Attention is all you need”论文中提出的用于解决机器翻译任务的Seq2Seq模型。变压器模型由编码器和解码器块组成,每个块包含固定数量的层。

编码器通过一系列多头注意和前馈网络层对输入序列进行传播处理。编码器的输出称为存储器,与目标张量一起馈送到解码器。编码器和解码器以端到端方式使用教师强迫技术进行培训。

from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)


class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        """
        初始化Seq2SeqTransformer类
        Args:
            num_encoder_layers: 编码器的层数
            num_decoder_layers: 解码器的层数
            emb_size: 嵌入维度大小
            src_vocab_size: 源语言词汇表大小
            tgt_vocab_size: 目标语言词汇表大小
            dim_feedforward: 前馈网络的维度
            dropout: 丢弃率
        """
        super(Seq2SeqTransformer, self).__init__()
        
        # 初始化编码器和解码器层
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # 初始化生成器(用于映射TransformerDecoder的输出到目标语言词汇空间)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        
        # 创建并初始化源语言和目标语言的token嵌入(TokenEmbedding)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        
        # 创建并初始化位置编码(PositionalEncoding)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        """
        前向传播方法
        Args:
            src: 输入源语言句子
            trg: 输入目标语言句子
            src_mask: 源语言掩码
            tgt_mask: 目标语言掩码
            src_padding_mask: 源语言填充掩码
            tgt_padding_mask: 目标语言填充掩码
            memory_key_padding_mask: 内存填充掩码
        Returns:
            模型的输出
        """
        # 对源语言和目标语言进行嵌入和位置编码
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        
        # 使用编码器对源语言进行编码得到memory
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        
        # 使用解码器结合memory对目标语言进行解码得到输出
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        
        # 将输出映射到目标语言词汇空间
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        """
        对源语言句子进行编码
        Args:
            src: 输入源语言句子
            src_mask: 源语言掩码
        Returns:
            编码后的表示
        """
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        """
        对目标语言句子进行解码
        Args:
            tgt: 输入目标语言句子
            memory: 编码器的内部表示(memory)
            tgt_mask: 目标语言掩码
        Returns:
            解码后的输出
        """
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

文本标记通过使用标记嵌入表示。位置编码被添加到标记嵌入中以引入词序的概念。

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        """
        位置编码模块,用于向标记嵌入添加位置信息。
        
        参数:
            emb_size (int): 标记嵌入的大小
            dropout: 丢弃率
            maxlen (int): 输入序列的最大长度(默认为 5000)
        """
        super(PositionalEncoding, self).__init__()
        
        # 初始化位置编码
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        # 初始化丢弃层
        self.dropout = nn.Dropout(dropout)
        
        # 将位置编码注册为缓冲
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        """
        PositionalEncoding 模块的前向传播。
        
        参数:
            token_embedding (Tensor): 输入的标记嵌入
        
        返回:
            Tensor: 添加位置编码后的标记嵌入
        """
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        """
        标记嵌入模块,用于将标记索引转换为嵌入。
        
        参数:
            vocab_size (int): 词汇表大小
            emb_size (int): 标记嵌入的大小
        """
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        """
        TokenEmbedding 模块的前向传播。
        
        参数:
            tokens (Tensor): 输入的标记索引
        
        返回:
            Tensor: 经过嵌入层处理后的标记嵌入
        """
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

我们创建一个后续单词掩码来阻止目标单词关注它的后续单词。我们还创建遮罩,用于屏蔽源和目标填充令牌

# 定义一个函数生成方形的后续遮罩 mask
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# 创建用于Transformer模型的mask
def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # 生成后续遮罩
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    
    # 创建源语言mask
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    # 创建用于填充的源语言和目标语言mask
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Define model parameters and instantiate model. 这里我们服务器实在是计算能力有限,按照以下配置可以训练但是效果应该是不行的。如果想要看到训练的效果请使用你自己的带GPU的电脑运行这一套代码。

当你使用自己的GPU的时候,NUM_ENCODER_LAYERS 和 NUM_DECODER_LAYERS 设置为3或者更高,NHEAD设置8,EMB_SIZE设置为512。

# 定义超参数
SRC_VOCAB_SIZE = len(ja_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 16

# 初始化 Transformer 模型
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                                  EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                  FFN_HID_DIM)

# 对模型参数进行初始化
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# 将模型移动到设备(device)
transformer = transformer.to(device)

# 定义交叉熵损失函数,忽略填充部分
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# 定义优化器
optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)

# 训练一个epoch的函数
def train_epoch(model, train_iter, optimizer):
    model.train()
    losses = 0
    for idx, (src, tgt) in enumerate(train_iter):
        src = src.to(device)
        tgt = tgt.to(device)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,
                       src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    return losses / len(train_iter)

# 评估函数
def evaluate(model, val_iter):
    model.eval()
    losses = 0
    for idx, (src, tgt) in enumerate(valid_iter):
        src = src.to(device)
        tgt = tgt.to(device)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,
                       src_padding_mask, tgt_padding_mask, src_padding_mask)
        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(val_iter)

七、开始训练

最后,在准备好必要的类和函数之后,我们准备训练我们的模型。这是不言而喻的,但是完成训练所需的时间可能会有很大的不同,这取决于很多事情,比如计算能力、参数和数据集的大小。

当我使用JParaCrawl(每种语言大约有590万个句子)的完整句子列表来训练模型时,使用单个NVIDIA GeForce RTX 3070 GPU,每个epoch大约需要5个小时。

代码如下:

# 遍历每个epoch
for epoch in tqdm.tqdm(range(1, NUM_EPOCHS+1)):
    start_time = time.time()  # 记录当前epoch开始时间
    train_loss = train_epoch(transformer, train_iter, optimizer)  # 在训练集上训练模型
    end_time = time.time()  # 记录当前epoch结束时间

    # 打印训练结果
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
           f"Epoch time = {(end_time - start_time):.3f}s"))

八、试着用训练好的模型翻译一个日语句子

首先,我们创建翻译新句子的函数,包括获取日语句子、标记化、转换为张量、推理,然后将结果解码回句子,但这次是英语。

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    """
    使用greedy decode方法生成目标序列
    Args:
    - model: Transformer模型
    - src: 源序列
    - src_mask: 源序列mask
    - max_len: 生成序列的最大长度
    - start_symbol: 起始符号

    Returns:
    - 生成的目标序列
    """
    src = src.to(device)
    src_mask = src_mask.to(device)
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
    """
    翻译源语言文本为目标语言文本
    Args:
    - model: Transformer模型
    - src: 源语言文本
    - src_vocab: 源语言词汇表
    - tgt_vocab: 目标语言词汇表
    - src_tokenizer: 源语言文本的tokenizer

    Returns:
    - 翻译后的目标语言文本
    """
    model.eval()
    tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
    num_tokens = len(tokens)
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")

然后,我们可以调用翻译函数并传递所需的参数。

translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)


trainen.pop(5)


trainja.pop(5)
' ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁'
'Chinese HS Code Harmonized Code System < HS编码 8515 : 电气(包括电热气体)、激光、其他光、光子束、超声波、电子束、磁脉冲或等离子弧焊接机器及装置,不论是否 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...'

'Japanese HS Code Harmonized Code System < HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)、レーザーその他の光子ビーム式、超音波式、電子ビーム式、 HS Code List (Harmonized System Code) for US, UK, EU, China, India, France, Japan, Russia, Germany, Korea, Canada ...'

九、保存Vocab对象和训练好的模型

最后,在训练完成后,我们将首先使用Pickle保存Vocab对象(en_vocab和ja_vocab)。

import pickle
# open a file, where you want to store the data
file = open('en_vocab.pkl', 'wb')
# dump information to that file
pickle.dump(en_vocab, file)
file.close()
file = open('ja_vocab.pkl', 'wb')
pickle.dump(ja_vocab, file)
file.close()

最后,我们还可以使用PyTorch保存和加载函数保存模型以供以后使用。通常,有两种保存模型的方法,这取决于我们以后想要使用它们的目的。第一个仅用于推理,我们可以稍后加载模型并使用它从日语翻译成英语。

# save model for inference
torch.save(transformer.state_dict(), 'inference_model')

第二个也是用于推理的,但当我们稍后想要加载模型并想要恢复训练时也是如此。

# save model + checkpoint to resume training later
torch.save({
  'epoch': NUM_EPOCHS,
  'model_state_dict': transformer.state_dict(),
  'optimizer_state_dict': optimizer.state_dict(),
  'loss': train_loss,
  }, 'model_checkpoint.tar')

  • 30
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值