PyTorch 循环模块解析

本文深入解析了LSTM、GRU和RNN的内部工作机制,通过Python代码展示了它们的实现过程。LSTM模块详细阐述了其参数、输入输出数据格式,并提供了功能验证。GRU结构的PyTorch计算过程也被详尽解释。最后,RNN模块的简单实现被给出,便于理解这些循环神经网络的基础操作。
摘要由CSDN通过智能技术生成

LSTM模块

参数说明

输入的参数列表包括:

  • input_size:输入数据的特征维数
  • hidden_size:LSTM中隐层的维度
  • num_layer:循环神经网络的层数
  • bias:是否用bias参数,默认为True
  • batch_first:是否将batch设置为输入数据第一位,设置后output同样按照此规则进行。默认为False
  • dropout 默认是0,代表不用dropout
  • bidirectional默认是false,代表不用双向LSTM

输入数据:input,(h_0,c_0):

  • input:形状为(seq_length,batch_size,input_size)的张量
  • h_0:形状为(num_layers*num_directions,batch,hidden_size)的张量,它包含了在当前这个batch_size中每个句子的初始隐藏状态,num_layers就是LSTM的层数,如果bidirectional=True则:num_directions=2,否则就是1,表示只有一个方向,
  • c_0h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始cell状态。h_0,c_0如果不提供,那么默认是0

输出数据包括output,(h_n,c_n):

  • output:形状(seq_length,batch_size,num_directions*hidden_size),
    它包含的LSTM的最后一层的输出特征(h_t),t是batch_size中每个句子的长度.
  • h_n:形状(num_directions * num_layers,batch,hidden_size)
  • c_n.shape==h_n.shape
  • h_n包含的是句子的最后一个单词的隐藏状态,c_n包含的是句子的最后一个单词的cell状态,所以它们都与句子的长度seq_length无关。
    各个层计算流程
    在这里插入图片描述

python功能验证

import torch


def layer_output(input_data, w_ii, w_hi, b_ii, b_hi, h, fn='sigmoid'):
    output = torch.matmul(w_ii, input_data.T)+b_ii.view(-1, 1) + \
        torch.matmul(w_hi, h)+b_hi.view(-1, 1)
    if fn == 'sigmoid':
        return torch.sigmoid(output)
    else:
        return torch.tanh(output)


def lstm_output(input_data, hh_weight, ih_weight, hh_bias, ih_bias, hidden_data, current):
    seq_length, batch, input_num = input_data.shape
    stack, _, hidden = hidden_data.shape
    for i in range(seq_length):
        input_data_0 = input_data[i][:][:]
        w_ii, w_if, w_ig, w_io = ih_weight[:hidden][:], ih_weight[hidden:2 *
                                                                  hidden][:], ih_weight[2*hidden:3*hidden][:], ih_weight[3*hidden:4*hidden][:]
        w_hi, w_hf, w_hg, w_ho = hh_weight[:hidden][:], hh_weight[hidden:2 *
                                                               hidden][:], hh_weight[2*hidden:3*hidden][:], hh_weight[3*hidden:4*hidden][:]
        b_ii, b_if, b_ig, b_io = ih_bias[:hidden], ih_bias[hidden:2 *
                                                           hidden], ih_bias[2*hidden:3*hidden], ih_bias[3*hidden:4*hidden]
        b_hi, b_hf, b_hg, b_ho = hh_bias[:hidden], hh_bias[hidden:2 *
                                                           hidden], hh_bias[2*hidden:3*hidden], hh_bias[3*hidden:4*hidden]

        h = hidden_data.view(-1, 1)
        i_t = layer_output(input_data_0, w_ii, w_hi, b_ii, b_hi, h)
        f_t = layer_output(input_data_0, w_if, w_hf, b_if, b_hf, h)
        g_t = layer_output(input_data_0, w_ig, w_hg, b_ig, b_hg, h, 'tanh')

        o_t = layer_output(input_data_0, w_io, w_ho, b_io, b_ho, h)
        c_t = f_t*current.view(-1, 1)+i_t*g_t
        h_t = o_t*torch.tanh(c_t)
        input_data_0 = h_t
        hidden_data = h_t
        current = c_t
    output = h_t.resize(stack,batch,hidden)
    return output, (c_t.resize(stack,batch,hidden), h_t.resize(stack,batch,hidden))
def compare(my_data,torch_data):
    res = torch.sum(my_data-torch_data)
    if res<1e-5:
        print("Verify passed")
    else:
        print("Verify Faied")

input_num = 57
hidden_num = 64
torch.manual_seed(1)
lstm = torch.nn.LSTM(input_size=input_num, hidden_size=hidden_num,
                     num_layers=1, bias=True)

hh_weight, ih_weight = lstm.weight_hh_l0, lstm.weight_ih_l0
hh_bias, ih_bias = lstm.bias_hh_l0, lstm.bias_ih_l0
current = torch.randn(1, 1, hidden_num)
hidden = torch.randn(1, 1, hidden_num)
input_data = torch.rand(5, 1, 57)
o_t, (m_c, m_h) = lstm_output(input_data, hh_weight, ih_weight,
                            hh_bias, ih_bias, hidden, current)
# weights_shape = [weights.shape for weights in weights]

# data = torch.ones(size=(1, 1, input_num), dtype=torch.float)
output_res, (hn_res, cn_res) = lstm(
    input_data, (hidden, current))
print("Torch output:", output_res.shape, hn_res.shape, cn_res.shape)
print("My output:", o_t,m_h.shape, m_c.shape)
compare(output_res,o_t)
compare(hn_res,m_h)
compare(cn_res,m_c)


GRU结构
在这里插入图片描述
PyTorch计算:

import torch
from torch import nn
def dense(input_data,weight,bias):
    return torch.matmul(weight,input_data.view(-1,1))+bias.view(-1,1)
def my_gru(input_data,hidden,weight_hh_l0,weight_ih_l0,bias_ih_l0,bias_hh_l0):
    W_ir,W_iz,W_in = weight_ih_l0[:2,:],weight_ih_l0[2:4,:],weight_ih_l0[4:,:]
    b_ir,b_iz,b_in = bias_ih_l0[:2],bias_ih_l0[2:4],bias_ih_l0[4:]
    
    W_hr,W_hz,W_hn = weight_hh_l0[:2,:],weight_hh_l0[2:4,:],weight_hh_l0[4:,:]
    b_hr,b_hz,b_hn = bias_hh_l0[:2],bias_hh_l0[2:4],bias_hh_l0[4:]

    r_t = torch.sigmoid(dense(input_data,W_ir,b_ir)+dense(hidden,W_hr,b_hr))
    z_t = torch.sigmoid(dense(input_data,W_iz,b_iz)+dense(hidden,W_hz,b_hz))
    n_t = torch.tanh(dense(input_data,W_in,b_in)+r_t*dense(hidden,W_hn,b_hn))
    h_t = (1-z_t)*n_t+z_t*hidden.view(-1,1)
    return h_t

gru = nn.GRU(input_size=3,hidden_size=2,num_layers=1)
input_data = torch.randn(1, 1, 3)
init_hidden = torch.randn(1, 1, 2)
output, hn = gru(input_data, init_hidden)
weight_hh_l0,weight_ih_l0,bias_ih_l0,bias_hh_l0 = gru.weight_hh_l0,gru.weight_ih_l0,gru.bias_ih_l0,gru.bias_hh_l0
my_output = my_gru(input_data,init_hidden,weight_hh_l0,weight_ih_l0,bias_ih_l0,bias_hh_l0)
print("PyTorch output:{} my output:{}".format(output,my_output))

RNN模块

def my_rnn(input_data, weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0, h_0):
    input_data_0 = input_data[0][:][:].reshape(1, -1)
    h_0 = h_0.reshape(1, -1)

    h_output_0 = torch.matmul(weight_hh_l0, h_0.T)+bias_hh_l0.reshape(-1, 1) # update hidden_0 ==> output_hidden_0
    h_1 = torch.tanh(torch.matmul(weight_ih_l0, input_data_0.T) +
                          bias_ih_l0.reshape(-1, 1)+h_output_0).T # hidden_0 ==> hidden_1

    input_data_1 = input_data[1][:][:].reshape(1, -1)
    h_output_1 = torch.matmul(weight_hh_l0, h_1.T)+bias_hh_l0.reshape(-1, 1)

    output_2 = torch.tanh(torch.matmul(weight_ih_l0, input_data_1.T) +
                        bias_ih_l0.reshape(-1, 1)+h_output_1).T

    return (h_1, output_2), output_2
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值