Transformer模型训练代码实现及详解

本文来源: PyTorch官方教程

主体框架包括以下几个部分:
data.py: 负责数据预处理,包含字符切割、转换为token等;
model.py: 负责模型构建;
main.py: 主要脚本,负责训练模型;
generate.py: 负责用训练好的模型生成新文本。

以下对每个脚本中的代码进行详细解释:
data.py中包含两个主要类:
Dictionary和Corpus(语料库)

第一个类Dictionary负责构建word与index之间的转换关系

import os
from io import open
import torch

class Dictionary(object):
    def __init__(self):
        self.word2idx = {
   }		# 用于将字符转换为index
        self.idx2word = []		# 用于将index转换为字符 (生成文本时使用)

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)	# 把word添加进列表末端
            self.word2idx[word] = len(self.idx2word) - 1	# 生成word2idx字典,index为idx2word列表中的序号
       
       
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)

第二个类定义了语料库:

class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()	#语料库的字典,包含所有训练集中的字符
        self.train = self.tokenize(os.path.join(path, 'train.txt'))	# 调用train数据集,并同时做tokenize
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))	# 调用valid数据集,并同时做tokenize
        self.test = self.tokenize(os.path.join(path, 'test.txt'))	# 调用test数据集,并同时做tokenize

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)		# 检查文件路径存在
        # Add words to the dictionary
        with open(path,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值