seq2seq原理
1.seq2seq = encoder + decoder
encoder负责对输入句子的理解,转化为context vector,decoder负责对理解后的句子的向量进行处理,解码,获得输出。
在encoder的过程中得到的context vector作为decoder的输入,那么这样一个输入,怎么能够得到多个输出呢?
就是将当前一步的输出,作为下一个单元的输入,然后得到结果
outputs = []
while True:
output = decoder(output)
outputs.append(output)
循环什么时候停止?
在训练数据集中,可以再输入的最后面添加一个结束符<END>,如果遇到该结束符,则可以终止循环
outputs = []
while outputs != "<END>":
output = decoder(output)
outputs.append(output)
这个结束符只是一个标记,很多人也会使用<EOS>(end of sentence)
总之:seq2seq模型中的encoder接受一个长度为M的序列,得到一个context vector,之后decoder把这个context vector转化为长度为N的序列作为输出,从而构成一个M to N的模型,能够处理很多不定长输入输出的问题。
seq2seq模型的实现
需求:完成一个模型,实现往模型输入一串数字,输出这串数字+0
如:输入123456 输出1234560
实现流程:
1.文本转化为序列(数字序列,torch.LongTensor)
2.使用序列,准备数据集,准备Dataloader
3.完成编码器
4.完成解码器
5.完成seq2seq模型
6.完成模型训练的逻辑,进行训练
7.完成模型评估的逻辑,进行模型评估
第一步代码实现:该部分在文本序列化中讲过 这里复习一遍
import config
class Num_sequence():
PAD_TAG = "PAD"
PAD = 0
UNK_TAG = "UNK"
UNK = 1
SOS_TAG = "SOS" # start of sequence
EOS_TAG = "EOS" # end of seq
SOS = 2
EOS = 3
def __init__(self):
self.dict = {self.PAD_TAG: self.PAD,
self.UNK_TAG: self.UNK,
self.SOS_TAG: self.SOS,
self.EOS_TAG: self.EOS}
for i in range(10):
self.dict[str(i)] = len(self.dict)
self.inverse_dict = dict(zip(self.dict.values(), self.dict.keys()))
def transform(self, sentence, max_len, add_eos=False):
""" 把sentence转化为数字序列"""
if len(sentence) > max_len: # 句子长度比max_len长时
sentence = sentence[:max_len] # 保留前部分
sentence_len = len(sentence)
if add_eos:
sentence = sentence+[self.EOS_TAG]
if len(sentence) < max_len: # PAD
sentence = sentence + [self.PAD_TAG]*(max_len-sentence_len)
result = [self.dict.get(i, self.UNK) for i in sentence]
return result
def inverse_transform(self, indices):
""" 把序列转化为数字(字符串)"""
return [self.inverse_dict.get(i, self.UNK_TAG) for i in indices]
if __name__ == '__main__':
num_sequence = Num_sequence()
print(num_sequence.dict)
第二步 代码实现 该部分较为常见 值得注意的是collate_fn部分 有个batch排序的操作 然后将四个目标值 zip* 起来 zip*就是把相同位置对应的值一同返回到一个列表中 对应每一个batch中包含的内容 然后转化为LongTensor
"""
准备dataset 和 dataloader
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import numpy as np
import config
class NumDataset(Dataset):
def __init__(self):
# 使用numpy随机创造一批数
self.data = np.random.randint(0, 1e8, size=[500000])
def __getitem__(self, index):
input = list(str(self.data[index]))
label = input + ["0"]
input_length = len(input)
label_length = len(label)
return input, label, input_length, label_length
def __len__(self):
return self.data.shape[0]
def collate_fn(batch):
"""
把batch里的值按照某一个值排序
:param batch: [input, label, input_length, label_length,...input, label, input_length, label_length]
:return:
"""
batch = sorted(batch, key=lambda x:x[3], reverse=True)
input, target, input_length, target_length = list(zip(*batch))
# 把input转化为序列
input = torch.LongTensor([config.num_sequence.transform(i, max_len=config.max_len) for i in input])
target = torch.LongTensor([config.num_sequence.transform(i, max_len=config.max_len+1) for i in target])
input_length = torch.LongTensor(input_length)
target_length = torch.LongTensor(target_length)
return input, target, input_length, target_length
train_data_loader = DataLoader(NumDataset(), batch_size=config.train_batch_size, shuffle=True, collate_fn=collate_fn)
if __name__ == '__main__':
# num_dataset = NumDataset()
# print(num_dataset[0])
# print(num_dataset.data[:10])
# print(len(num_dataset))
for input, target, input_length, target_length in train_data_loader:
print(input)
print(target)
print(input_length)
print(target_length)
break
运行结果 34567见下篇