用python基本语法以及torch的一些基本数据结构对GRU进行实现如下:
import torch
import torch.nn as nn
bs,T,i_size,h_size=2,3,4,5
input=torch.rand(bs,T,i_size)
h_0=torch.rand(bs,h_size)
gru=nn.GRU(i_size,h_size,batch_first=True)
output,h_n=gru(input,h_0.unsqueeze(0))
print(output)
print(h_n)
def gru_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
bs,T,i_size=input.shape
h_size=initial_states.shape[-1]
h_0=initial_states
batch_w_ih=w_ih.unsqueeze(0).tile(bs,1,1)
batch_w_hh=w_hh.unsqueeze(0).tile(bs,1,1)
prev_h=h_0
for t in range(T):
x=input[:,t,:]
w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1)).squeeze(-1)
w_times_h=torch.bmm(batch_w_hh,prev_h.unsqueeze(-1)).squeeze(-1)
r=torch.sigmoid(w_times_x[:,:h_size]+w_times_h[:,:h_size]+b_ih[:h_size]+b_hh[:h_size])
z=torch.sigmoid(w_times_x[:,h_size:2*h_size]+w_times_h[:,h_size:2*h_size]+\
b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size])
n=torch.tanh(w_times_x[:,2*h_size:3*h_size]+b_ih[2*h_size:3*h_size]+\
r*(w_times_h[:,2*h_size:3*h_size]+b_hh[2*h_size:3*h_size]))
prev_h=(1-z)*n+z*prev_h
output[:,t,:]=prev_h
return output,prev_h
output_custom,h_n_custom=gru_forward(input,h_0,gru.weight_ih_l0,gru.weight_hh_l0,gru.bias_ih_l0,gru.bias_hh_l0)
print(torch.allclose(output,output_custom))
print(torch.allclose(h_n,h_n_custom))
结果如下:
最后两个allclose函数返回的都是TRUE,说明计算结果一致