seq2seq php,pytorch seq2seq模型示例

"""test"""

importnumpy as npimporttorchimporttorch.nn as nnimporttorch.optim as optimfrom torch.autograd importVariable#创建字典

seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

char_arr= [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']

num_dict= {n:i for i,n inenumerate(char_arr)}#网络参数

n_step = 5n_hidden= 128n_class=len(num_dict)

batch_size=len(seq_data)#准备数据

defmake_batch(seq_data):

input_batch, output_batch, target_batch=[], [], []for seq inseq_data:for i in range(2):

seq[i]= seq[i] + 'P' * (n_step-len(seq[i]))

input= [num_dict[n] for n inseq[0]]

ouput= [num_dict[n] for n in ('S'+ seq[1])]

target= [num_dict[n] for n in (seq[1]) + 'E']

input_batch.append(np.eye(n_class)[input])

output_batch.append(np.eye(n_class)[ouput])

target_batch.append(target)returnVariable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))

input_batch, output_batch, target_batch=make_batch(seq_data)#创建网络

classSeq2Seq(nn.Module):"""要点:

1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果

2.RNN网络结构要熟知

3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入"""

def __init__(self):

super().__init__()#此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度

self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)

self.dec= nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)

self.fc=nn.Linear(n_hidden, n_class)defforward(self, enc_input, enc_hidden, dec_input):#RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下

enc_input = enc_input.transpose(0,1)

dec_input= dec_input.transpose(0,1)

_, enc_states=self.enc(enc_input, enc_hidden)

outputs, _=self.dec(dec_input, enc_states)

pred=self.fc(outputs)returnpred#training

model =Seq2Seq()

loss_fun=nn.CrossEntropyLoss()

optimizer= optim.Adam(model.parameters(), lr=0.001)for epoch in range(5000):

hidden= Variable(torch.zeros(1, batch_size, n_hidden))

optimizer.zero_grad()

pred=model(input_batch, hidden, output_batch)

pred= pred.transpose(0, 1)

loss=0for i inrange(len(seq_data)):

temp=pred[i]

tar=target_batch[i]

loss+=loss_fun(pred[i], target_batch[i])if (epoch + 1) % 1000 ==0:print('Epoch: %d Cost: %f' % (epoch + 1, loss))

loss.backward()

optimizer.step()#测试

deftranslate(word):

input_batch, output_batch, _= make_batch([[word, 'P' *len(word)]])#hidden 形状 (1, 1, n_class)

hidden = Variable(torch.zeros(1, 1, n_hidden))#output 形状(6,1, n_class)

output =model(input_batch, hidden, output_batch)

predict= output.data.max(2, keepdim=True)[1]

decoded= [char_arr[i] for i inpredict]

end= decoded.index('E')

translated= ''.join(decoded[:end])return translated.replace('P', '')print('girl ->', translate('girl'))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值