基于Inception-V3模型的迁移学习在图像识别分类中的应用

1.背景

        自1998年LeNet-5模型的提出一直到现在,卷积神经网络模型的层数和复杂度都发生了巨大的变化,下表中罗列了ILSVRC(Lareg Scale Visual Recognition Challenge)第一名模型的表现:

年份模型名称层数Top5错误率
2012AlexNet815.3%
2013ZF Net814.8%
2014GoogLeNet226.67%
2015ResNet1523.57%

        可以看到随着层数的增加,模型的识别错误率也在降低。然而ImagNet图像分类数据集中有120万标注图片, 在真实的引用中,很难收集到海量的训练数据,要训练一个复杂的卷积神经网络需要几天甚至几周的时间。为了解决该问题,迁移学习的概念被提出来。所谓迁移学习,就是讲一个问题上训练好的模型通过调整使之适应于一个新的问题。

2.基于tensorflow实现

        该迁移学习用到的数据集合为http://download.tensorflow.org/example_images/flower_photos.tgz.

        数据预处理代码,将彩色的三通道的图片转化为矩阵形式并通过numpy保存到一个文件中,模型的下载地址为http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,具体代码如下:

 

"""
将原始图像数据整理成模型需要的输入数据
"""
import glob
import os

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

INPUT_DATA = r"E:\master_paper\DATA_SET\flower_photos"

OUTPUT_DATA = r"E:\paper_exp\paper1\tansferlearning\flower_procssed_data.npy"

VALIDATION_PRECENTAGE = 10
TEST_PERCENTAGE = 10


def create_image_list(sess, test_percentage, validation_percentage):
    """
    读取数据并将数据分割成训练数据,验证数据和测试数据
    :param sess:
    :param test_percentage:
    :param validation_percentage:
    :return:
    """
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    is_root_dir = True

    # 初始化各个数据集
    training_images = []
    training_labels = []
    testing_images = []
    testing_labels = []
    validation_images = []
    validation_labels = []

    current_label = 0

    # 读取子目录
    for sub_dir in sub_dirs:
        if is_root_dir:  # 过滤掉根目录
            is_root_dir = False
            continue
        # 获取一个子目录中的所有图片文件
        extensions = ['jpg', 'jpeg']
        file_list = []
        dir_name = os.path.basename(sub_dir)
        for extension in extensions:
            file_glob = os.path.join(INPUT_DATA, dir_name, "*." + extension)
            file_list.extend(glob.glob(file_glob))  # glob 获取满足正则表达式的文件名

        if not file_list:
            continue

        # 处理图片数据
        counter = 0
        total = len(file_list)
        for file_name in file_list:
            counter += 1
            print("total: [%s / %s] current process image is: %s" % (counter, total, file_name))

            image_raw_data = gfile.FastGFile(file_name, 'rb').read()
            image = tf.image.decode_jpeg(image_raw_data)
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)

            image = tf.image.resize_images(image, [299, 299])
            image_value = sess.run(image)

            # 随机划分数据集
            chance = np.random.randint(100)
            if chance < validation_percentage:
                validation_images.append(image_value)
                validation_labels.append(current_label)
            elif chance < (test_percentage + validation_percentage):
                testing_images.append(image_value)
                testing_labels.append(current_label)
            else:
                training_images.append(image_value)
                training_labels.append(current_label)

        current_label += 1

    # 将训练数据随机打乱以获得更好的训练效果
    state = np.random.get_state()
    np.random.shuffle(training_images)
    np.random.set_state(state)
    np.random.shuffle(training_labels)

    return np.asarray([training_images, training_labels,
                       validation_images, validation_labels,
                       testing_images, testing_labels])


def main():
    with tf.Session() as sess:
        processed_data = create_image_list(sess, TEST_PERCENTAGE, VALIDATION_PRECENTAGE)
        np.save(OUTPUT_DATA, processed_data)


if __name__ == '__main__':
    main()

        接着运行上述数据预处理代码,后会在本地文件夹下新增数据处理的结果,接下来,使用迁移学习将Inception-V3应用于图像分类任务中,具体实现代码如下:

"""
迁移学习
"""
import numpy as np
import tensorflow as tf

from tensorflow.contrib import slim

# 处理结束之后的数据文件
import tensorflow.contrib.slim.python.slim.nets.inception_v3  as inception_v3

INPUT_DATA = "./flower_procssed_data.npy"
# 保存训练好的模型路径
TRAIN_FILE = "./savemodel"
# 谷歌提供训练好的模型文件地址
CKPT_FILE = "./inception_v3.ckpt"

# 定义训练中使用的参数
LEARNING_RAGE = 0.0001
STEPS = 300
BATCH = 32
N_CLASS = 5

CHECKPOINT_EXCLUDE_SCOPES = "InceptionV3/Logits,InceptionV3/AuxLogits"
TRAINABLE_SCOPES = "InceptionV3/Logits,InceptionV3/AuxLogits"


def get_tuned_variables():
    """
    获取所有需要从谷歌训练好的模型中加载的参数
    :return:
    """
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
    variables_to_restore = []

    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return variables_to_restore


def get_trainable_variables():
    """
    获取所有需要训练的变量列表
    :return:
    """
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
    variables_to_train = []

    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train


def main(argv=None):
    processed_data = np.load(INPUT_DATA, allow_pickle=True)

    training_images = processed_data[0]
    training_labels = processed_data[1]
    n_training_example = len(training_labels)

    validation_images = processed_data[2]
    validation_labels = processed_data[3]

    test_images = processed_data[4]
    test_labels = processed_data[5]

    print("%d training examples, %d validation examples and %d testing examples."
          % (n_training_example, len(validation_labels), len(test_labels)))

    # 定义图片的输入
    images = tf.placeholder(tf.float32, [None, 299, 299, 3], name='input_images')
    labels = tf.placeholder(tf.int64, [None], name='labels')

    # 定义Inception-v3模型
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        logits, _ = inception_v3.inception_v3(images, num_classes=N_CLASS)

    # 获取要训练的变量
    trainable_variables = get_trainable_variables()
    tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASS), logits, weights=1.0)

    # 定义训练过程
    train_step = tf.train.RMSPropOptimizer(LEARNING_RAGE).minimize(tf.losses.get_total_loss())

    # 计算正确率
    with tf.name_scope('evaluation'):
        correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
        evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # 定义加载模型的函数
    load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE, get_tuned_variables(), ignore_missing_vars=True)

    # 定义保存新训练好的模型的函数
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        # 加载谷歌已经训练好的模型
        print("loading tunned variables from %s" % CKPT_FILE)
        load_fn(sess)

        start = 0
        end = BATCH
        for i in range(STEPS):
            # 运行训练过程
            sess.run(train_step, feed_dict={
                images: training_images[start:end],
                labels: training_labels[start:end]
            })

            # 输出日志
            if i % 30 == 0 or i + 1 == STEPS:
                saver.save(sess, TRAIN_FILE, global_step=i)
                validation_accuracy = sess.run(evaluation_step, feed_dict={
                    images: validation_images,
                    labels: validation_labels
                })
                print("step %d: validation accuracy = %1.f%%" % (
                    i, validation_accuracy * 100.0
                ))

            start = end
            if start == n_training_example:
                start = 0

            end = start + BATCH
            if end > n_training_example:
                end = n_training_example

        # 在最后的测试集上测试正确率
        test_accuracy = sess.run(evaluation_step, feed_dict={
            images: test_images,
            labels: test_labels
        })

        print("final test accuracy = %.1f%%" % (test_accuracy * 100))


if __name__ == '__main__':
    tf.app.run()

 

 

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值