tensorflow中函数的输出问题

1.tf.contrib.slim.model_analyzer.analyze_vars

 

import tensorflow.contrib.slim as slim
slim.model_analyzer.analyze_vars(g_var, print_info=True)

用以输出g_var参数的信息(print_info=True),输出结果如下所示:

---------
Variables: name (type shape) [size]
---------
Generator/g_1_deconv/Conv2d_transpose/weights:0 (float32_ref 2x2x384x100) [153600, bytes: 614400]
Generator/g_1_deconv/Conv2d_transpose/biases:0 (float32_ref 384) [384, bytes: 1536]
Generator/g_1_deconv/BatchNorm/beta:0 (float32_ref 384) [384, bytes: 1536]
Generator/g_1_deconv/BatchNorm/gamma:0 (float32_ref 384) [384, bytes: 1536]
Generator/g_2_deconv/Conv2d_transpose/weights:0 (float32_ref 4x4x128x384) [786432, bytes: 3145728]
Generator/g_2_deconv/Conv2d_transpose/biases:0 (float32_ref 128) [128, bytes: 512]
Generator/g_2_deconv/BatchNorm/beta:0 (float32_ref 128) [128, bytes: 512]
Generator/g_2_deconv/BatchNorm/gamma:0 (float32_ref 128) [128, bytes: 512]
Generator/g_3_deconv/Conv2d_transpose/weights:0 (float32_ref 4x4x64x128) [131072, bytes: 524288]
Generator/g_3_deconv/Conv2d_transpose/biases:0 (float32_ref 64) [64, bytes: 256]
Generator/g_3_deconv/BatchNorm/beta:0 (float32_ref 64) [64, bytes: 256]
Generator/g_3_deconv/BatchNorm/gamma:0 (float32_ref 64) [64, bytes: 256]
Generator/g_4_deconv/Conv2d_transpose/weights:0 (float32_ref 4x4x32x64) [32768, bytes: 131072]
Generator/g_4_deconv/Conv2d_transpose/biases:0 (float32_ref 32) [32, bytes: 128]
Generator/g_5_deconv/Conv2d_transpose/weights:0 (float32_ref 7x7x3x32) [4704, bytes: 18816]
Generator/g_5_deconv/Conv2d_transpose/biases:0 (float32_ref 3) [3, bytes: 12]
Total size of variables: 1110339
Total bytes of variables: 4441356

2.TypeError: Failed to convert object of type <type 'list'> to Tensor. Contents: [Dimension(16), Dimension(57), Dimension(57), 1024]. Consider casting elements to a supported type.

Solution:

TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, Dimension(4608)]. Consider casting elements to a supported type.

跟踪发现是tf.reshape()时候报错!

1 flatten_shape = input.get_shape()[1] * input.get_shape()[2] * input.get_shape()[3]
2 return tf.reshape(input, [-1, flatten_shape], name="flatten")

这里需要改成

flatten_shape = input.get_shape().as_list()[1] * input.get_shape().as_list()[2] * input.get_shape().as_list()[3]
return tf.reshape(input, [-1, flatten_shape], name="flatten")

需要使用.as_list()将获取到的shape转换成list才行。

3

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值