pytorch_gru理解

gru理解

原理图

在这里插入图片描述

在这里插入图片描述

代码(pytorch)

import torch
import torch.nn as nn

batch_size = 3
seq_lens = 4
input_size = 2
hidden_size = 5
num_layers = 1
bidirectional = True

n_direction = 2 if bidirectional else 1

gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)

input = torch.randn(seq_lens, batch_size, input_size)
hidden = torch.zeros(num_layers * n_direction, batch_size, hidden_size)

out, hn = gru(input, hidden)

if n_direction == 2:
    hidden_cat = torch.cat([hn[-1], hn[-2]], dim=1)
else:
    hidden_cat = hidden[-1]

print(gru.parameters())

print("input:", input)
print("input shape:", input.shape)

print("hidden:", hidden)
print("hidden shape:", hidden.shape)

print("out:", out)
print("out shape:", out.shape)
print("hn:", hn)
print("hn shape:", hn.shape)
print("hidden_cat:", hidden_cat)
print("hidden_cat shape:", hidden_cat.shape)

输出结果:

<generator object Module.parameters at 0x7fab2e0e0bf8>
input: tensor([[[-1.4577, -0.0883],
         [ 0.0624,  1.3746],
         [-0.2722, -0.6366]],
        [[-0.5036,  1.2694],
         [-0.7932, -0.4175],
         [-1.1812, -1.3432]],
        [[ 0.1605, -2.0594],
         [ 0.8873,  0.7250],
         [ 1.0732,  0.8253]],
        [[ 0.6679, -0.4159],
         [-0.6118,  0.7531],
         [-0.1456,  0.1618]]])
input shape: torch.Size([4, 3, 2])
hidden: tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
hidden shape: torch.Size([2, 3, 5])


out: tensor([[[ 0.1100,  0.3579, -0.2854,  0.0465,  0.2265,  0.0457,  0.3667,
           0.0135,  0.0013,  0.1330],
         [ 0.1946,  0.2779,  0.2816,  0.0291,  0.0342, -0.0239,  0.2849,
           0.1947,  0.2360,  0.0147],
         [-0.0233,  0.2177, -0.1126,  0.2133, -0.0898,  0.0341, -0.0355,
          -0.3002, -0.0752, -0.0234]],
        [[ 0.2873,  0.5285,  0.0650,  0.0538,  0.2787, -0.0166,  0.1987,
           0.0093,  0.1445, -0.0308],
         [ 0.1473,  0.3692, -0.2229,  0.1248,  0.0641,  0.0155,  0.1096,
          -0.0015,  0.0621,  0.0221],
         [-0.0325,  0.2317, -0.5337,  0.2691,  0.0074,  0.0518, -0.0028,
          -0.2400,  0.0070, -0.0293]],
        [[ 0.0492, -0.0384, -0.4666,  0.4176, -0.1227,  0.0282, -0.3614,
          -0.3667, -0.1225, -0.1316],
         [ 0.1222,  0.4731,  0.2719,  0.2220, -0.0679, -0.0373,  0.0059,
           0.1497,  0.1780, -0.0708],
         [ 0.0511,  0.4006,  0.2200,  0.2951, -0.0967, -0.0563, -0.0760,
           0.0471,  0.1901, -0.1464]],
        [[-0.0551,  0.2066,  0.0331,  0.3877, -0.2199, -0.0202, -0.1984,
          -0.1131, -0.0051, -0.0988],
         [ 0.2098,  0.5760,  0.0440,  0.0953,  0.1283, -0.0013,  0.2712,
           0.1320,  0.0844,  0.0797],
         [ 0.0770,  0.4780,  0.0344,  0.2017, -0.0510, -0.0039,  0.0620,
          -0.0098,  0.0251,  0.0057]]], grad_fn=<CatBackward>)
out shape: torch.Size([4, 3, 10])


hn: tensor([[[-0.0551,  0.2066,  0.0331,  0.3877, -0.2199],
         [ 0.2098,  0.5760,  0.0440,  0.0953,  0.1283],
         [ 0.0770,  0.4780,  0.0344,  0.2017, -0.0510]],
        [[ 0.0457,  0.3667,  0.0135,  0.0013,  0.1330],
         [-0.0239,  0.2849,  0.1947,  0.2360,  0.0147],
         [ 0.0341, -0.0355, -0.3002, -0.0752, -0.0234]]],
       grad_fn=<StackBackward>)		# [[h~N~^b^],[h~N~^f^]] 反着拼的
hn shape: torch.Size([2, 3, 5])


hidden_cat: tensor([[ 0.0457,  0.3667,  0.0135,  0.0013,  0.1330, -0.0551,  0.2066,  0.0331,
          0.3877, -0.2199],
        [-0.0239,  0.2849,  0.1947,  0.2360,  0.0147,  0.2098,  0.5760,  0.0440,
          0.0953,  0.1283],
        [ 0.0341, -0.0355, -0.3002, -0.0752, -0.0234,  0.0770,  0.4780,  0.0344,
          0.2017, -0.0510]], grad_fn=<CatBackward>)
hidden_cat shape: torch.Size([3, 10])

总结(直观理解)

  • 喂入下一层的数据是hidden,而不是out ???

  • 需要按序列长度(就是不包括填充的长度)从大到小排列

  • 对于input:每个数据(序列)是竖着放的(维度就是:seq_lens, batch_size, input_size【循环网络的input好像都是这个】)

  • 对与hidden:(num_layers * n_direction, batch_size, hidden_size)

  • 数据上的改变:即把原数据的input_size这一维的长度变成了hidden的hidden_size这一维

  • 如果是双向循环,

    • 则output:hidden_size这一维的长度会double(见原理图就明白了)
    • 则hn:第一维的长度会double(见原理图就明白了)
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值