tf.nn.dynamic_rnn()实现的一个例子。

tf.nn.dynamic_rnn()在处理变长输入时特别方便,具体解释可以看这篇博文

#coding=utf-8
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)

# 第二个example长度为6
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

result = tf.contrib.learn.run_n(
    {"outputs": outputs, "last_states": last_states},
    n=1,
    feed_dict=None)
a = result[0]
print(a)

assert result[0]["outputs"].shape == (2, 10, 5)

# 第二个example中的outputs超过6步(7-10步)的值应该为0
assert (result[0]["outputs"][1,7,:] == np.zeros(cell.output_size)).all()

输出:

[[ 0.13337141  0.10697078  0.11238842 -0.16187296  0.04447445]
 [ 0.06800554  0.29581101  0.14454009 -0.11857419 -0.08062822]
 [-0.02766501  0.20230338  0.25521379 -0.10196185  0.02908   ]
 [ 0.07160553  0.39891538  0.03997988 -0.43861938 -0.00340179]
 [-0.12841535  0.35346241  0.08577594 -0.29574161 -0.06306395]
 [-0.19022677  0.11256105 -0.13190501 -0.20170257 -0.02765217]
 [-0.04303006 -0.42253068 -0.02945417 -0.0817529   0.03569792]
 [-0.01433148  0.00066725 -0.08619441 -0.1063433   0.36421112]
 [ 0.19718385  0.06653057  0.02880462 -0.31631752  0.04064322]
 [ 0.07665874  0.15330013  0.11820727 -0.28386946 -0.06841132]]
-------------------------------------------------------
[[ 0.09817442  0.12635493  0.14153314 -0.13827174 -0.14350587]
 [ 0.09484242  0.05155221  0.11429032 -0.04175748 -0.11621833]
 [ 0.21802519  0.17491722  0.17653461 -0.2161642  -0.17876485]
 [ 0.05461165 -0.01181785  0.31818148 -0.18725258 -0.06083239]
 [ 0.03753194  0.04578742  0.30538616 -0.09413831 -0.41238963]
 [-0.04687686  0.01701181  0.21276684 -0.02761401 -0.07971509]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]

参考资料:
tensorflow高阶教程:tf.dynamic_rnn

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值