取tf中的矩阵维度,并转成int

在写模型的时候,我们希望一个模型的参数是灵活的,例如矩阵乘的时候可以根据输入最后一维的大小来定义一个W。

获得矩阵的维度

  1. tf.shape(input)
    input为所求矩阵,返回该矩阵的维度,但是是一个Tensor。经常取出来的维度值并不能直接用,因为会出现类似这样的报错:
    TypeError: int() argument must be a string or a number, not ‘Tensor’

  2. input.get_shape()
    这样得到的是Dimension类型的对象。

解决办法:
使用as_list()函数将Dimention

k=input.get_shape().as_list()[-1]

例子:

u=tf.reshape(np.arange(0,6),[3,2])
k=u.get_shape().as_list()[-1]
w=tf.Variable(tf.random_uniform([k,4]))
prod=tf.matmul(tf.cast(u,tf.float32),w)
with tf.Session()as s:
    s.run(tf.initialize_all_variables())
    print s.run(prod)

output:

[[ 0.61596847  0.58131492  0.46035814  0.58667159]
 [ 3.24906611  2.39435339  1.6604023   3.1891706 ]
 [ 5.882164    4.20739174  2.86044645  5.79166985]]
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值