给大家推荐免费的kugou音乐 vip哈, 还不知道的一定有相见恨晚的感觉,获取地址http://www.service99.cn
------------题记
每个框架都有查看权重参数的机制,在tenstensorflow中
查看的例子如下
import tensorflow as tf
import numpy as np
reader = tf.train.NewCheckpointReader('model-100')
all_variables = reader.get_variable_to_shape_map()
w0 = reader.get_tensor("conv0/W")
print(type(w0))
print(w0.shape)
print(w0[0])
b0 = reader.get_tensor("conv0/b")
print(type(b0))
print(b0.shape)
print(b0)
注意这里,在保存moxi模型的目录中有checkpoint文件,有model-100.data-00000-of-00001和model-100.index文件,此处我们只写.之前的东西。
直接是Numpy.ndarray格式,这个很好。
使用txt文件保存权重的代码为
import tensorflow as tf
import numpy as np
reader = tf.train.NewCheckpointReader('model-100')
all_variables = reader.get_variable_to_shape_map()
quantized_conv_list = ['conv1','conv2','conv3','conv4']
pf = open('result.txt', 'w+')
for quantized_conv_name in quantized_conv_list:
weight = reader.get_tensor(quantized_conv_name+"/W")
print quantized_conv_name
print '***************************************'
print weight.shape
[n,cout,h,w]=weight.shape
print cout,h,w
pf.write(quantized_conv_name)
pf.write('\n')
pf.write(str(n)+' '+str(cout)+' '+str(h)+' '+str(w)+'\n')
#for c in range(cout):
#pf.write('***********'+str(c)+'**********\n')
for n1 in range(n):
pf.write('***********'+str(n1)+'**********\n')
for h1 in range(h):
for w1 in range(w):
for c in range(cout):
pf.write('%f ' %weight[n1][c][h1][w1])
pf.write('\n')
#pf.write('\n')
try:
bias = reader.get_tensor(quantized_conv_name+"/b")
n2=bias.shape
print bias.shape
print n2
print '***************************************'
pf.write('\n')
pf.write('**************************bias:')
pf.write('\n')
pf.write(str(n)+'\n')
#for n1 in range(n2):
# pf.write('%f, ' %bias[n1])
#pf.write('\n')
for b in bias:
pf.write('%f '%b)
except:
print 'no bias'
pf.write('\n')
pf.close()
注意这里的conv1/W可以在log文件中看到。不同的命名方式不一样。