【NLP】GRU基本结构原理,代码实现

LSTM变种GRU

GRU是LSTM改进的门控循环神经网络,将输入门,遗忘门,输出门变成更新门和重置门。

将细胞状态和隐藏状态合并,只有当前时刻候选状态和当前时刻隐藏状态。

【NLP】LSTM结构,原理,代码实现,序列池化-CSDN博客

模型结构

在这里插入图片描述

内部结构
在这里插入图片描述

相较于LSTM,GRU的结构更加简洁,参数更少,计算效率更高

可以类比LSTM理解GRU,同样都是门控机制

重置门

在这里插入图片描述

决定了保留多上一个时间步的信息和当前的信息合并输入

候选门

在这里插入图片描述

最终隐藏状态

在这里插入图片描述

代码实现

原生代码实现

import numpy as np

class GRU():
    def __init__(self,input_size,hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 初始化权重参数
        # 跟新门
        self.W_z = np.random.randn(hidden_size,hidden_size+input_size)
        self.b_z = np.zeros(hidden_size)

        # 重置门
        self.W_r = np.random.randn(hidden_size,hidden_size+input_size)
        self.b_r = np.zeros(hidden_size)

        # 候选隐藏状态
        self.W_h = np.random.randn(hidden_size,hidden_size+input_size)
        self.b_h = np.zeros(hidden_size)

    def tanh(self,x):
        return np.tanh(x)

    def sigmoid(self,x):
        return 1/(1+np.exp(-x))

    def forward(self,x):
        # 初始化隐藏状态
        h_prev = np.zeros((self.hidden_size,))
        concat_input = np.concatenate([x,h_prev],axis=0)

        z_t = self.sigmoid(np.dot(self.W_z,concat_input)+self.b_z)
        r_t = self.sigmoid(np.dot(self.W_r,concat_input)+self.b_r)

        concat_reset_input = np.concatenate([x,r_t*h_prev],axis=0)
        h_hat_t = self.tanh(np.dot(self.W_h,concat_reset_input)+self.b_h)

        h_t = (1-z_t)*h_prev + z_t*h_hat_t

        return h_t

# 测试数据
input_size = 3
hidden_size = 2
seq_len = 4
x = np.random.randn(seq_len,input_size)

gru = GRU(input_size, hidden_size)
all_h = []
for t in range(seq_len):
    h_t = gru.forward(x[t,:])
    all_h.append(h_t)
    print(h_t.shape)

all_h = np.array(all_h)
print(all_h.shape)

基于PyTorch的GURcell

import torch
import torch.nn as nn
import numpy as np

class GRUcell(nn.Module):
    def __init__(self,input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.gru_cell = nn.GRUCell(input_size,hidden_size)

    def forward(self,x):
        h_t = self.gru_cell(x)
        return h_t


input_size = 3
hidden_size = 2
seq_len = 2

x = torch.randn(seq_len,input_size)
grucell = GRUcell(input_size, hidden_size)
for t in range(seq_len):
    out = grucell(x[t])
    print(out)

基于PyTorch的GRUapi实现

import torch
import torch.nn as nn

class GRU(nn.Module):
    def __init__(self,input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.gru = nn.GRU(input_size,hidden_size)

    def forward(self,x):
        out,_ = self.gru(x)

        return out

input_size = 3
hidden_size = 2
seq_len = 4
bach_size = 5

x = torch.randn(seq_len,bach_size,input_size)

gru = GRU(input_size,hidden_size)
out = gru(x)
print(out)
print(out.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值