按《tensorflow2caffe(2) : 如何在tensorflow中取出模型参数》一文的代码原理:
把下面的代码放到 main.py 的 generator 部分:
#-------------------------------------------------------------
# 这里应该是global_variables,如果trainable_variables,则会缺少一些参数
# all_vars = tf.trainable_variables()
all_vars = tf.global_variables()
fp = open('SRGAN_generator_model.txt', 'w')
for v in all_vars:
name = v.name
fname = name + '.prototxt'
fname = fname.replace('/','_')
print (fname)
fp.write(fname)
fp.write('\n')
v_4d = np.array(sess.run(v))
if v_4d.ndim == 4:
#v_4d.shape [ H, W, I, O ]
v_4d = np.swapaxes(v_4d, 0, 2) # swap H, I
v_4d = np.swapaxes(v_4d, 1, 3) # swap W, O
v_4d = np.swapaxes(v_4d, 0, 1) # swap I, O
#v_4d.shape [ O, I, H, W ]
vshape = v_4d.shape[:]
v_1d = v_4d.reshape(v_4d.shape[0]*v_4d.shape[1]*v_4d.shape[2]*v_4d.shape[3])
fp.write(' blobs {\n')
for vv in v_1d:
fp.write(' data: %8f' % vv)
fp.write('\n')
fp.write(' shape {\n')
for s in vshape:
fp.write(' dim: ' + str(s))#print dims
fp.write('\n')
fp.write(' }\n')
fp.write(' }\n')
elif v_4d.ndim == 1 :#do not swap
fp.write(' blobs {\n')
for vv in v_4d:
fp.write(' data: %.8f' % vv)
fp.write('\n')
fp.write(' shape {\n')
fp.write(' dim: ' + str(v_4d.shape[0]))#print dims
fp.write('\n')
fp.write(' }\n')
fp.write(' }\n')
fp.close()
#-------------------------------------------------------------
然后运行就导出了一个文本方式的数据
SRGAN_generator_model.txt:
generator_generator_unit_input_stage_conv_Conv_weights:0.prototxt
blobs {
data: -0.022789
data: -0.008191
data: -0.001650
...省略
data: 0.007882
data: 0.007484
shape {
dim: 64
dim: 3
dim: 9
dim: 9
}
}
generator_generator_unit_input_stage_conv_Conv_biases:0.prototxt
blobs {
data: -0.09940426
...省略
data: -0.06667865
shape {
dim: 64
}
}
generator_generator_unit_input_stage_Prelu_alpha:0.prototxt
blobs {
...省略
shape {
dim: 64
}
}
generator_generator_unit_resblock_1_conv_1_Conv_weights:0.prototxt
blobs {
...省略
shape {
dim: 64
dim: 64
dim: 3
dim: 3
}
}
generator_generator_unit_resblock_1_BatchNorm_beta:0.prototxt
...省略
这样就可以和 caffe_srgan-master中的数据对比一下异同。