LSTM变种GRU
GRU是LSTM改进的门控循环神经网络,将输入门,遗忘门,输出门变成更新门和重置门。
将细胞状态和隐藏状态合并,只有当前时刻候选状态和当前时刻隐藏状态。
【NLP】LSTM结构,原理,代码实现,序列池化-CSDN博客
模型结构
内部结构
相较于LSTM,GRU的结构更加简洁,参数更少,计算效率更高
可以类比LSTM理解GRU,同样都是门控机制
重置门
决定了保留多上一个时间步的信息和当前的信息合并输入
候选门
最终隐藏状态
代码实现
原生代码实现
import numpy as np
class GRU():
def __init__(self,input_size,hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size
# 初始化权重参数
# 跟新门
self.W_z = np.random.randn(hidden_size,hidden_size+input_size)
self.b_z = np.zeros(hidden_size)
# 重置门
self.W_r = np.random.randn(hidden_size,hidden_size+input_size)
self.b_r = np.zeros(hidden_size)
# 候选隐藏状态
self.W_h = np.random.randn(hidden_size,hidden_size+input_size)
self.b_h = np.zeros(hidden_size)
def tanh(self,x):
return np.tanh(x)
def sigmoid(self,x):
return 1/(1+np.exp(-x))
def forward(self,x):
# 初始化隐藏状态
h_prev = np.zeros((self.hidden_size,))
concat_input = np.concatenate([x,h_prev],axis=0)
z_t = self.sigmoid(np.dot(self.W_z,concat_input)+self.b_z)
r_t = self.sigmoid(np.dot(self.W_r,concat_input)+self.b_r)
concat_reset_input = np.concatenate([x,r_t*h_prev],axis=0)
h_hat_t = self.tanh(np.dot(self.W_h,concat_reset_input)+self.b_h)
h_t = (1-z_t)*h_prev + z_t*h_hat_t
return h_t
# 测试数据
input_size = 3
hidden_size = 2
seq_len = 4
x = np.random.randn(seq_len,input_size)
gru = GRU(input_size, hidden_size)
all_h = []
for t in range(seq_len):
h_t = gru.forward(x[t,:])
all_h.append(h_t)
print(h_t.shape)
all_h = np.array(all_h)
print(all_h.shape)
基于PyTorch的GURcell
import torch
import torch.nn as nn
import numpy as np
class GRUcell(nn.Module):
def __init__(self,input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.gru_cell = nn.GRUCell(input_size,hidden_size)
def forward(self,x):
h_t = self.gru_cell(x)
return h_t
input_size = 3
hidden_size = 2
seq_len = 2
x = torch.randn(seq_len,input_size)
grucell = GRUcell(input_size, hidden_size)
for t in range(seq_len):
out = grucell(x[t])
print(out)
基于PyTorch的GRUapi实现
import torch
import torch.nn as nn
class GRU(nn.Module):
def __init__(self,input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size,hidden_size)
def forward(self,x):
out,_ = self.gru(x)
return out
input_size = 3
hidden_size = 2
seq_len = 4
bach_size = 5
x = torch.randn(seq_len,bach_size,input_size)
gru = GRU(input_size,hidden_size)
out = gru(x)
print(out)
print(out.shape)