一:用xx.npz文件初始化网络
使用tensorpack框架的时候,发现官方提供的训练好的权重文件是xx.npz格式的,我想将其某些层的参数用在自己的网络中。
import os
import random
import tensorflow as tf
import numpy as np
PRE_IMANET_NPZ = 'XX.npz'
def convert_param_name(param):
# 获取.npz文件里的变量名称
# 和网络的变量名称进行比较
# 得到 网络变量名称:文件里的变量值 这样的字典
# print('--> convert_param_name ...')
resnet_param = {}
for k in param.keys():
# print(k)
var_name = k.replace('W', 'weights')
var_name = var_name.replace('bn', 'BatchNorm')
resnet_param[var_name] = param[k]
return resnet_param
def initial_imagenet(sess, path_to_npz):
print('Initializing through npz file trained on ImageNet ...')
sess.run(tf.global_variables_initializer()) # 先初始化网络 避免有些网络的变量不存在在文件里
param = np.load(path_to_npz, encoding='latin1') # 加载文件
param