RNN、RNNCell

# -*- encoding: utf-8 -*-
'''
@Author: Xiaosu Wang
@Email: 19110240018@fudan.edu.cn

@Version : 1.0
@File : rnn.py
@Time : 2020-01-30 22:44

@Description :

    Pytorch中RNN相关源码在文件:torch/nn/modules/rnn.py  
'''

import torch
import torch.nn as nn


init_seed = 2020
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
# np.random.seed(init_seed) # 用于numpy的随机数

def print_parameters(module):
    for name, params in module.named_parameters():
        print(name)
        print(params)

'''
input_size: 输入特征维度,即词向量维度
hidden_size: 隐藏特征维度
num_layers: 层数。 Default: 1
nonlinearity: 非线性函数,'tanh' or 'relu'. Default: 'tanh'
bias: 是否加偏移。 Default: ``True``
batch_first: If ``True``, then the input and output tensors are provided
             as `(batch, seq, feature)`. Default: ``False`` (seq_len, batch, input_size)
dropout: 非0表示添加dropout层. Default: 0
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
'''
rnn = nn.RNN(4, 5, 1) # RNN : input_size, hidden_size, num_layers
rnn_cell = nn.RNNCell(4, 5) # RNNCell : input_size, hidden_size

input = torch.randn(3, 2, 4)  # (seq_len, batch, input_size)
h0 = torch.randn(1, 2, 5) # (num_layers * num_directions, batch, hidden_size)

# 计算RNN
output, hn = rnn(input, h0)

# 用 RNNCell 模拟 RNN
hx = h0[0]
output_cell = []
for i in range(3):
    hx = rnn_cell(input[i], hx)
    output_cell.append(hx)


print('----' * 5 + '两者结果不同' + '----' * 5 )

print(output)
print(output_cell)

print('----' * 5 + '观察 RNN 、RNNCell 的参数' + '----' * 5 )
print_parameters(rnn)
print('----' * 5)
print_parameters(rnn_cell)

print('----' * 5 + '将 RNN 的参数赋值给 RNNCell,使两者 Cell 的参数一样' + '----' * 5 )
rnn_cell.weight_ih = rnn.weight_ih_l0
rnn_cell.weight_hh = rnn.weight_hh_l0
rnn_cell.bias_ih = rnn.bias_ih_l0
rnn_cell.bias_hh = rnn.bias_hh_l0

print('----' * 5 + '观察 RNN 、RNNCell 的参数' + '----' * 5 )
print_parameters(rnn)
print('----' * 5)
print_parameters(rnn_cell)

print('----' * 5 + '重新用 RNNCell 模拟 RNN' + '----' * 5 )

output_cell = []
hx = h0[0]
for i in range(3):
    hx = rnn_cell(input[i], hx)
    output_cell.append(hx)

print('----' * 5 + '两者结果相同' + '----' * 5 )

print(output)
print(output_cell)

print('----' * 5 + '多层、双向都是相同的道理' + '----' * 5 )

 

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值