8.15笔记,GRU

# 逐行实现GRU网络
import torch
import torch.nn as nn


def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    prev_h = initial_states
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 3

    # 对权重扩维,复制成batch_size倍
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)

    output = torch.zeros(bs, T, h_size) #GRU网络输出状态序列

    for t in range(T):
        x = input[:,t,:] #t时刻GRU cell的输入特征向量
        w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1)) #[bs, 3*h_size]
        w_times_x = w_times_x.squeeze(-1) # [bs, 3*h_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))
        w_times_h_prev = w_times_h_prev.squeeze(-1)

        r_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:,: h_size] + b_ih[:h_size] + b_hh[:h_size]) #重置门
        z_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size] + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size]) #更新门
        n_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size] + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size]) #候选状态
        prev_h = (1-z_t)*n_t + z_t*prev_h #增量更新得到当前时刻最新隐含状态
        output[:, t, :] = prev_h

    return output, prev_h

# 测试函数正确性
bs, T, i_size, h_size=2,3,4,5
input = torch.randn(bs, T, i_size) #输入序列
h0 = torch.randn(bs, h_size)

# 调用Pytorch官方的GRU API
gru_layer = nn.GRU(i_size, h_size, batch_first=True)
output, h_final = gru_layer(input, h0.unsqueeze(0))
print(output)
# for k,v in gru_layer.named_parameters():
#     print(k,v.shape)

# 调用自定义的gru_forward函数
output_custom,h_final_custom = gru_forward(input, h0, gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)
print(output_custom)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值