tf.contrib.rnn.static_bidirectional_rnn和MultiRNNCell构建多层静态双向LSTM

import tensorflow as tf
import numpy as np

# 设置训练参数
learning_rate = 0.01
max_examples = 40
batch_size = 128
display_step = 10  # 每间隔10次训练就展示一次训练情况

n_input = 100#词向量维度
n_steps = 300#时间步长
fw_n_hidden = 256#正向神经元数量
bw_n_hidden = 128#反向神经元数量
n_classes = 10

x = tf.placeholder("float", [max_examples, n_steps, n_input])
y = tf.placeholder('float', [max_examples, n_classes])
weights = tf.Variable(tf.random_normal([(fw_n_hidden + bw_n_hidden), n_classes]))
biases = tf.Variable(tf.random_normal([n_classes]))

x = tf.transpose(x, [1, 0, 2])
print(x.shape)  
x = tf.reshape(x, [-1, n_input])
print(x.shape) 
x = tf.split(x, n_steps)
print(len(x), x[0].shape) 

# lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(fw_n_hidden, forget_bias=1.0)  # 正向RNN,输出神经元数量为256
# lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(bw_n_hidden, forget_bias=1.0)  # 反向RNN,输出神经元数量为128

lstm_fw_cell=[]
lstm_bw_cell=[]
for i in range(3):
    lstm_fw_cell.append(tf.contrib.rnn.BasicLSTMCell(fw_n_hidden, forget_bias=1.0) )
    lstm_bw_cell.append( tf.contrib.rnn.BasicLSTMCell(bw_n_hidden, forget_bias=1.0))

mul_lstm_fw_cell=tf.contrib.rnn.MultiRNNCell(lstm_fw_cell)
mul_lstm_bw_cell=tf.contrib.rnn.MultiRNNCell(lstm_bw_cell)

outputs, fw_state, bw_state = tf.contrib.rnn.static_bidirectional_rnn(mul_lstm_fw_cell, mul_lstm_bw_cell, x, dtype=tf.float32)

print(len(outputs))##300,等于时间步的长度,一般取outputs[-1]也就是最后一步的输出进行运算
print(outputs[0].shape)#(40, 384)
print(outputs[-1].shape)#(40, 384),一般取最后一个时间步的输出来进行运算

print(len(fw_state))#三个LSTM隐藏层
# print(fw_state)

#正向RNN第一个LSTM隐藏层的c状态
print(fw_state[0].c.shape)#(40, 256)
print(fw_state[1].c.shape)#(40, 256)
print(fw_state[2].c.shape)#(40, 256)

#正向RNN第一个LSTM隐藏层的h状态
print(fw_state[0].h.shape)#(40, 256)
print(fw_state[1].h.shape)#(40, 256)
print(fw_state[2].h.shape)#(40, 256)

#反向RNN第一个LSTM隐藏层的c状态
print(bw_state[0].c.shape)#(40, 256)
print(bw_state[1].c.shape)#(40, 256)
print(bw_state[2].c.shape)#(40, 256)

#反向RNN第一个LSTM隐藏层的h状态
print(bw_state[0].h.shape)#(40, 256)
print(bw_state[1].h.shape)#(40, 256)
print(bw_state[2].h.shape)#(40, 256)

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值