本文来源: PyTorch官方教程
主体框架包括以下几个部分:
data.py: 负责数据预处理,包含字符切割、转换为token等;
model.py: 负责模型构建;
main.py: 主要脚本,负责训练模型;
generate.py: 负责用训练好的模型生成新文本。
以下对每个脚本中的代码进行详细解释:
data.py中包含两个主要类:
Dictionary和Corpus(语料库)
第一个类Dictionary负责构建word与index之间的转换关系
import os
from io import open
import torch
class Dictionary(object):
def __init__(self):
self.word2idx = {
} # 用于将字符转换为index
self.idx2word = [] # 用于将index转换为字符 (生成文本时使用)
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word) # 把word添加进列表末端
self.word2idx[word] = len(self.idx2word) - 1 # 生成word2idx字典,index为idx2word列表中的序号
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
第二个类定义了语料库:
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary() #语料库的字典,包含所有训练集中的字符
self.train = self.tokenize(os.path.join(path, 'train.txt')) # 调用train数据集,并同时做tokenize
self.valid = self.tokenize(os.path.join(path, 'valid.txt')) # 调用valid数据集,并同时做tokenize
self.test = self.tokenize(os.path.join(path, 'test.txt')) # 调用test数据集,并同时做tokenize
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path) # 检查文件路径存在
# Add words to the dictionary
with open(path,