python实现GRU并使用torch.nn.GRU验证正确性

这篇博客展示了如何使用Python基础语法和PyTorch库手动实现门控循环单元(GRU),并将其结果与PyTorch内置的GRU模块进行对比。通过allclose函数验证,两者计算结果一致,证明手动实现的GRU是正确的。
摘要由CSDN通过智能技术生成

用python基本语法以及torch的一些基本数据结构对GRU进行实现如下:

import torch
import torch.nn as nn

bs,T,i_size,h_size=2,3,4,5

input=torch.rand(bs,T,i_size)

h_0=torch.rand(bs,h_size)

gru=nn.GRU(i_size,h_size,batch_first=True)
output,h_n=gru(input,h_0.unsqueeze(0))
print(output)
print(h_n)

def gru_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    bs,T,i_size=input.shape
    h_size=initial_states.shape[-1]
    h_0=initial_states

    batch_w_ih=w_ih.unsqueeze(0).tile(bs,1,1)
    batch_w_hh=w_hh.unsqueeze(0).tile(bs,1,1)


    prev_h=h_0

    for t in range(T):
        x=input[:,t,:]
        w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1)).squeeze(-1)
        w_times_h=torch.bmm(batch_w_hh,prev_h.unsqueeze(-1)).squeeze(-1)
        r=torch.sigmoid(w_times_x[:,:h_size]+w_times_h[:,:h_size]+b_ih[:h_size]+b_hh[:h_size])
        z=torch.sigmoid(w_times_x[:,h_size:2*h_size]+w_times_h[:,h_size:2*h_size]+\
            b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size])
        n=torch.tanh(w_times_x[:,2*h_size:3*h_size]+b_ih[2*h_size:3*h_size]+\
            r*(w_times_h[:,2*h_size:3*h_size]+b_hh[2*h_size:3*h_size]))
        prev_h=(1-z)*n+z*prev_h
        output[:,t,:]=prev_h
    return output,prev_h

output_custom,h_n_custom=gru_forward(input,h_0,gru.weight_ih_l0,gru.weight_hh_l0,gru.bias_ih_l0,gru.bias_hh_l0)

print(torch.allclose(output,output_custom))
print(torch.allclose(h_n,h_n_custom))

结果如下:
在这里插入图片描述
最后两个allclose函数返回的都是TRUE,说明计算结果一致

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值