《PyTorch深度学习实践》12 RNN基础_使用RnnCell构造RNN

该博客通过实践代码介绍了如何使用PyTorch实现RNN,包括RNNCell的构成,参数解释,数据预处理,模型设计,损失函数与优化器的设置,并展示了训练过程及输出结果。博客详细讲解了RNN在处理序列数据时的工作原理。
摘要由CSDN通过智能技术生成

1. 说明

本系列博客记录B站课程《PyTorch深度学习实践》的实践代码课程链接请点我

2. 知识点

(1)RNN由多个RnnCell组成,RnnCell中是由线性层组成,且每个RnnCell是一摸一样的,即同一个RnnCell.
在这里插入图片描述
在这里插入图片描述
(2)对参数的理解

input_size = 4      # 输入的维度,例如我们输入为:hello, 看起来是5个,但只有4个不同的字符,按下面顺序,[1, 0, 0, 0]即可表示e, 所以输入维度为4
hidden_size = 4     # 输出的隐藏单元,自定义
batch_size = 1      # 表示一次输入几句话,如:这里只有一句话hello; 如果一次输入两句话, hello, how are you 则batch_size=2

3. 代码

# ---------------------------
# @Time     : 2022/4/25 15:02
# @Author   : lcq
# @File     : 12_RNN_basic.py
# @Function : 
# ---------------------------
import torch

input_size = 4      # 输入的维度,例如我们输入为:hello, 看起来是5个,但只有4个不同的字符,按下面顺序,[1, 0, 0, 0]即可表示e, 所以输入维度为4
hidden_size = 4     # 输出的隐藏单元,自定义
batch_size = 1      # 表示一次输入几句话,如:这里只有一句话hello; 如果一次输入两句话, hello, how are you 则batch_size=2

# 1.准备数据
idx2char = ['e', 'h', 'l', 'o']
X = [1, 0, 2, 2, 3]             # 对应上面的idx2char为: hello
Y = [3, 1, 2, 3, 2]             # 表示最终训练的结果为: 0hlol
on_hot_lookup = [[1, 0, 0, 0],  # 对应e
                 [0, 1, 0, 0],  # 对应h
                 [0, 0, 1, 0],  # 对应l
                 [0, 0, 0, 1]]  # 对应o
X_one_hot = [on_hot_lookup[x] for x in X]
print("X_one_hot = ", X_one_hot)
# 转换为tensor
X_inputs = torch.Tensor(X_one_hot).view(-1, batch_size, input_size)     # 转换维度:torch.Size([5, 1, 4]),
                                                                        # 之所以需要转换是因为,RNN输入的维度为(seqLen, batchSize, inputSize)
Y_label = torch.LongTensor(Y).view(-1, 1)                                   # torch.Size([5, 1])


# 2. 模型设计
class Model(torch.nn.Module):
    def __init__(self, inputSize, hiddenSize, batchSize):
        super(Model, self).__init__()
        self.batchSize = batchSize
        self.inputSize = inputSize
        self.hiddenSize = hiddenSize
        self.rnnCell = torch.nn.RNNCell(input_size=self.inputSize,
                                        hidden_size=self.hiddenSize)

    def forward(self, input, hidden):
        hidden = self.rnnCell(input, hidden)
        return hidden

    def init_hidden(self):
        return torch.zeros(self.batchSize, self.hiddenSize)


# 3. 构建损失函数和优化器
model = Model(inputSize=input_size, hiddenSize=hidden_size, batchSize=batch_size)
Loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.1)


# 4. train
def train():
    for epoch in range(15):
        loss = 0
        optimizer.zero_grad()
        hidden = model.init_hidden()
        print("Predicted string: ", end="")
        for input, label in zip(X_inputs, Y_label):     # 遍历每个字符,输入是“hello”, 第一次循环为:h; 第二次循环为:e
            hidden = model.forward(input=input, hidden=hidden)
            loss += Loss(hidden, label)     # 这里不能写 Loss(hidden, label).item(),
                                            # 因为一个rnn是由多个rnnCell组成,需要把每个RnnCell都加起来,最后的loss.item()才算是损失
            _, idx = hidden.max(dim=1)
            print(idx2char[idx.item()], end="")
        loss.backward()
        optimizer.step()
        print(', Epoch [%d/15] loss=%.4f' % (epoch, loss.item()))


train()

4. 结果

Predicted string: lleee, Epoch [0/15] loss=8.4630
Predicted string: lllll, Epoch [1/15] loss=6.9278
Predicted string: lllll, Epoch [2/15] loss=6.2207
Predicted string: lllll, Epoch [3/15] loss=5.5153
Predicted string: lllll, Epoch [4/15] loss=5.0224
Predicted string: ollol, Epoch [5/15] loss=4.6996
Predicted string: ollol, Epoch [6/15] loss=4.3447
Predicted string: ohlol, Epoch [7/15] loss=3.9433
Predicted string: ohlol, Epoch [8/15] loss=3.6049
Predicted string: ohlol, Epoch [9/15] loss=3.4140
Predicted string: ohlol, Epoch [10/15] loss=3.3032
Predicted string: ohlol, Epoch [11/15] loss=3.1829
Predicted string: ohlol, Epoch [12/15] loss=3.0435
Predicted string: ohlol, Epoch [13/15] loss=2.9066
Predicted string: ohlol, Epoch [14/15] loss=2.7866
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值