数据集下载:经典的cat数据
inception-v3下载:classify_image_graph_def.pb
一、注意事项
- 创建sess时启动inception-v3: tf.Session(graph=graph);
- 最后一层全连接层的输入使用tf.placeholder_with_default()占位,而不是tf.placeholder();
- 训练时记得保存labels.txt,用于后续使用remodel.pb;(主要是label值和顺序)
- 在使用remodel.pb时,注意导入tensor的名称书写正确性;(可通过查看tensorboard中graph进行判断)
二、使用Inception-v3进行retrain
1、指定路径和参数
路径:
- 输入:数据集(注意文件夹格式)、inception-v3存放路径、
- 输出:inception-v3模型输出值、tensorboard训练日志、重训练后的新模型、数据集labels保存
参数:
- 验证集\测试集占比、迭代次数、batch大小、inception-v3模型output的节点数、学习率
import numpy as np
import tensorflow as tf
import os
import glob
import random
from tensorflow.python.framework import graph_util
import shutil
path = r'E:\cat' #主文件夹
datapath = os.path.join(path, 'dataset') #图像数据路径
bottleneckpath = os.path.join(path, 'bottleneck') #inception-v3输出结果缓存路径
premodel_path = os.path.join(path, 'inception-2015-12-05', 'classify_image_graph_def.pb') #inception-v3模型路径
log_path = os.path.join(path, 'log') #tensorboard训练日志保存路径
remodel_path = os.path.join(path, 'remodel', 'remodel.pb') #重训练后的模型保存路径
label_path = os.path.join(path, 'labels.txt') #数据集labels保存路径(用于retrain.pb测试)
valpct = 10 #验证集占比
testpct = 10 #测试集占比
max_steps = 6000 #迭代次数
batch = 100 #batch_size
bottleneck_size = 2048 #inception-v3模型output的节点数(固定)
learn_rate = 0.01 #学习率
tf.reset_default_graph()
2、导入inception-v3,构建最后一层全连接层,对新数据训练验证测试
#读数据:获取图片路径
def get_img_path():
sub_dirs = [i[0] for i in os.walk(datapath)]
is_root_file = True
results = {}
for sub_dir in sub_dirs:
if is_root_file:
is_root_file = False
continue
label = os.path.basename(sub_dir).lower()
imgs