fine-tuning inception网络 花朵分类

1、导入相关包

import tensorflow as tf
import os
import zipfile
import requests
import glob
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

2、下载inception模型并解压

inception_model_url = 'https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip'
filename = inception_model_url.split('/')[-1]
filepath = os.path.join('./', filename)
 
if not os.path.exists(filepath):
    print("download: ", filename)
    r = requests.get(inception_model_url, stream=True)
    with open(filepath, 'wb') as f:
        for chunk in r.iter_content(chunk_size=2048):
            if chunk:
                f.write(chunk)
print("finish: ", filename)

path = './inception_dec_2015'
if not os.path.exists(path):
    os.makedirs(path) 
    print(path+'\tcerate file successful')
else:
    print(path+'\tfile has existed')
    
zipfile.ZipFile(filepath).extractall(path)
print(path+'\tfile has extractalled')

3、数据集处理与展示

     将数据集分为10%的测试集与90%的训练集,并且随机打乱,注意图片数据一般是以路径输入的,而非原始图片数据。

datas_path = './flower_photos'
sub_dirs = [x[0] for x in os.walk(datas_path)][1:]
class_name = ['daisy','dandelion','roses','sunflowers','tulips']
# 初始化各个数据集。
training_images_path = []
training_labels = []
testing_images_path = []
testing_labels = []
current_label = 0

# 读取所有的子目录。
for sub_dir in sub_dirs:
    extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    file_list = []
    #返回路径最后的文件名
    dir_name = os.path.basename(sub_dir)
    for extension in extensions:
        file_glob = os.path.join(datas_path, dir_name, '*.' + extension)
        #glob.glob返回匹配格式的所有变量
        file_list.extend(glob.glob(file_glob))
    if not file_list:
        print('the path:{} don\'t have any file'.format(os.path.join(datas_path, dir_name)))

    test_count = 0
    train_count = 0
    for file_name in file_list:
        # 随机划分数据。
        chance = np.random.randint(100)
        if chance < 10:
            testing_images_path.append(file_name)
            testing_labels.append(current_label)
            test_count += 1
        else:
            training_images_path.append(file_name)
            training_labels.append(current_label)
            train_count += 1
    print('for the {} class,lable is {},there are {} testing_images and {} train_images'.
    format(dir_name,current_label,test_count,train_count))
    current_label += 1

# 将训练数据随机打乱
state = np.random.get_state()
np.random.shuffle(training_images_path)
np.random.set_state(state)
np.random.shuffle(training_labels)

plt.figure(figsize=(6,6))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    img = cv2.imread(training_images_path[i])
    plt.imshow(img[:,:,::-1])
    plt.xlabel(class_name[training_labels[i]])
plt.show()

    得到结果如下,大概5000张训练图片,500张测试图片,数据集是比较小的,所以后期训练时做了数据增强处理。

4、数据准备

     接下来构建数据迭达器,喂给网络的是图像数据,需要根据路径解析出图片数据,同时构建批次和数据增强处理

class dataset(object):
    
    def __init__(self,data_paths,labels,batch_size,trainable):
        
        ix = np.arange(len(data_paths))
        np.random.shuffle(ix)
        self.labels = [labels[i] for i in ix]
        self.data_paths = [data_paths[i] for i in ix]
        self.batch_size = batch_size        
        self.batch_num = int(np.floor(len(data_paths)/ batch_size))
        self.res = len(data_paths) % batch_size
        self.batch_count = 0
        self.trainable = trainable

    def __iter__(self):
        
        return self
    
    def __next__(self):
        
        with tf.device('/cpu:0'):
            if self.trainable == True:
                if(self.res!=0):
                    batch_images = np.zeros(shape = (self.res*4,299,299,3)).astype(np.float32)
                    batch_labels = np.zeros(shape = (self.res*4)).astype(np.int64)
                    for num in range(self.res):
                        image_path = self.data_paths[len(self.data_paths)-self.res+num]
                        image_data = cv2.imread(image_path)
                        image_data = cv2.resize(image_data,(299,299))
                        batch_images[4*num,:] = (cv2.flip(image_data, -1)[:,:,::-1]-128)*0.0078125
                        batch_images[4*num+1,:] = (cv2.flip(image_data, 0)[:,:,::-1]-128)*0.0078125
                        batch_images[4*num+2,:] = (cv2.flip(image_data, 1)[:,:,::-1]-128)*0.0078125
                        batch_images[4*num+3,:] = (image_data[:,:,::-1]-128)*0.0078125
                        batch_labels[4*num] = self.labels[len(self.labels)-self.res+num]
                        batch_labels[4*num+1] = self.labels[len(self.labels)-self.res+num]
                        batch_labels[4*num+2] = self.labels[len(self.labels)-self.res+num]
                        batch_labels[4*num+3] = self.labels[len(self.labels)-self.res+num]
                    self.res = 0
                    state = np.random.get_state()
                    np.random.shuffle(batch_images)
                    np.random.set_state(state)
                    np.random.shuffle(batch_labels)
                    return batch_images,batch_labels
                else:
                    if self.batch_count < self.batch_num:
                        batch_labels = np.zeros(shape=(self.batch_size*4)).astype(np.int64)
                        batch_images = np.zeros(shape = (self.batch_size*4,299,299,3)).astype(np.float32)
                        for num in range(self.batch_size):
                            image_path = self.data_paths[num+self.batch_count*self.batch_size]
                            image_data = cv2.imread(image_path)
                            image_data = cv2.resize(image_data,(299,299))
                            batch_images[4*num,:] = (cv2.flip(image_data, -1)[:,:,::-1]-128)*0.0078125
                            batch_images[4*num+1,:] = (cv2.flip(image_data, 0)[:,:,::-1]-128)*0.0078125
                            batch_images[4*num+2,:] = (cv2.flip(image_data, 1)[:,:,::-1]-128)*0.0078125
                            batch_images[4*num+3,:] = (image_data[:,:,::-1]-128)*0.0078125
                            batch_labels[4*num] = self.labels[self.batch_count*self.batch_size+num]
                            batch_labels[4*num+1] = self.labels[self.batch_count*self.batch_size+num]
                            batch_labels[4*num+2] = self.labels[self.batch_count*self.batch_size+num]
                            batch_labels[4*num+3] = self.labels[self.batch_count*self.batch_size+num]
                        self.batch_count += 1
                        state = np.random.get_state()
                        np.random.shuffle(batch_images)
                        np.random.set_state(state)
                        np.random.shuffle(batch_labels)
                        return batch_images,batch_labels
                    else:
                        self.batch_count = 0
                        raise StopIteration
            else:
                if(self.res!=0):
                    batch_images = np.zeros(shape = (self.res,299,299,3)).astype(np.float32)
                    batch_labels = np.array(self.labels[len(self.labels)-self.res:]).astype(np.int64)
                    for num in range(self.res):
                        image_path = self.data_paths[len(self.data_paths)-self.res+num]
                        image_data = cv2.imread(image_path)
                        image_data = cv2.resize(image_data,(299,299))
                        batch_images[num,:] = (image_data[:,:,::-1]-128)*0.0078125
                    self.res = 0
                    state = np.random.get_state()
                    np.random.shuffle(batch_images)
                    np.random.set_state(state)
                    np.random.shuffle(batch_labels)
                    return batch_images,batch_labels
                else:
                    if self.batch_count < self.batch_num:
                        batch_labels = np.array(self.labels[self.batch_count*self.batch_size:(1+self.batch_count)*self.batch_size]).astype(np.int64)
                        batch_images = np.zeros(shape = (self.batch_size,299,299,3)).astype(np.float32)
                        for num in range(self.batch_size):
                            image_path = self.data_paths[num+self.batch_count*self.batch_size]
                            image_data = cv2.imread(image_path)
                            image_data = cv2.resize(image_data,(299,299))
                            batch_images[num,:] = (image_data[:,:,::-1]-128)*0.0078125
                        self.batch_count += 1
                        state = np.random.get_state()
                        np.random.shuffle(batch_images)
                        np.random.set_state(state)
                        np.random.shuffle(batch_labels)
                        return batch_images,batch_labels
                    else:
                        self.batch_count = 0
                        raise StopIteration

5、调整网络结构

     利用Netro查看下载的inception网络PB模型,如下图所示

为了方便网络批量训练,可以将pb文件转换为pbtxt文件进行修改,然后再转换为pb文件,更改后的效果如下:直接将那一部分利用占位符取代。

     具体代码如下:

#pb模型转为pbtxt
from tensorflow.core.framework import graph_pb2

with tf.Session() as sess:
    with tf.gfile.FastGFile("./inception_dec_2015/tensorflow_inception_graph.pb", 'rb') as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())

tf.train.write_graph(graph_def, './', "./inception_dec_2015/tensorflow_inception_graph.pbtxt", as_text=True)

#pbtxt转为pb文件
from google.protobuf import text_format
with tf.Session() as sess:
    with tf.gfile.FastGFile('./inception_dec_2015/tensorflow_inception_graph.pbtxt', 'rb') as f:
        graph_def = graph_pb2.GraphDef()
        new_graph_def=text_format.Merge(f.read(), graph_def)
tf.train.write_graph(new_graph_def, './', './inception_dec_2015/tensorflow_inception_graph_new.pb', as_text=False)

6、网络搭建与训练

    本人小破本只有4G显存,跑网络经常出现显存方面的问题,在利用inception前向计算提取特征的过程中,一直报错,直到添加了如下代码:

config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
'''log_device_placement=True : 是否打印设备分配日志
allow_soft_placement=True : 如果你指定的设备不存在,允许TF自动分配设备'''
config.gpu_options.allow_growth = True
'''config.gpu_options.allow_growth = True动态申请显存,需要多少就申请多少'''
#config.gpu_options.per_process_gpu_memory_fraction = 0.95  #最多占用95%显存

     接下来就可以愉快的炼丹了,我去除掉inception中最后的分类层,替代为对五种花朵进行识别的全连接网络,需要注意的仅需要设置分类层中的变量为训练变量。

graph=tf.Graph()

model_filename = "./inception_dec_2015/tensorflow_inception_graph_new.pb"
with tf.gfile.FastGFile(model_filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
#input_iamge对应输入占位符,pool3_output对应输出结果
with graph.as_default():
    input_iamge, pool3_output= tf.import_graph_def(graph_def, return_elements=["Placeholder:0","pool_3:0"])
    print(pool3_output.name,pool3_output.shape)
    label_input = tf.placeholder(tf.int64,[None])
    trainable = tf.placeholder(tf.bool)

    with tf.name_scope('flatten1'):
        pool3_output_shape = pool3_output.get_shape().as_list()
        nodes = pool3_output_shape[1] * pool3_output_shape[2] * pool3_output_shape[3]
        flatten1_output = tf.reshape(pool3_output, [-1, nodes])
        if trainable == True: flatten1_output = tf.nn.dropout(flatten1_output, 0.5)
        print(flatten1_output.name,flatten1_output.shape)

    #建立全连接层进行分类
    with tf.variable_scope('fc1'):
        weights = tf.get_variable("weight", [flatten1_output.get_shape().as_list()[-1], 5],
                                      initializer=tf.truncated_normal_initializer(stddev=0.1))

        biases = tf.get_variable("bias", [5], initializer=tf.constant_initializer(0.1))
        output_tensor = tf.nn.relu(tf.matmul(flatten1_output, weights) + biases)
        output_tensor = tf.nn.softmax(output_tensor)
        print(output_tensor.name,output_tensor.shape)
        

    #定义待优化的参数
    output_vars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='fc1')

    #定义损失函数:交叉熵损失
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output_tensor, labels=label_input)
    loss = tf.reduce_mean(cross_entropy)

    #定义优化器
    Optimizer = tf.train.AdamOptimizer().minimize(loss,var_list=[output_vars1])

    #计算网络输出精度
    accuracy = tf.equal(tf.argmax(output_tensor,1),label_input)
    accuracy = tf.reduce_mean(tf.cast(accuracy,tf.float32))

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(10):
            aver_train_acc = []
            aver_train_loss = []
            train_dataset = dataset(training_images_path,training_labels,15,True)
            test_dataset = dataset(testing_images_path,testing_labels,80,False)
            pbar = tqdm(train_dataset)
            for traindata in pbar:
                _, loss_value, train_acc = sess.run([Optimizer, loss, accuracy], feed_dict={input_iamge: traindata[0], trainable: True, label_input: traindata[1]})
                aver_train_acc.append(train_acc)
                aver_train_loss.append(loss_value)
            aver_test_acc = []
            for testdata in test_dataset:
                test_acc = sess.run(accuracy, feed_dict={input_iamge: testdata[0], trainable: False,label_input: testdata[1]})
                aver_test_acc.append(test_acc)
            print('\n the iter is {},the aver-loss is {},the aver-train_acc is {},the aver-test_acc is {}'
                  .format(i,np.mean(aver_train_loss),np.mean(aver_train_acc),np.mean(aver_test_acc)))

    虽然只用训练一层,但网络计算量依旧很大,我的小破本跑起来还是相当的慢啊,训练10次的结果如下,效果不太好,感觉这个炼丹还需研究研究:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值