1、导入相关包
import tensorflow as tf
import os
import zipfile
import requests
import glob
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
2、下载inception模型并解压
inception_model_url = 'https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip'
filename = inception_model_url.split('/')[-1]
filepath = os.path.join('./', filename)
if not os.path.exists(filepath):
print("download: ", filename)
r = requests.get(inception_model_url, stream=True)
with open(filepath, 'wb') as f:
for chunk in r.iter_content(chunk_size=2048):
if chunk:
f.write(chunk)
print("finish: ", filename)
path = './inception_dec_2015'
if not os.path.exists(path):
os.makedirs(path)
print(path+'\tcerate file successful')
else:
print(path+'\tfile has existed')
zipfile.ZipFile(filepath).extractall(path)
print(path+'\tfile has extractalled')
3、数据集处理与展示
将数据集分为10%的测试集与90%的训练集,并且随机打乱,注意图片数据一般是以路径输入的,而非原始图片数据。
datas_path = './flower_photos'
sub_dirs = [x[0] for x in os.walk(datas_path)][1:]
class_name = ['daisy','dandelion','roses','sunflowers','tulips']
# 初始化各个数据集。
training_images_path = []
training_labels = []
testing_images_path = []
testing_labels = []
current_label = 0
# 读取所有的子目录。
for sub_dir in sub_dirs:
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
#返回路径最后的文件名
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(datas_path, dir_name, '*.' + extension)
#glob.glob返回匹配格式的所有变量
file_list.extend(glob.glob(file_glob))
if not file_list:
print('the path:{} don\'t have any file'.format(os.path.join(datas_path, dir_name)))
test_count = 0
train_count = 0
for file_name in file_list:
# 随机划分数据。
chance = np.random.randint(100)
if chance < 10:
testing_images_path.append(file_name)
testing_labels.append(current_label)
test_count += 1
else:
training_images_path.append(file_name)
training_labels.append(current_label)
train_count += 1
print('for the {} class,lable is {},there are {} testing_images and {} train_images'.
format(dir_name,current_label,test_count,train_count))
current_label += 1
# 将训练数据随机打乱
state = np.random.get_state()
np.random.shuffle(training_images_path)
np.random.set_state(state)
np.random.shuffle(training_labels)
plt.figure(figsize=(6,6))
for i in range(9):
plt.subplot(3,3,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
img = cv2.imread(training_images_path[i])
plt.imshow(img[:,:,::-1])
plt.xlabel(class_name[training_labels[i]])
plt.show()
得到结果如下,大概5000张训练图片,500张测试图片,数据集是比较小的,所以后期训练时做了数据增强处理。
4、数据准备
接下来构建数据迭达器,喂给网络的是图像数据,需要根据路径解析出图片数据,同时构建批次和数据增强处理
class dataset(object):
def __init__(self,data_paths,labels,batch_size,trainable):
ix = np.arange(len(data_paths))
np.random.shuffle(ix)
self.labels = [labels[i] for i in ix]
self.data_paths = [data_paths[i] for i in ix]
self.batch_size = batch_size
self.batch_num = int(np.floor(len(data_paths)/ batch_size))
self.res = len(data_paths) % batch_size
self.batch_count = 0
self.trainable = trainable
def __iter__(self):
return self
def __next__(self):
with tf.device('/cpu:0'):
if self.trainable == True:
if(self.res!=0):
batch_images = np.zeros(shape = (self.res*4,299,299,3)).astype(np.float32)
batch_labels = np.zeros(shape = (self.res*4)).astype(np.int64)
for num in range(self.res):
image_path = self.data_paths[len(self.data_paths)-self.res+num]
image_data = cv2.imread(image_path)
image_data = cv2.resize(image_data,(299,299))
batch_images[4*num,:] = (cv2.flip(image_data, -1)[:,:,::-1]-128)*0.0078125
batch_images[4*num+1,:] = (cv2.flip(image_data, 0)[:,:,::-1]-128)*0.0078125
batch_images[4*num+2,:] = (cv2.flip(image_data, 1)[:,:,::-1]-128)*0.0078125
batch_images[4*num+3,:] = (image_data[:,:,::-1]-128)*0.0078125
batch_labels[4*num] = self.labels[len(self.labels)-self.res+num]
batch_labels[4*num+1] = self.labels[len(self.labels)-self.res+num]
batch_labels[4*num+2] = self.labels[len(self.labels)-self.res+num]
batch_labels[4*num+3] = self.labels[len(self.labels)-self.res+num]
self.res = 0
state = np.random.get_state()
np.random.shuffle(batch_images)
np.random.set_state(state)
np.random.shuffle(batch_labels)
return batch_images,batch_labels
else:
if self.batch_count < self.batch_num:
batch_labels = np.zeros(shape=(self.batch_size*4)).astype(np.int64)
batch_images = np.zeros(shape = (self.batch_size*4,299,299,3)).astype(np.float32)
for num in range(self.batch_size):
image_path = self.data_paths[num+self.batch_count*self.batch_size]
image_data = cv2.imread(image_path)
image_data = cv2.resize(image_data,(299,299))
batch_images[4*num,:] = (cv2.flip(image_data, -1)[:,:,::-1]-128)*0.0078125
batch_images[4*num+1,:] = (cv2.flip(image_data, 0)[:,:,::-1]-128)*0.0078125
batch_images[4*num+2,:] = (cv2.flip(image_data, 1)[:,:,::-1]-128)*0.0078125
batch_images[4*num+3,:] = (image_data[:,:,::-1]-128)*0.0078125
batch_labels[4*num] = self.labels[self.batch_count*self.batch_size+num]
batch_labels[4*num+1] = self.labels[self.batch_count*self.batch_size+num]
batch_labels[4*num+2] = self.labels[self.batch_count*self.batch_size+num]
batch_labels[4*num+3] = self.labels[self.batch_count*self.batch_size+num]
self.batch_count += 1
state = np.random.get_state()
np.random.shuffle(batch_images)
np.random.set_state(state)
np.random.shuffle(batch_labels)
return batch_images,batch_labels
else:
self.batch_count = 0
raise StopIteration
else:
if(self.res!=0):
batch_images = np.zeros(shape = (self.res,299,299,3)).astype(np.float32)
batch_labels = np.array(self.labels[len(self.labels)-self.res:]).astype(np.int64)
for num in range(self.res):
image_path = self.data_paths[len(self.data_paths)-self.res+num]
image_data = cv2.imread(image_path)
image_data = cv2.resize(image_data,(299,299))
batch_images[num,:] = (image_data[:,:,::-1]-128)*0.0078125
self.res = 0
state = np.random.get_state()
np.random.shuffle(batch_images)
np.random.set_state(state)
np.random.shuffle(batch_labels)
return batch_images,batch_labels
else:
if self.batch_count < self.batch_num:
batch_labels = np.array(self.labels[self.batch_count*self.batch_size:(1+self.batch_count)*self.batch_size]).astype(np.int64)
batch_images = np.zeros(shape = (self.batch_size,299,299,3)).astype(np.float32)
for num in range(self.batch_size):
image_path = self.data_paths[num+self.batch_count*self.batch_size]
image_data = cv2.imread(image_path)
image_data = cv2.resize(image_data,(299,299))
batch_images[num,:] = (image_data[:,:,::-1]-128)*0.0078125
self.batch_count += 1
state = np.random.get_state()
np.random.shuffle(batch_images)
np.random.set_state(state)
np.random.shuffle(batch_labels)
return batch_images,batch_labels
else:
self.batch_count = 0
raise StopIteration
5、调整网络结构
利用Netro查看下载的inception网络PB模型,如下图所示
为了方便网络批量训练,可以将pb文件转换为pbtxt文件进行修改,然后再转换为pb文件,更改后的效果如下:直接将那一部分利用占位符取代。
具体代码如下:
#pb模型转为pbtxt
from tensorflow.core.framework import graph_pb2
with tf.Session() as sess:
with tf.gfile.FastGFile("./inception_dec_2015/tensorflow_inception_graph.pb", 'rb') as f:
graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
tf.train.write_graph(graph_def, './', "./inception_dec_2015/tensorflow_inception_graph.pbtxt", as_text=True)
#pbtxt转为pb文件
from google.protobuf import text_format
with tf.Session() as sess:
with tf.gfile.FastGFile('./inception_dec_2015/tensorflow_inception_graph.pbtxt', 'rb') as f:
graph_def = graph_pb2.GraphDef()
new_graph_def=text_format.Merge(f.read(), graph_def)
tf.train.write_graph(new_graph_def, './', './inception_dec_2015/tensorflow_inception_graph_new.pb', as_text=False)
6、网络搭建与训练
本人小破本只有4G显存,跑网络经常出现显存方面的问题,在利用inception前向计算提取特征的过程中,一直报错,直到添加了如下代码:
config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
'''log_device_placement=True : 是否打印设备分配日志
allow_soft_placement=True : 如果你指定的设备不存在,允许TF自动分配设备'''
config.gpu_options.allow_growth = True
'''config.gpu_options.allow_growth = True动态申请显存,需要多少就申请多少'''
#config.gpu_options.per_process_gpu_memory_fraction = 0.95 #最多占用95%显存
接下来就可以愉快的炼丹了,我去除掉inception中最后的分类层,替代为对五种花朵进行识别的全连接网络,需要注意的仅需要设置分类层中的变量为训练变量。
graph=tf.Graph()
model_filename = "./inception_dec_2015/tensorflow_inception_graph_new.pb"
with tf.gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#input_iamge对应输入占位符,pool3_output对应输出结果
with graph.as_default():
input_iamge, pool3_output= tf.import_graph_def(graph_def, return_elements=["Placeholder:0","pool_3:0"])
print(pool3_output.name,pool3_output.shape)
label_input = tf.placeholder(tf.int64,[None])
trainable = tf.placeholder(tf.bool)
with tf.name_scope('flatten1'):
pool3_output_shape = pool3_output.get_shape().as_list()
nodes = pool3_output_shape[1] * pool3_output_shape[2] * pool3_output_shape[3]
flatten1_output = tf.reshape(pool3_output, [-1, nodes])
if trainable == True: flatten1_output = tf.nn.dropout(flatten1_output, 0.5)
print(flatten1_output.name,flatten1_output.shape)
#建立全连接层进行分类
with tf.variable_scope('fc1'):
weights = tf.get_variable("weight", [flatten1_output.get_shape().as_list()[-1], 5],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases = tf.get_variable("bias", [5], initializer=tf.constant_initializer(0.1))
output_tensor = tf.nn.relu(tf.matmul(flatten1_output, weights) + biases)
output_tensor = tf.nn.softmax(output_tensor)
print(output_tensor.name,output_tensor.shape)
#定义待优化的参数
output_vars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='fc1')
#定义损失函数:交叉熵损失
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output_tensor, labels=label_input)
loss = tf.reduce_mean(cross_entropy)
#定义优化器
Optimizer = tf.train.AdamOptimizer().minimize(loss,var_list=[output_vars1])
#计算网络输出精度
accuracy = tf.equal(tf.argmax(output_tensor,1),label_input)
accuracy = tf.reduce_mean(tf.cast(accuracy,tf.float32))
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
for i in range(10):
aver_train_acc = []
aver_train_loss = []
train_dataset = dataset(training_images_path,training_labels,15,True)
test_dataset = dataset(testing_images_path,testing_labels,80,False)
pbar = tqdm(train_dataset)
for traindata in pbar:
_, loss_value, train_acc = sess.run([Optimizer, loss, accuracy], feed_dict={input_iamge: traindata[0], trainable: True, label_input: traindata[1]})
aver_train_acc.append(train_acc)
aver_train_loss.append(loss_value)
aver_test_acc = []
for testdata in test_dataset:
test_acc = sess.run(accuracy, feed_dict={input_iamge: testdata[0], trainable: False,label_input: testdata[1]})
aver_test_acc.append(test_acc)
print('\n the iter is {},the aver-loss is {},the aver-train_acc is {},the aver-test_acc is {}'
.format(i,np.mean(aver_train_loss),np.mean(aver_train_acc),np.mean(aver_test_acc)))
虽然只用训练一层,但网络计算量依旧很大,我的小破本跑起来还是相当的慢啊,训练10次的结果如下,效果不太好,感觉这个炼丹还需研究研究: