1.tf.contrib.slim.model_analyzer.analyze_vars
import tensorflow.contrib.slim as slim
slim.model_analyzer.analyze_vars(g_var, print_info=True)
用以输出g_var参数的信息(print_info=True),输出结果如下所示:
---------
Variables: name (type shape) [size]
---------
Generator/g_1_deconv/Conv2d_transpose/weights:0 (float32_ref 2x2x384x100) [153600, bytes: 614400]
Generator/g_1_deconv/Conv2d_transpose/biases:0 (float32_ref 384) [384, bytes: 1536]
Generator/g_1_deconv/BatchNorm/beta:0 (float32_ref 384) [384, bytes: 1536]
Generator/g_1_deconv/BatchNorm/gamma:0 (float32_ref 384) [384, bytes: 1536]
Generator/g_2_deconv/Conv2d_transpose/weights:0 (float32_ref 4x4x128x384) [786432, bytes: 3145728]
Generator/g_2_deconv/Conv2d_transpose/biases:0 (float32_ref 128) [128, bytes: 512]
Generator/g_2_deconv/BatchNorm/beta:0 (float32_ref 128) [128, bytes: 512]
Generator/g_2_deconv/BatchNorm/gamma:0 (float32_ref 128) [128, bytes: 512]
Generator/g_3_deconv/Conv2d_transpose/weights:0 (float32_ref 4x4x64x128) [131072, bytes: 524288]
Generator/g_3_deconv/Conv2d_transpose/biases:0 (float32_ref 64) [64, bytes: 256]
Generator/g_3_deconv/BatchNorm/beta:0 (float32_ref 64) [64, bytes: 256]
Generator/g_3_deconv/BatchNorm/gamma:0 (float32_ref 64) [64, bytes: 256]
Generator/g_4_deconv/Conv2d_transpose/weights:0 (float32_ref 4x4x32x64) [32768, bytes: 131072]
Generator/g_4_deconv/Conv2d_transpose/biases:0 (float32_ref 32) [32, bytes: 128]
Generator/g_5_deconv/Conv2d_transpose/weights:0 (float32_ref 7x7x3x32) [4704, bytes: 18816]
Generator/g_5_deconv/Conv2d_transpose/biases:0 (float32_ref 3) [3, bytes: 12]
Total size of variables: 1110339
Total bytes of variables: 4441356
2.TypeError: Failed to convert object of type <type 'list'> to Tensor. Contents: [Dimension(16), Dimension(57), Dimension(57), 1024]. Consider casting elements to a supported type.
Solution:
TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, Dimension(4608)]. Consider casting elements to a supported type.
跟踪发现是tf.reshape()时候报错!
1 flatten_shape = input.get_shape()[1] * input.get_shape()[2] * input.get_shape()[3] 2 return tf.reshape(input, [-1, flatten_shape], name="flatten")
这里需要改成
flatten_shape = input.get_shape().as_list()[1] * input.get_shape().as_list()[2] * input.get_shape().as_list()[3] return tf.reshape(input, [-1, flatten_shape], name="flatten")
需要使用.as_list()将获取到的shape转换成list才行。
3