使用RNN模型构建字符串批量转换功能seq2seq

使用RNN Module构建的一个字符串转换功能:

import torch
import torch.optim as optim

class Model(torch.nn.Module):
    """
    RNN
    """
    def __init__(self, input_size, hidden_size, batch_size,num_layers):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        #反复使用rnncell, 权重共享
        self.rnn = torch.nn.RNN(
                input_size=self.input_size,
                hidden_size=self.hidden_size)
        

    def forward(self,input, **args):
        if 'batch_size' in args:
            self.batch_size = args['batch_size']
        hidden = torch.zeros(
                self.num_layers, 
                self.batch_size,
                self.hidden_size)
        out, _= self.rnn(input, hidden)
        return out.view(-1, self.hidden_size)



if __name__ == "__main__":

    num_layers = 1 # RNN层数

    #idx2char = ['e','h','l','o','n','a','b','c'] #构建词典
    idx2char = [chr(x) for x in range(ord('A'),ord('Z')+1)] + [chr(x) for x in range(ord('0'),ord('9')+1)] + ['+', '-', '*', '/', '=', ' ']

    input_size = len(idx2char) #输入序列每一元素的特征维度
    hidden_size = len(idx2char) #隐藏状态维度

    print(idx2char)

    #输入与标签数据
    #x_data = [1,0,5,2,2,3,2,2,4,5] #hellollnnaa
    #y_data = [3,1,4,2,3,2,3,3,5,4] #ohlolooaann

    x_data = ["xihuanliaojiexuexilehuatuan"]
    y_data = ["hifuanliaogaihaxolaofuatuen"]

    batch_size = len(x_data) #批次大小
    seq_len = len(x_data[0]) #每一批量的序列长度


    x_data = [idx2char.index(x) for item in x_data for x in item.upper() ]
    y_data = [idx2char.index(x)  for item in y_data for x in item.upper()]
    print(x_data, y_data)

    #词典转换为one-hot对照表
    one_hot_lookup = torch.diag(torch.ones(input_size,dtype=torch.int32))
    """
    one_hot_lookup = [
            [1,0,0,0,0,0],
            [0,1,0,0,0,0],
            [0,0,1,0,0,0],
            [0,0,0,1,0,0],
            [0,0,0,0,1,0],
            [0,0,0,0,0,1],
            ]
    #x_one_hot = [one_hot_lookup[x] for x in x_data]
    """

    x_one_hot = one_hot_lookup[x_data]
    print(x_one_hot)

    inputs = (x_one_hot.float()).view(seq_len, batch_size, input_size)
    labels = torch.LongTensor(y_data)


    model = Model(input_size, hidden_size, batch_size, num_layers)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.01)

    #测试过程
    for epoch in range(100):
        optimizer.zero_grad()#梯度数据重置
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        #反馈
        loss.backward()#反向传播
        #更新
        optimizer.step()#更新参数
        
        _,idx =outputs.max(dim=1)
        print("EPOCH: ", epoch+1, loss.item(), end=" ")
        print("Predicted String: ", end=" ")
        print("".join([idx2char[x] for x in idx]))


    batch_size = 1
    myinput = input("请输入你要转换的序列:")
    test_x_data = [idx2char.index(x) for x in myinput.upper()]

    #新数据
    with torch.no_grad(): #无需计算梯度
        x_one_hot = one_hot_lookup[test_x_data]
        inp = (x_one_hot.float()).view(len(test_x_data), batch_size, input_size)
        outputs = model(inp, **{"batch_size":batch_size})
        _,idx = outputs.max(dim=1)
        print("".join([idx2char[x] for x in idx]),end="")

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值