GRU
RNN存在的问题:梯度较容易出现衰减或爆炸(BPTT)
⻔控循环神经⽹络:捕捉时间序列中时间步距离较⼤的依赖关系
RNN:
H t = ϕ ( X t W x h + H t − 1 W h h + b h ) H_{t} = ϕ(X_{t}W_{xh} + H_{t-1}W_{hh} + b_{h}) Ht=ϕ(XtWxh+Ht−1Whh+bh)
GRU:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) H ~ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t R_{t} = σ(X_tW_{xr} + H_{t−1}W_{hr} + b_r)\\ Z_{t} = σ(X_tW_{xz} + H_{t−1}W_{hz} + b_z)\\ \widetilde{H}_t = tanh(X_tW_{xh} + (R_t ⊙H_{t−1})W_{hh} + b_h)\\ H_t = Z_t⊙H_{t−1} + (1−Z_t)⊙\widetilde{H}_t Rt=σ(XtWxr+Ht−1Whr+br)Zt=σ(XtWxz+Ht−1Whz+bz)H
t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh)Ht=Zt⊙Ht−1+(1−Zt)⊙H
t
• 重置⻔有助于捕捉时间序列⾥短期的依赖关系;
• 更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。
载入数据集
import os
os.listdir('/home/kesci/input')
['d2lzh1981', 'houseprices2807', 'jaychou_lyrics4703', 'd2l_jay9460']
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F