python-pytorch seq2seq+luong dot attention笔记1.0.0
可复用部分
主要将数据弄成如下格式:
seq_example = [“你认识我吗”, “你住在哪里”, “你知道我的名字吗”, “你是谁”, “你会唱歌吗”, “你是张学友吗”]
seq_answer = [“当然认识”, “我住在成都”, “我不知道”, “我是机器人”, “我不会”, “肯定不是”]
同时设定embedding_size 、vocab_size、 hidden_size、 seq_len,实现word2index、index2word、encoder_input、decoder_input、target_input
代码如下:
# def getAQ():
# ask=[]
# answer=[]
# with open("./data/flink.txt","r",encoding="utf-8") as f:
# lines=f.readlines()
# for line in lines:
# ask.append(line.split("----")[0])
# answer.append(line.split("----")[1].replace("\n",""))
# return answer,ask
# seq_answer,seq_example=getAQ()
import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdm
seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "你有父母吗"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "我没有父母"]
# 所有词
example_cut = []
answer_cut = []
word_all = []
# 分词
for i in seq_example:
example_cut.append(list(jieba.cut(i)))
for i in seq_answer:
answer_cut.append(list(jieba.cut(i)))
# 所有词
for i in example_cut + answer_cut:
for word in i:
if word not in word_all:
word_all.append(word)
# 词语索引表
word2index = {
w: i+3 for i, w in enumerate(word_all)}
# 补全
word2index['PAD'] = 0
# 句子开始
word2index['SOS'] = 1
# 句子结束
word2index['EOS'] = 2
index2word = {
value: key for key, value in word2index.items()}
# 一些参数
vocab_size = len(word2index)
seq_length = max([len(i) for i in example_cut + answer_cut]) + 1
print("vocab_size is",vocab_size,", seq_length is ",seq_length)
embedding_size = 128
num_classes = vocab_size
hidden_size = 256
batch_size=6
seq_len=