迁移学习将已经训练好的模型导入并使用,提高了训练效率。
以下为将原始图像数据整理为模型需要的输入数据实例。
首先进行数据预处理:
#-*- coding: utf-8 -*-
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
INPUT_DATA = "E:/train-data/Flower_Photo/flower_photos"
OUTPUT_FILE = 'E:/train-data/Pro_Flower_Photo'
#测试数据与验证数据比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
#读取数据并分类数据
def create_image_lists(sess, testing_percentage, validation_percentage):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]#遍历文件下子文件
is_root_dir = True
training_images = []
training_labels = []
testing_images = []
testing_labels = []
validation_images = []
validation_labels = []
current_label = 0
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)#获取文件名
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)#把目录和文件名合成一个路径
file_list.extend(glob.glob(file_glob))#glob用于从目录通配符搜索中生成文件列表
if not file_list: continue
for file_name in file_list:
#读取并解析图片,将图片转化为299*299以便inception-v3模型处理
#读取图像,参数:1、图像路径;2、读取方式
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
image = tf.image.decode_jpeg(image_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(
image, dtype=tf.float32
)
image = tf.image.resize_images(image,[299,299])
image_value = sess.run(image)
#随机划分数据集
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < (testing_percentage + validation_percentage):
testing_images.append(image_value)
testing_images.append(current_label)
else:
training_images.append(image_value)
training_labels.append(current_label)
current_label+=1
#打乱数据
state = np.random.get_state()
np.random.shuffle(training_images)
np.set_state(state)
np.random.shuffle(training_labels)
return np.asarray([training_images, training_labels,
validation_images, validation_labels,
testing_images, testing_labels])
def main():
with tf.Session() as sess:
processed_data = create_image_lists(
sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE
)
#通过numpy格式保存处理后的数据
np.save(OUTPUT_FILE, processed_data)
if __name__ =='__main__':
main()
np.asarray 浅拷贝c=np.asarray(a),修改a改变c的值
数据处理之后通过下载好的模型,可进行迁移学习。
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
INPUT_DATA = 'E:/train-data/Pro_Flower_Photo'
#训练之后存放模型的路径
TRAIN_FILE = "E:/model/inception-v3"
#下载的模型
CKPT_FILE = "E:/model/inception-v3"
LEARNING_RATE = 0.0001
STEPS = 300
BATCH = 32
N_CLASSES = 5
#不需要从模型中加载的参数。这里就是最后的全连接层,因为要重新训练这一层的参数
#以下给出的是参数的前缀
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits, InceptionV3/AuxLogits'
#需要训练的网络层参数
TRAINABLE_SCOPES = 'InceptionV3/Logits, InceptionV3/AuxLogits'
#从模型中获取需要的参数
def get_tuned_variables():
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startwith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore
#获取需要训练的变量
def get_trainable_variables():
scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
variables_to_train = []
for scope in scopes:
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
def main():
processed_data = np.load(INPUT_DATA)
#...
images = tf.placeholder(tf.float32, [None, 299, 299, 3], name='input_images')
labels = tf.placeholder(tf.int64, [None], name='labels')
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
logits, _ = inception_v3.inception_v3(
images, num_classes=N_CLASSES
)
trainable_variables = get_trainable_variables()
tf.losses.softmax_cross_entropy(
tf.one_hot(labels, N_CLASSES), logits, weights=1.0
)
train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss)
#计算正确率
with tf.name_scope('evaluation'):
correction_rate = tf.equal(tf.arg_max(logits,1),labels)
evaluation_step = tf.reduce_mean(tf.cast(correction_rate, tf.float32))
#定义加载模型的函数
load_fn = slim.assign_from_checkpoint_fn(
CKPT_FILE,
get_tuned_variables(),
ignore_missing_vars=True
)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print('Loading tuned variables from %s '% CKPT_FILE)
load_fn(sess)#加载训练好的模型
start = 0
end = BATCH
for i in range(STEPS):
sess.run(train_step, feed_dict={
images: training_images[start:end]
labels: training_labels[start:end]}
})