"""test"""
importnumpy as npimporttorchimporttorch.nn as nnimporttorch.optim as optimfrom torch.autograd importVariable#创建字典
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
char_arr= [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dict= {n:i for i,n inenumerate(char_arr)}#网络参数
n_step = 5n_hidden= 128n_class=len(num_dict)
batch_size=len(seq_data)#准备数据
defmake_batch(seq_data):
input_batch, output_batch, target_batch=[], [], []for seq inseq_data:for i in range(2):
seq[i]= seq[i] + 'P' * (n_step-len(seq[i]))
input= [num_dict[n] for n inseq[0]]
ouput= [num_dict[n] for n in ('S'+ seq[1])]
target= [num_dict[n] for n in (seq[1]) + 'E']
input_batch.append(np.eye(n_class)[input])
output_batch.append(np.eye(n_class)[ouput])
target_batch.append(target)returnVariable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))
input_batch, output_batch, target_batch=make_batch(seq_data)#创建网络
classSeq2Seq(nn.Module):"""要点:
1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果
2.RNN网络结构要熟知
3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入"""
def __init__(self):
super().__init__()#此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度
self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
self.dec= nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
self.fc=nn.Linear(n_hidden, n_class)defforward(self, enc_input, enc_hidden, dec_input):#RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下
enc_input = enc_input.transpose(0,1)
dec_input= dec_input.transpose(0,1)
_, enc_states=self.enc(enc_input, enc_hidden)
outputs, _=self.dec(dec_input, enc_states)
pred=self.fc(outputs)returnpred#training
model =Seq2Seq()
loss_fun=nn.CrossEntropyLoss()
optimizer= optim.Adam(model.parameters(), lr=0.001)for epoch in range(5000):
hidden= Variable(torch.zeros(1, batch_size, n_hidden))
optimizer.zero_grad()
pred=model(input_batch, hidden, output_batch)
pred= pred.transpose(0, 1)
loss=0for i inrange(len(seq_data)):
temp=pred[i]
tar=target_batch[i]
loss+=loss_fun(pred[i], target_batch[i])if (epoch + 1) % 1000 ==0:print('Epoch: %d Cost: %f' % (epoch + 1, loss))
loss.backward()
optimizer.step()#测试
deftranslate(word):
input_batch, output_batch, _= make_batch([[word, 'P' *len(word)]])#hidden 形状 (1, 1, n_class)
hidden = Variable(torch.zeros(1, 1, n_hidden))#output 形状(6,1, n_class)
output =model(input_batch, hidden, output_batch)
predict= output.data.max(2, keepdim=True)[1]
decoded= [char_arr[i] for i inpredict]
end= decoded.index('E')
translated= ''.join(decoded[:end])return translated.replace('P', '')print('girl ->', translate('girl'))