新生学习编程的每一天

import torch
import torch.nn as nn
from torch.autograd import Variable
torch.manual_seed(777)
sentence = 'hihello'
x_str = sentence[:-1] # x: hihell
y_str = sentence[1:] # y: ihello
# 1. 获取字符集 去重
char_set = list(set(sentence))  # 例:['h', 'e', 'i', 'o', 'l']
# 2. 根据字符生成字典 例:{'h':0, 'i':1, 'e':2, 'i':3, 'l':4}
word2id = {w: i for i, w in enumerate(char_set)}
id2word = {i: w for i, w in enumerate(char_set)}
# 3. x,y 转码
x = [word2id[c] for c in x_str] # 例:[0,1,0,2,4,4]
y = [word2id[c] for c in y_str]
# 获取句长(时间步数)和词向量长(特征数)
seq_length = max(len(x), len(y)) #6
char_dim = len(char_set) #5
# 只对输入x进行one-hot处理,因为是多分类任务所以y保持一维整型
x_onehot = nn.functional.one_hot(torch.Tensor(x).to(torch.int64),
                                 num_classes=char_dim)
inputs = Variable(torch.Tensor(x_onehot).float())
labels = Variable(torch.Tensor(y)) # 标签只到转为整数编码形式即可
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=char_dim, hidden_size=char_dim,
                            num_layers=2, batch_first=True) #层的个数,
        self.flat = nn.Flatten()
        self.fc = nn.Linear(char_dim, char_dim)
    def forward(self, x):
        out, (h_out, c_out) = self.lstm(x)
        out = self.flat(out)
        out = self.fc(out)
        return out
lstm = LSTM()
# 声明多分类损失函数及优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=0.1)
# 开始循环训练
for epoch in range(200):
    optimizer.zero_grad()
    pred = lstm(inputs)
    loss = criterion(pred, labels.long())
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print("epoch: %d, loss: %1.3f" % (epoch + 1, loss.item()))
        # pred:(6,5) 每个输入字母在五个类别(i,h,e,l,o)上的预测概率
        idx = torch.argmax(pred, 1).data.numpy()
        # 对pred求每一行的最大值索引idx ,并转为字符
        result_str = [id2word[c] for c in idx]
        print("Predicted string: ", ''.join(result_str))
print("Learning finished!")
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值