# 逐行实现GRU网络
import torch
import torch.nn as nn
def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
prev_h = initial_states
bs, T, i_size = input.shape
h_size = w_ih.shape[0] // 3
# 对权重扩维,复制成batch_size倍
batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)
batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)
output = torch.zeros(bs, T, h_size) #GRU网络输出状态序列
for t in range(T):
x = input[:,t,:] #t时刻GRU cell的输入特征向量
w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1)) #[bs, 3*h_size]
w_times_x = w_times_x.squeeze(-1) # [bs, 3*h_size]
w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))
w_times_h_prev = w_times_h_prev.squeeze(-1)
r_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:,: h_size] + b_ih[:h_size] + b_hh[:h_size]) #重置门
z_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size] + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size]) #更新门
n_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size] + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size]) #候选状态
prev_h = (1-z_t)*n_t + z_t*prev_h #增量更新得到当前时刻最新隐含状态
output[:, t, :] = prev_h
return output, prev_h
# 测试函数正确性
bs, T, i_size, h_size=2,3,4,5
input = torch.randn(bs, T, i_size) #输入序列
h0 = torch.randn(bs, h_size)
# 调用Pytorch官方的GRU API
gru_layer = nn.GRU(i_size, h_size, batch_first=True)
output, h_final = gru_layer(input, h0.unsqueeze(0))
print(output)
# for k,v in gru_layer.named_parameters():
# print(k,v.shape)
# 调用自定义的gru_forward函数
output_custom,h_final_custom = gru_forward(input, h0, gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)
print(output_custom)
8.15笔记,GRU
于 2023-08-15 17:04:19 首次发布