参考博客:【TensorFlow】迁移学习(使用Inception-v3),非常感谢这个博主的这篇博客,我这篇博客的框架来自于这位博主,然后我针对评论区的问题以及自己的实践增加了一些内容以及解答。
github:代码
知识储备
- 迁移学习是将一个数据集上训练好的网络模型快速转移到另外一个数据集上,可以保留训练好的模型中倒数第一层之前的所有参数,替换最后一层即可,在最后层之前的网络层称之为瓶颈层。
- 迁移学习,首先尝试了Inception-V3,直接使用pool_3层的输出,接上一个全连接的分类层,使用softmax进行分类,使用Inception-V3的默认输入。
一、准备工作
1、数据集下载
2、Inception-v3模型下载
flower_photos/
daisy/
dandelion/
roses/
sunflowers/
tulips/
数据集文件夹包含5个子文件,每一个子文件夹的名称为一种花的名称,代表了不同的类别。平均每一种花有734张图片,每一张图片都是RGB色彩模式,大小也不相同,程序将直接处理没有整理过的图像数据。
- 模型解压后的目录:
imagenet_comp_graph_label_strings.txt
tensorflow_inception_graph.pb
3、目录结构
- 需要自行创建transfer-learning/data/tmp/bottlenec/model/train.py/eval.py文件。
transfer-learning/
data/
flower_photos/
......
tmp/
bottleneck/
......
model/
imagenet_comp_graph_label_strings.txt
tensorflow_inception_graph.pb
train.py
eval.py
二、代码实现
1、train.py
python3 train.py
import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
# 数据参数
MODEL_DIR = 'model/' # inception-v3模型的文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb' # inception-v3模型文件名
CACHE_DIR = 'data/tmp/bottleneck' # 图像的特征向量保存地址
INPUT_DATA = 'data/flower_photos' # 图片数据文件夹
VALIDATION_PERCENTAGE = 10 # 验证数据的百分比
TEST_PERCENTAGE = 10 # 测试数据的百分比
# inception-v3模型参数
BOTTLENECK_TENSOR_SIZE = 2048 # inception-v3模型瓶颈层的节点个数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' # 图像输入张量对应的名称
# 神经网络的训练参数
LEARNING_RATE = 0.01
STEPS = 1000
BATCH = 100
CHECKPOINT_EVERY = 100
NUM_CHECKPOINTS = 5
# 从数据文件夹中读取所有的图片列表并按训练、验证、测试分开
def create_image_lists(validation_percentage, test_percentage):
result = {} # 保存所有图像。key为类别名称。value也是字典,存储了所有的图片名称
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] # 获取所有子目录
is_root_dir = True # 第一个目录为当前目录,需要忽略
# 分别对每个子目录进行操作
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取当前目录下的所有有效图片
extensions = {'jpg', 'jpeg', 'JPG', 'JPEG'}
file_list = [] # 存储所有图像
dir_name = os.path.basename(sub_dir) # 获取路径的最后一个目录名字
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
# 将当前类别的图片随机分为训练数据集、测试数据集、验证数据集
label_name = dir_name.lower() # 通过目录名获取类别的名称
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name) # 获取该图片的名称
chance = np.random.randint(100) # 随机产