B站 刘二大人 传送门 循环神经网络(基础篇)
课件链接:https://pan.baidu.com/s/1vZ27gKp8Pl-qICn_p2PaSw
提取码:cxe4
本节模型为将输入“hello”训练输出为“ohlol”,用循环神经网络实现。本节老师讲了cell,rnn和embedding三种简单模型,为方便测试,我给每个模型分别定义了函数。下面上开始的数据处理思路图和代码。
搭配视频学习效果最佳。
首先是注释内容和数据集准备代码块。
'''
训练RNN模型使得 "hello" -> "ohlol"
输入为"hello",可设置字典 e -> 0 h -> 1 l -> 2 o -> 3 hello对应为 10223 one-hot编码有下面对应关系
h 1 0100 o 3
e 0 1000 h 1
l 2 0010 l 2
l 2 0010 o 3
o 3 0001 l 2
输入有“helo”四个不同特征于是input_size = 4
hidden_size = 4 batch_size = 1
RNN模型维度的确认至关重要:
rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size,num_layers=num_layers)
outputs, hidden_outs = rnn(inputs, hiddens):
inputs of shape 𝑠𝑒𝑞𝑆𝑖𝑧𝑒, 𝑏𝑎𝑡𝑐ℎ, 𝑖𝑛𝑝𝑢𝑡_𝑠𝑖𝑧𝑒
hiddens of shape 𝑛𝑢𝑚𝐿𝑎𝑦𝑒𝑟𝑠, 𝑏𝑎𝑡𝑐ℎ, ℎ𝑖𝑑𝑑𝑒𝑛_𝑠𝑖𝑧𝑒
outputs of shape 𝑠𝑒𝑞𝑆𝑖𝑧𝑒, 𝑏𝑎𝑡𝑐ℎ, ℎ𝑖𝑑𝑑𝑒𝑛_𝑠𝑖𝑧𝑒
hidden_outs of shape 𝑠𝑒𝑞𝑆𝑖𝑧𝑒, 𝑏𝑎𝑡𝑐ℎ, ℎ𝑖𝑑𝑑𝑒𝑛_𝑠𝑖𝑧𝑒
cell = torch.nn.RNNcell(input_size=input_size, hidden_size=hidden_size)
output, hidden_out = cell(input, hidden):
input of shape 𝑏𝑎𝑡𝑐ℎ, 𝑖𝑛𝑝𝑢𝑡_𝑠𝑖𝑧𝑒
hidden of shape 𝑏𝑎𝑡𝑐ℎ, ℎ𝑖𝑑𝑑𝑒𝑛_𝑠𝑖𝑧𝑒
output of shape 𝑏𝑎𝑡𝑐ℎ, ℎ𝑖𝑑𝑑𝑒𝑛_𝑠𝑖𝑧𝑒
hidden_out of shape 𝑏𝑎𝑡𝑐ℎ, ℎ𝑖𝑑𝑑𝑒𝑛_𝑠𝑖𝑧𝑒
其中,seqSize:输入个数 batch:批量大小 input_size:特征维数 numLayers:网络层数 hidden_size:隐藏层维数
'''
import torch
idx2char = ['e', 'h', 'l', 'o'] #方便最后输出结果
x_data = [1, 0, 2, 2, 3] #输入向量
y_data = [3, 1, 2, 3, 2] #标签
one_hot_lookup = [ [1, 0, 0, 0], #查询ont hot编码 方便转换
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0