2021/01/05:更新了代码,现在应该没啥问题了
不铺垫啥了,最近看深度学习和神经网络有一点启发,想写个文章算是做个记录,直接开始吧。我看的是b站李宏毅老师的机器学习视频。
我比较懒,了解完原理,写完一遍代码之后,就想写一个一劳永逸的代码,想着以后有不同的应用场景需要,直接改几个参数然后调用就行了,所以这也是我特别喜欢写注释的原因。写注释是就为了一劳永逸。
当然,做图片分类,大概分为三个过程。
第一,收集你的数据,把它做成数据集(后面我是把数据集做出.npy格式),然后对数据预处理一下。具体就是把你的数据按类分好(丢到不同类名称的文件夹里),然后用代码把整体数据分成训练集、验证集、测试集和生成对应的训练集label、验证集_label、测试集_label。
第二、搭建卷积神经网络,把训练好的模型保存为.h5文件。搭建神经网络主要是要确定你处理数据的卷积层结构是什么,比如一张图片进去,要经过的卷积层的filter什么样,池化层什么样,打算设计几个,最后把得到的数据flatten拉直成一个向量,把这个向量丢到一个全连接层去跑(这整个第二步就叫卷积 神经网络,卷积就是数据去全连接层训练前的一种处理方式,没有卷积处理层只有全连接层的神经网络就叫DNN深度神经网络)。
第三、调用保存的模型对某一张图片预测它的类别,输出预测结果。
贴代码:
接下来以给花分类为例子,给出详细的代码。注:花的数据集和.npy文件的创建代码是参考了这篇博文,后面附上博文的转载版权声明(版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。本文链接:https://blog.csdn.net/umbrellalalalala/article/details/86516928)
第一、收集数据,放到数据文件夹里面去分好类,一个类一个文件夹(如图),然后做成.npy文件,生成的.npy为6维数组,分别是训练集,训练集标签,验证集,验证集标签,测试集,测试集标签。其中训练集的数据是随机打乱了的,标签对应也是;而训练集、验证集和测试集的数据都是来自于数据文件夹,按比例确定的,比如随机拿出10%做验证集,10%测试集,剩下的当作数据集。注:文件名不能有中文文件夹名就是类别名,里面是同一类的图片
# -*- coding: utf-8 -*-
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
# 将分好类的图片(一个文件夹一类图片)制作成npy文件并为每一个类别定好label。
# 其中npy包含6个内容,分别为
# training_images = np.array(data[0]) 顺序将会打乱
# training_labels = np.array(data[1]) 顺序会和training_images的顺序一样
# validation_images = np.array(data[2])
# validation_labels = np.array(data[3])
# testing_images = np.array(data[4])
# testing_labels = np.array(data[5])
# 且数据都是分好格式了的(picture_numbers, pixel-x, pixel_y, channel_number),可以直接代到模型里面
INPUT_DATA = 'E:\\projects\\manchine_learning\\flower_photos' # 原始输入数据的目录,其有五个子目录,每个目录下保存属于该类别的所有图片
OUTPUT_DATA = 'E:\\projects\\manchine_learning\\flower_processed_data.npy' # 将整理后的图片数据通过numpy的格式保存
# 测试数据和验证数据的比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
# 读取数据并将数据分割成训练数据、验证数据和测试数据
# 创建数据列表
def create_image_lists(sess, testing_percentage, validation_percentage):
# sub_dirs用于存储INPUT_DATA下的全部子文件夹目录名称,有5种花所以有五个元素
# os.walk() 方法用于通过在目录树中游走输出在目录中的文件名,向上或者向下
# os.walk(top, topdown=Ture, οnerrοr=None, followlinks=False)
# 该函数可以得到一个三元tupple(dirpath, dirnames, filenames)
# dirpath:string,代表目录的路径;
# dirnames:list,包含了当前dirpath路径下所有的子目录名字(不包含目录路径);
# filenames:list,包含了当前dirpath路径下所有的非目录子文件的名字(不包含目录路径)。
# 注意,dirnames和filenames均不包含路径信息,如需完整路径,可使用os.p