Tensorflow slim 选择性读取加载权重

inception_v1.ckpt的下载地址
http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz
并存入解压,存入文件夹中
我存储在’./tmp/checkpoints’
第一种方式:

import os
from nets import inception
import tensorflow as tf
from tensorflow.contrib import slim

checkpoints_dir = './tmp/checkpoints'
heigh, width, channels, n_class = 224, 224, 3, 3
with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)
    X = tf.placeholder(dtype=tf.float32, shape=[None, heigh, width, channels])
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(X, num_classes=n_class, is_training=True)  # 根据logit返回的值,可以继续添加训练网络
        checkpoint_exclude_scopes = ["InceptionV1/Logits", "InceptionV1/AuxLogits"]  
        exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
        variables_to_restore = []
        for var in slim.get_model_variables():
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    break
            else:
                variables_to_restore.append(var)
        saver = tf.train.Saver(variables_to_restore)
        with tf.Session() as sess:
            saver.restore(sess, os.path.join(checkpoints_dir, 'inception_v1.ckpt'))
            print('Good Job')

第二种方式:

import os
from nets import inception
import tensorflow as tf
from tensorflow.contrib import slim

#添加inception-v1模型
checkpoints_dir = './tmp/checkpoints'
heigh, width, channels, n_class = 224, 224, 3, 3
with tf.Graph().as_default():
    #将 TensorFlow 日志信息输出到屏幕 TensorFlow有五个不同级别的日志信息。
    tf.logging.set_verbosity(tf.logging.INFO)
    #定义X
    X = tf.placeholder(dtype=tf.float32, shape=[None, heigh, width, channels])
    # Create the model, use the default arg scope to configure the batch norm parameters.
    #用 slim.arg_scope()为目标函数设置默认参数.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        # 根据logit返回的值,可以继续添加训练网络
        logits, _ = inception.inception_v1(X, num_classes=n_class, is_training=True)
        exclude=['InceptionV1/Logits']
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
        saver = tf.train.Saver(variables_to_restore)
        with tf.Session() as sess:
            saver.restore(sess, os.path.join(checkpoints_dir, 'inception_v1.ckpt'))
            print('Good Job')

第三种方式:

import os
from nets import inception
import tensorflow as tf
from tensorflow.contrib import slim

#添加inception-v1模型
checkpoints_dir = './tmp/checkpoints'
heigh, width, channels, n_class = 224, 224, 3, 3
with tf.Graph().as_default():
    #将 TensorFlow 日志信息输出到屏幕 TensorFlow有五个不同级别的日志信息。
    tf.logging.set_verbosity(tf.logging.INFO)
    #定义X
    X = tf.placeholder(dtype=tf.float32, shape=[None, heigh, width, channels])
    # Create the model, use the default arg scope to configure the batch norm parameters.
    #用 slim.arg_scope()为目标函数设置默认参数.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        # 根据logit返回的值,可以继续添加训练网络
        logits, _ = inception.inception_v1(X, num_classes=n_class, is_training=True)
        exclude=['InceptionV1/Logits']
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
        init_fn = slim.assign_from_checkpoint_fn(os.path.join(checkpoints_dir, 'inception_v1.ckpt'), variables_to_restore)

        with tf.Session() as sess:
            init_fn(sess)
            print('Good Job')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值