inception_v1.ckpt的下载地址
http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz
并存入解压,存入文件夹中
我存储在’./tmp/checkpoints’
第一种方式:
import os
from nets import inception
import tensorflow as tf
from tensorflow.contrib import slim
checkpoints_dir = './tmp/checkpoints'
heigh, width, channels, n_class = 224, 224, 3, 3
with tf.Graph().as_default():
tf.logging.set_verbosity(tf.logging.INFO)
X = tf.placeholder(dtype=tf.float32, shape=[None, heigh, width, channels])
# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(inception.inception_v1_arg_scope()):
logits, _ = inception.inception_v1(X, num_classes=n_class, is_training=True) # 根据logit返回的值,可以继续添加训练网络
checkpoint_exclude_scopes = ["InceptionV1/Logits", "InceptionV1/AuxLogits"]
exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
variables_to_restore = []
for var in slim.get_model_variables():
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
break
else:
variables_to_restore.append(var)
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
saver.restore(sess, os.path.join(checkpoints_dir, 'inception_v1.ckpt'))
print('Good Job')
第二种方式:
import os
from nets import inception
import tensorflow as tf
from tensorflow.contrib import slim
#添加inception-v1模型
checkpoints_dir = './tmp/checkpoints'
heigh, width, channels, n_class = 224, 224, 3, 3
with tf.Graph().as_default():
#将 TensorFlow 日志信息输出到屏幕 TensorFlow有五个不同级别的日志信息。
tf.logging.set_verbosity(tf.logging.INFO)
#定义X
X = tf.placeholder(dtype=tf.float32, shape=[None, heigh, width, channels])
# Create the model, use the default arg scope to configure the batch norm parameters.
#用 slim.arg_scope()为目标函数设置默认参数.
with slim.arg_scope(inception.inception_v1_arg_scope()):
# 根据logit返回的值,可以继续添加训练网络
logits, _ = inception.inception_v1(X, num_classes=n_class, is_training=True)
exclude=['InceptionV1/Logits']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
saver.restore(sess, os.path.join(checkpoints_dir, 'inception_v1.ckpt'))
print('Good Job')
第三种方式:
import os
from nets import inception
import tensorflow as tf
from tensorflow.contrib import slim
#添加inception-v1模型
checkpoints_dir = './tmp/checkpoints'
heigh, width, channels, n_class = 224, 224, 3, 3
with tf.Graph().as_default():
#将 TensorFlow 日志信息输出到屏幕 TensorFlow有五个不同级别的日志信息。
tf.logging.set_verbosity(tf.logging.INFO)
#定义X
X = tf.placeholder(dtype=tf.float32, shape=[None, heigh, width, channels])
# Create the model, use the default arg scope to configure the batch norm parameters.
#用 slim.arg_scope()为目标函数设置默认参数.
with slim.arg_scope(inception.inception_v1_arg_scope()):
# 根据logit返回的值,可以继续添加训练网络
logits, _ = inception.inception_v1(X, num_classes=n_class, is_training=True)
exclude=['InceptionV1/Logits']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
init_fn = slim.assign_from_checkpoint_fn(os.path.join(checkpoints_dir, 'inception_v1.ckpt'), variables_to_restore)
with tf.Session() as sess:
init_fn(sess)
print('Good Job')