李沐-GRU的从0开始实现


#coding=utf-8
import torch
from torch import nn
from d2l import torch as d2l

batch_size=32
num_steps=35
train_iter,vocab=d2l.load_data_time_machine(batch_size,num_steps)

#初始化模型参数
def get_params(vocab_size,num_hiddens,decice):
    num_inputs=num_outs=vocab_size

    # 正态随机分布张量
    def normal(shape):
        return torch.randn(size=shape,device=decice)*0.01

    def three():
        return (normal((num_inputs,num_hiddens)),#正态随机分布张量,调用上面函数
                normal((num_hiddens,num_hiddens)),
                torch.zeros(num_hiddens,device=decice))

    W_xz,W_hz,b_z = three() # 更新门参数
    W_xr,W_hr,b_r = three() # 重置门参数
    W_xh,W_hh,b_h = three() # 候选状态参数

    #输出层参数
    W_hq=normal((num_hiddens,num_outs))
    b_q=torch.zeros(num_outs,device=decice)

    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    #进行梯度更新
    for param in params:
        param.requires_grad_(True)
    return params

#定义隐藏状态的初始化函数
def init_gru_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)

#定义GRU模型
def gru(inputs,state,params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q=params
    #解包操作,将state元祖解开包复制给H
    H,=state
    #列表
    outputs=[]
    for X in inputs:
        #@是点乘,以下四个是门控循环定义公式
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        #含隐藏状态循环神经网络的输出
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

#训练和预测
vocab_size=len(vocab)
nun_hiddens=256
device=d2l.try_gpu()

num_epochs=500
lr=1
net=d2l.RNNModelScratch(len(vocab),nun_hiddens,device,get_params,init_gru_state,gru)
d2l.train_ch8(net,train_iter,vocab,lr,num_epochs,device)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值