从零实现循环神经网络(二)

#本篇博客代码是基于上一篇《从零实现循环神经网络(一)》 

上一篇网址:从零实现循环神经网络(一)-CSDN博客

1.初始化时返回隐藏层状态

def init_rnn_state(batch_size, num_hiddens, device):
    """
    batch_size:每批次中的数据量
    num_hiddens:隐藏层中神经元的数量
    device:数据被创建在哪一类型的设备上
    """
    #返回的隐藏状态是一个元组,可能有多个
    return (torch.zeros((batch_size, num_hiddens), device=device), )

2.封装函数:实现RNN主体结构

def rnn(inputs, state, params):
    """
    inputs:输入的数据,形状:(时间步数量,批次大小,词表大小)
    state:状态
    params:参数
    """
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state   #H, 表示接收元祖中返回的第一个元素数值
    outputs = []
    ## X的shape: [批次大小, 词表大小]
    for X in inputs:  #inputs的时间步数量被循环了
        # 一般在循环神经网络中激活函数用tanh,效果比relu好
        H = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(H, W_hh) + b_h)  #H的shape=(batch_size, num_hiddens)
        Y = torch.matmul(H, W_hq) + b_q   #Y的shape=(batch_size, num_outputs)
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, )

 3.将上述封装的函数包装成类

class RNNModelScratch():
    #__init__()定义初始化传入的参数
    def __init__(self, vocab_size, num_hiddens, device, get_params, init_rnn_state, forward_fn):
        """
        vocab_size = len(vacob)
        num_hiddens:隐藏层的神经元个数
        device:使用GPU还是CPU
        get_params:初始化模型参数的函数
        init_rnn_state:隐藏层的状态
        forward_fn:前向传播的方法,即上面封装的rnn函数
        """
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        #调用封装的函数:获取初始化的模型参数
        self.params = get_params(vocab_size, num_hiddens, device=device)
        self.init_rnn_state, self.forward_fn = init_rnn_state, forward_fn
    
    #定义call方法,使类像函数一样被调用
    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)
    
    def begin_state(self, batch_size, device):
        return self.init_rnn_state(batch_size, self.num_hiddens, device)
# 试用上面封装的类和函数
num_hiddens = 512
rnn_net = RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_params, init_rnn_state, rnn)
state = rnn_net.begin_state(X.shape[0], dltools.try_gpu())
Y, new_state = rnn_net(X.to(dltools.try_gpu()), state)
Y.shape
torch.Size([10, 28])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值