GRU:门控循环单元
理论
解决长期记忆和反向传播中的梯度等问题
GRU是LSTM的一种变体,它比LSTM的结构更加简单,而且效果也很好
GRU只有两个门:更新门和重置门
GRU将单元状态与输出合并为一个状态h
z
t
z_t
zt和
r
t
r_t
rt分别表示更新门和重置门
更新门
z
t
z_t
zt
控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多。
z
t
=
σ
(
W
z
⋅
[
h
t
−
1
,
x
t
]
)
z_t=\sigma{(W_{z} \cdot{[h_{t-1},x_t]})}
zt=σ(Wz⋅[ht−1,xt])
重置门
r
t
r_t
rt
控制前一状态有多少信息被写入到当前状态
h
t
h_t
ht上,重置门越小,前一状态的信息被写入的越少。
r
t
=
σ
(
W
r
⋅
[
h
t
−
1
,
x
t
]
)
r_t=\sigma{(W_r\cdot{[h_{t-1},x_t]})}
rt=σ(Wr⋅[ht−1,xt])
h
~
t
=
t
a
n
h
(
W
⋅
[
r
t
∗
h
t
−
1
,
x
t
]
)
\tilde{h}_t=tanh(W\cdot{[r_t*h_{t-1},x_t]})
h~t=tanh(W⋅[rt∗ht−1,xt])
h
t
=
(
1
−
z
t
)
∗
h
~
t
+
z
t
∗
h
t
−
1
h_t=(1-z_t)*\tilde{h}_t+z_t*h_{t-1}
ht=(1−zt)∗h~t+zt∗ht−1
LSTM与GRU对比
- GRU的参数更少,因而训练稍快或需要更少的数据来泛化。
- 如果你有足够的数据,LSTM的强大表达能力可能产生更好结果。
实践
从零实现GRU
class My_GRU(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.gates = nn.Linear(input_size+hidden_size, hidden_size*2)
# 用于计算candidate hidden state
self.hidden_transform = nn.Linear(input_size+hidden_size, hidden_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.output = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, output_size)
)
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
def forward(self, x):
batch_size = x.size(0)
seq_len = x.size(1)
h = torch.zeros(batch_size, self.hidden_size).to(x.device)
y_list = []
for i in range(seq_len):
update_gate, reset_gate = self.gates(torch.cat([x[:, i, :], h], dim=-1)).chunk(2, -1)
update_gate, reset_gate = (self.sigmoid(gate) for gate in (update_gate, reset_gate))
candidate_hidden = self.tanh(self.hidden_transform(torch.cat([x[:, i, :], reset_gate * h], dim=-1)))
h = (1-update_gate) * h + update_gate * candidate_hidden
y_list.append(self.output(h))
return torch.stack(y_list, dim=1), h
Pytorch实现GRU
参数
输入
输出
gru = nn.GRU(input_size=input_size,hidden_size=hidden_size,num_layers=1).to(device)
out_linear = nn.Sequential(nn.Linear(hidden_size, 1),nn.LeakyReLU()).to(device)
optimizer = torch.optim.Adam(list(gru.parameters()) + list(out_linear.parameters()), lr)