备忘:tensorflow关于网络权重

一:用xx.npz文件初始化网络

使用tensorpack框架的时候,发现官方提供的训练好的权重文件是xx.npz格式的,我想将其某些层的参数用在自己的网络中。

import os
import random

import tensorflow as tf
import numpy as np

PRE_IMANET_NPZ = 'XX.npz'


def convert_param_name(param):
    # 获取.npz文件里的变量名称
    # 和网络的变量名称进行比较
    # 得到 网络变量名称:文件里的变量值 这样的字典
    # print('--> convert_param_name ...')
    resnet_param = {}
    for k in param.keys():
        # print(k) 
        var_name = k.replace('W', 'weights')
        var_name = var_name.replace('bn', 'BatchNorm')
        resnet_param[var_name] = param[k]
    return resnet_param


def initial_imagenet(sess, path_to_npz):
    print('Initializing through npz file trained on ImageNet ...')
    sess.run(tf.global_variables_initializer())  # 先初始化网络 避免有些网络的变量不存在在文件里
    param = np.load(path_to_npz, encoding='latin1')  # 加载文件
    param 
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值