tensorflow预训练权重导入

tensorflow预训练权重导入

1. 代码

def load(data_path, session):
    """
    load the VGG16_pretrain parameters file
    :param data_path:
    :param session:
    :return:
    """
    data_dict = np.load(data_path, encoding='latin1',allow_pickle=True).item()

    keys = sorted(data_dict.keys())
    for key in keys:
        with tf.variable_scope(key, reuse=True):
            for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                session.run(tf.get_variable(subkey).assign(data))


def load_with_skip(data_path, session, skip_layer):
    """
    Only load some layer parameters
    :param data_path:
    :param session:
    :param skip_layer:
    :return:
    """
    data_dict = np.load(data_path, encoding='latin1',allow_pickle=True).item()

    for key in data_dict:
        if key not in skip_layer:
            with tf.variable_scope(key, reuse=True):
                for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                    session.run(tf.get_variable(subkey).assign(data))

这里作为验证仅通过输入一张图片判断vgg16的输出类别

import tensorflow as tf
import VGG16 as vgg
from PIL import Image

data_path ='/opt/..../vgg16.npy'

input_maps = tf.placeholder(tf.float32, [None, 224, 224, 3])
prediction,_ = vgg.inference_op(input_maps,1.0)

image = Image.open('weasel.png')
image = image.convert('RGB')
image = image.resize((224,224))
img_raw = image.tobytes()
image = tf.reshape(tf.decode_raw(img_raw,out_type=tf.uint8),[1,224,224,3])
image = tf.cast(image, tf.float32)

# image = tf.read_file('cat.jpg')
# image = tf.image.decode_jpg(image)
# image = tf.image.convert_image_dtype(image,dtype=tf.float32)
# image = tf.image.resize_images(image, size=[224,224])
# image = tf.reshape(image,[1,224,224,3])

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    #先载入vgg16.npy文件再去进行预测
    vgg.load(data_path, sess)
    image = sess.run(image)
    test_prediction = sess.run([prediction],feed_dict={input_maps:image})
    print(test_prediction)

2. 解析

之前定义网络结构时要使用tf.get_variable()来定义weights和biases,并且名字要和vgg16.npy中的名字相对应。变量名空间可以通过tf.name_scope()或者tf.variable_scope(),但是使用方法不同:

with tf.variable_scope(name):
  kernel = tf.get_variable('weights',shape=[kh,kw,n_input,n_out], dtype=tf.float32,
                   initializer=tf.contrib.layers.xavier_initializer_conv2d())
with tf.name_scope(name) as scope:
  kernel = tf.get_variable(scope+'weights', shape=[n_input, n_out], dtype=tf.float32,
                   initializer=tf.contrib.layers.xavier_initializer_conv2d())
#上面的load函数
with tf.variable_scope(key, reuse=True):
                for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                    session.run(tf.get_variable(subkey).assign(data))

因为tf.name_scope 只能对tf.Variabel()创建的变量的名字有影响,而对tf.get_variabel()创建的变量的名字没有影响。并且tf.get_variabel()只能对tf.get_variabel()创建的变量进行共享。
因此reuse设置为True之后,在模型载入时,我们可以使用tf.variable_scope配合tf.get_variable来载入已经训练好的变量参数。

参考:https://www.jianshu.com/p/14662e980fc0

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值