【Inception-v3模型】迁移学习 实战训练 花朵种类识别

本文介绍了使用Inception-v3模型进行花朵种类识别的迁移学习实战,包括知识储备、准备工作和代码实现。通过调整模型的最后一层进行分类,并提供数据集下载及模型解压后的目录结构。训练和评估代码分别在train.py和eval.py中,解决更换图片后accuracy始终为0的问题,需确保图片标签与实际类别对应。
摘要由CSDN通过智能技术生成

参考博客:【TensorFlow】迁移学习(使用Inception-v3),非常感谢这个博主的这篇博客,我这篇博客的框架来自于这位博主,然后我针对评论区的问题以及自己的实践增加了一些内容以及解答。

github:代码

知识储备

  • 迁移学习是将一个数据集上训练好的网络模型快速转移到另外一个数据集上,可以保留训练好的模型中倒数第一层之前的所有参数,替换最后一层即可,在最后层之前的网络层称之为瓶颈层。
  • 迁移学习,首先尝试了Inception-V3,直接使用pool_3层的输出,接上一个全连接的分类层,使用softmax进行分类,使用Inception-V3的默认输入。

一、准备工作

1、数据集下载
2、Inception-v3模型下载

flower_photos/
    daisy/
    dandelion/
    roses/
    sunflowers/
    tulips/

数据集文件夹包含5个子文件,每一个子文件夹的名称为一种花的名称,代表了不同的类别。平均每一种花有734张图片,每一张图片都是RGB色彩模式,大小也不相同,程序将直接处理没有整理过的图像数据。

  • 模型解压后的目录:
imagenet_comp_graph_label_strings.txt
tensorflow_inception_graph.pb

3、目录结构

  • 需要自行创建transfer-learning/data/tmp/bottlenec/model/train.py/eval.py文件。
transfer-learning/
    data/
        flower_photos/
            ......
        tmp/
            bottleneck/
                ......
    model/
        imagenet_comp_graph_label_strings.txt
        tensorflow_inception_graph.pb
    train.py
    eval.py

二、代码实现

需要写两个文件:1、train.py 2、eval.py

1、train.py

python3 train.py

import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

# 数据参数
MODEL_DIR = 'model/'  # inception-v3模型的文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb'  # inception-v3模型文件名
CACHE_DIR = 'data/tmp/bottleneck'  # 图像的特征向量保存地址
INPUT_DATA = 'data/flower_photos'  # 图片数据文件夹
VALIDATION_PERCENTAGE = 10  # 验证数据的百分比
TEST_PERCENTAGE = 10  # 测试数据的百分比

# inception-v3模型参数
BOTTLENECK_TENSOR_SIZE = 2048  # inception-v3模型瓶颈层的节点个数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 图像输入张量对应的名称

# 神经网络的训练参数
LEARNING_RATE = 0.01
STEPS = 1000
BATCH = 100
CHECKPOINT_EVERY = 100
NUM_CHECKPOINTS = 5


# 从数据文件夹中读取所有的图片列表并按训练、验证、测试分开
def create_image_lists(validation_percentage, test_percentage):
    result = {}  # 保存所有图像。key为类别名称。value也是字典,存储了所有的图片名称
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]  # 获取所有子目录
    is_root_dir = True  # 第一个目录为当前目录,需要忽略

    # 分别对每个子目录进行操作
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue

        # 获取当前目录下的所有有效图片
        extensions = {'jpg', 'jpeg', 'JPG', 'JPEG'}
        file_list = []  # 存储所有图像
        dir_name = os.path.basename(sub_dir)  # 获取路径的最后一个目录名字
        for extension in extensions:
            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
            file_list.extend(glob.glob(file_glob))
        if not file_list:
            continue

        # 将当前类别的图片随机分为训练数据集、测试数据集、验证数据集
        label_name = dir_name.lower()  # 通过目录名获取类别的名称
        training_images = []
        testing_images = []
        validation_images = []
        for file_name in file_list:
            base_name = os.path.basename(file_name)  # 获取该图片的名称
            chance = np.random.randint(100)  # 随机产
  • 2
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值