【TensorFlow】迁移学习(使用Inception-v3)

本文档详细介绍了如何使用TensorFlow进行迁移学习,特别是利用Inception-v3模型对flower数据集进行训练。包括数据集的下载、模型结构、代码实现、训练过程以及测试方法。通过替换Inception-v3的全连接层,实现针对特定任务的分类。
摘要由CSDN通过智能技术生成

项目已上传至 GitHub —— transfer-learning

更新:

  • 2018/3/16:添加了保存模型和摘要的代码,都保存在 runs/ 目录下
  • 2018/3/17:添加了测试代码,用于测试一张图片的类别
  • 2018/3/18:添加了训练之后输出标签文件的代码

1. 数据集及模型下载

1.1 数据集

通过官方下载地址,下载之后解压。解压后的目录结构:

flower_photos/
    daisy/
    dandelion/
    roses/
    sunflowers/
    tulips/

解压之后包含 5 个子文件夹,每个子文件夹的名称为一种花的名称,平均每一种花有 734 张图片,每张图片都是 RGB 色彩模式,大小不同。

1.2 Inception-v3模型

以下有两种下载方式,如果链接失效可以搜索网上的资源:

解压后有两个文件,将要使用的是 .pb 文件:

imagenet_comp_graph_label_strings.txt
tensorflow_inception_graph.pb

2. 目录结构

将数据集及模型文件下载好之后,分别放在 data/ 和 model/ 文件夹下,然后新建一个 train.py 文件用于实现迁移学习。

还需要新建一个 tmp/bottleneck/ 文件夹用于存放每张图片通过 Inception-v3 模型计算得到的特征向量。该文件夹的结构与 flower_photos 文件夹类似,可以在代码中生成各子文件夹,或者手动创建。

目录结构如下:

transfer-learning/
    data/
        flower_photos/
            ......
        tmp/
            bottleneck/
                ......
    model/
        imagenet_comp_graph_label_strings.txt
        tensorflow_inception_graph.pb
    train.py

3. 完整代码

该代码实现自《TensorFlow:实战Google深度学习框架》

该迁移学习方法的实现是,替换掉了 Inception-v3 模型的最后一层全连接层。用瓶颈层的输出来训练一个新的全连接层处理花的分类问题。

由于训练数据、验证数据和测试数据都是训练的时候随机分配的,所以训练正确率是不可再现的,并且差距较大,甚至能达到 10% 左右的差距。不过能用这么少的数据集达到 90% 以上的正确率也是很不错了。

源码如下:

import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

# 数据参数
MODEL_DIR = 'model/'  # inception-v3模型的文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb'  # inception-v3模型文件名
CACHE_DIR = 'data/tmp/bottleneck'  # 图像的特征向量保存地址
INPUT_DATA = 'data/flower_photos'  # 图片数据文件夹
VALIDATION_PERCENTAGE = 10  # 验证数据的百分比
TEST_PERCENTAGE = 10  # 测试数据的百分比

# inception-v3模型参数
BOTTLENECK_TENSOR_SIZE = 2048  # inception-v3模型瓶颈层的节点个数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 图像输入张量对应的名称

# 神经网络的训练参数
LEARNING_RATE = 0.01
STEPS = 1000
BATCH = 100
CHECKPOINT_EVERY = 100
NUM_CHECKPOINTS = 5


# 从数据文件夹中读取所有的图片列表并按训练、验证、测试分开
def create_image_lists(validation_percentage, test_percentage):
    result = {
    }  # 保存所有图像。key为类别名称。value也是字典,存储了所有的图片名称
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]  # 获取所有子目录
    is_root_dir = True  # 第一个目录为当前目录,需要忽略

    # 分别对每个子目录进行操作
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue

        # 获取当前目录下的所有有效图片
        extensions = {
    'jpg', 'jpeg', 'JPG', 'JPEG'}
        file_list = []  # 存储所有图像
        dir_name = os.path.basename(sub_dir)  # 获取路径的最后一个目录名字
        for extension in extensions:
            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
            file_list.extend(glob.glob(file_glob))
        if not file_list:
            continue

        # 将当前类别的图片随机分为训练数据集、测试数据集、验证数据集
        label_name = dir_name.lower()  # 通过目录名获取类别的名称
        training_images = []
        testing_images = []
        validation_images = []
        for file_name in file_list:
            base_name = os.path.basename(file_name)  # 获取该图片的名称
            chance = np.random.randint(100)  # 随机产生100个数代表百分比
            if chance < validation_percentage:
                validation_images.append(base_name)
            elif chance < (validation_percentage + test_percentage):
                testing_images.append(base_name)
            else:
                training_images.append(base_name)

        # 将当前类别的数据集放入结果字典
        result[label_name] = {
    
            'dir': dir_name,
            'training': training_images,
            'testing': testing_images,
            'validation': validation_images
        }

    # 返回整理好的所有数据
    return result


# 通过类别名称、所属数据集、图片编号获取一张图片的地址
def get_image_path(image_lists, image_dir, label_name, index, category):
    label_lists = image_lists[label_name]  # 获取给定类别中的所有图片
    category_list = label_lists[category]  # 根据所属数据集的名称获取该集合中的全部图片
    mod_index = index % len(category_list)  # 规范图片的索引
    base_name = category_list[mod_index]  # 获取图片的文件名
    sub_dir = label_lists['dir'] 
评论 105
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值