文章目录
RNN
网络结构
一个典型的单层RNN网络,结果如下所示。其实很简单就是一个神经元(蓝色模块),针对时序输入信号分别输入网络中,网络再每个时序进行输出,其中通过一个状态量记录时序信息。
多层的RNN,上一次的输出ht是下一层的输入。
关键公式
RNN的关键特点隐层状态变量 H t \boldsymbol{H}_{t} Ht,用于存储时序信息。
H t = ϕ ( X t W x h + H t − 1 W h h + b h ) \boldsymbol{H}_{t}=\phi\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x h}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h h}+\boldsymbol{b}_{h}\right) Ht=ϕ(XtWxh+Ht−1Whh+bh)
O t = H t W h q + b q \boldsymbol{O}_{t}=\boldsymbol{H}_{t} \boldsymbol{W}_{h q}+\boldsymbol{b}_{q} Ot=HtWhq+bq
多数网络会直接所有时序的 H \boldsymbol{H} H作为输出
网络维度
输入数据 X \boldsymbol{X} X维度表示[seq len, batch, h dim] =>[序列长度,batch, 序列的表示维度]=> [单词数, 句子数, 词维度]
刚开始接触RNN会觉得这个输入形式有点怪,
为什么batch维不是在0维度上
。这里需要理解RNN的运算特点,RNN网络每次送入的数据 X t \boldsymbol{X}_{t} Xt维度恰好就是[b, h dim],循环送入seq len次。
X t \boldsymbol{X}_{t} Xt:维度[bath, input_dim]
W x h \boldsymbol{W}_{x h} Wxh:维度[input_dim, hidden_len]
H t − 1 \boldsymbol{H}_{t-1} Ht−1:维度[batch, hidden_len]
W h h \boldsymbol{W}_{h h} Whh:维度[hidden_len, hidden_len]
b h \boldsymbol{b}_{h} bh:维度[hidden_len]
H t \boldsymbol{H}_{t} Ht:维度[batch, hidden_len]
输出层
O t \boldsymbol{O}_{t} Ot:维度[batch, out_dim]
W h q \boldsymbol{W}_{h q} Whq:维度[hidden_len, out_dim]
b q \boldsymbol{b}_{q} bq:维度[out_dim]
nn.RNN
构建一个 h_dim=10, hidden_len=20的2层RNN网络,同时我们打印出网络中的权重和偏差,以及shape
import torch
from torch import nn
from torch.nn import functional as F
rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2)
print(rnn._parameters.keys())
print(rnn.weight_ih_l0.shape)
print(rnn.weight_hh_l0.shape)
print(rnn.bias_ih_l0.shape)
print(rnn.bias_hh_l0.shape)
odict_keys([‘weight_ih_l0’, ‘weight_hh_l0’, ‘bias_ih_l0’, ‘bias_hh_l0’, ‘weight_ih_l1’, ‘weight_hh_l1’, ‘bias_ih_l1’, ‘bias_hh_l1’])
torch.Size([20, 10])
torch.Size([20, 20])
torch.Size([20])
torch.Size([20])
结构上每一层四组参数,即两组W和b,没有最后输出层O的操作,故没有权重和偏差
。
可见Torch的RNN模型与表达式 H t = ϕ ( X t W x h + H t − 1 W h h + b h ) \boldsymbol{H}_{t}=\phi\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x h}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h h}+\boldsymbol{b}_{h}\right) Ht=ϕ(XtWxh+Ht−1Whh+bh)中的W和b完全对应,其中末尾的l0和l1表示layer层数编号。维度也刚好符合之前的推导。
- weight_ih_l0 如果对应