1.知识点
- GRU只有两个门--重置门、更新门
- 初始状态只需要提供h0
- ht里“ * ”这个符号,代表着逐个元素相乘,不是矩阵相乘;其他是矩阵相乘
-
2.使用pytorch看参数个数
-
import torch import torch.nn as nn lstm_layer = nn.LSTM(3,5) #输入特征为3,隐含特征为5的特征量 gru_layer = nn.GRU(3,5) #同样的隐含大小 sum(p.numel() for p in lstm_layer.paramenters()) #计算lstm总参数量,调用paramenters()函数,在对其进行参数枚举(p代表参数),p.numel()计算每个参数p的所有元素进行统计----》200 sum(p.numel() for p in gru_layer.paramenters()) ----->150 ## GRU的参数量是LSTM的0.75
3.GRU网络代码
-
#准备工作 def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh): #定义函数,前向运算 initial_states提供初始状态 w权重--表示大的矩阵 b_ih偏置项 prev_h = initial_states #h t=0时刻的初始值 bs, T, i_size = input.shape #对input进行拆解 h_size = w_ih.shape[0] // 3 #公式里每一个 h_size(隐含神经元个数) 配三个权重 #计算公式里的W ---由于目前提供的w是二维张量,而input和initial_states都是三维张量(带有batch这个维度)-----》对两个w进行扩充维度 #对权重扩维,复制成batch_size倍 batch_w_ih = w_ih.unsqueeze(0).tile(bs,1,1) #从第0维起扩一维,tile(bs,1,1)复制一下,第一维扩大到bs倍,其他保持不变 batch_w_hh = w_hh.unsqueeze(0).tile(bs,1,1) output = torch.zeros(bs, T, h_size) #对输出初始化 GRU网络的输出状态序列 #计算 for t in range(T): #for循环对每一时刻进行迭代更新 x=input(:,t,:) #step1:找到当前时刻的输入 t时刻GRU cell的输入特征向量,大小为[bs,i_size] w_time_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) #扩成3维---[bs,3*h_size,1] w_time_x = w_time_x.squeeze(-1) #.squeeze(-1)去掉维度为1的维度---[bs,3*h_size] w_time_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) w_time_h_prev = w_time_h_prev.squeeze(-1) #计算重置门和更新门 r_t = torch.sigmoid(w_time_x[:, :h_size]+w_time_h_prev[:, :h_size]+b_ih[:h_size]+b_hh[:h_size]) z_t = torch.sigmoid(w_time_x[:, h_size:2*h_size]+w_time_h_prev[:, h_size:2*h_size]+b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size]) #计算候选状态nt n_t = torch.tanh(w_time_x[:, 2*h_size:3*h_size]+b_ih[2*h_size:3*h_size]+\ r_t*(w_time_h_prev[:,2*h_size:3*h_size]+b_hh[2*h_size:3*h_size])) #候选状态 prev_h = (1-z_t)*n_t+z_t*prev_h #增量更新得到当前时刻最新隐含状态 output[:, t, :] = prev_h #把 prev_h 喂入到output矩阵中 return output, prev_h
4.测试函数的正确性
-
#测试函数的正确性--用pytprch的官方api测试 bs, T, i_size, h_size = 2, 3, 4, 5 input = torch.randn(bs, T, i_size) #输入序列 h0 = torch.randn(bs, h_size) #初始值,不需要训练 #定义GRU量 调用官方GRU API gru_layer = nn.GRU(i_size, h_size ,batch_first=True) output, h_final = gru_layer(input, h0.unsqueeze(0)) print(output) for k, v in gru_layer.named_parameters(): print(k, v.shape) #调用自定义的gru_forward函数 output_custom, h_final_custom = gru_forward(input,h0,gru_layer.weight_ih_10,gru_layer.weight_hh_10,gru_layer.bias_ih_10,gru_layer.bias_hh_10) print(torch.allclose(output,output_custom))#allclose对比两个浮点型张量是否非常接近 print(torch.allclose(h_final,h_final_custom_custom))
注:根据b站up讲解所得代码记录 课程31