basic_rnn_seq2seq

import tensorflow as tf
import numpy as np

steps=2
batch_size=4
input_size=3

encoder_inputs = tf.placeholder("float", [None, steps, input_size]) # (4,2,3)
decoder_inputs = tf.placeholder("float", [None, steps, input_size])

en_input=np.zeros(shape=[batch_size,steps,input_size])
de_input=np.zeros(shape=[batch_size,steps,input_size])

cell=tf.nn.rnn_cell.BasicLSTMCell(5)

def get_result(encoder_inputs,decoder_inputs,cell):
    encoder_inputs=tf.unstack(encoder_inputs,axis=1)
    decoder_inputs=tf.unstack(decoder_inputs,axis=1)
    result=tf.contrib.legacy_seq2seq.basic_rnn_seq2seq(
        encoder_inputs,
        decoder_inputs,
        cell,
        dtype=tf.float32,
        scope=None
    )
    return result
result=get_result(encoder_inputs,decoder_inputs,cell)
print(len(result))
init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    result_value=sess.run(result,feed_dict={encoder_inputs:en_input,decoder_inputs:de_input})
    print(len(result_value))
    print(result_value)
    print('-----------------')
    print(type(result_value))
    print(result_value[0])
    print('------------------')
    print(result_value[1])
WARNING:tensorflow:From <ipython-input-2-369e707092eb>:14: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').
2
2
([array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)], LSTMStateTuple(c=array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32), h=array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)))
-----------------
<class 'tuple'>
[array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)]
------------------
LSTMStateTuple(c=array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32), h=array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值