cnn 预测过程代码_CNN卷积神经网络入门个人理解和实例代码(注释详细,一劳永逸)...

本文介绍了使用CNN进行花类识别的过程,包括数据预处理、模型构建与训练、模型保存及预测。通过详细注释的代码,展示了如何将图像数据转化为适合CNN的格式,构建卷积神经网络模型,并使用Keras实现模型训练。最后,演示了如何加载模型对单张图片进行预测。
摘要由CSDN通过智能技术生成

2021/01/05:更新了代码,现在应该没啥问题了

不铺垫啥了,最近看深度学习和神经网络有一点启发,想写个文章算是做个记录,直接开始吧。我看的是b站李宏毅老师的机器学习视频。

我比较懒,了解完原理,写完一遍代码之后,就想写一个一劳永逸的代码,想着以后有不同的应用场景需要,直接改几个参数然后调用就行了,所以这也是我特别喜欢写注释的原因。写注释是就为了一劳永逸。

当然,做图片分类,大概分为三个过程。

第一,收集你的数据,把它做成数据集(后面我是把数据集做出.npy格式),然后对数据预处理一下。具体就是把你的数据按类分好(丢到不同类名称的文件夹里),然后用代码把整体数据分成训练集、验证集、测试集和生成对应的训练集label、验证集_label、测试集_label。

第二、搭建卷积神经网络,把训练好的模型保存为.h5文件。搭建神经网络主要是要确定你处理数据的卷积层结构是什么,比如一张图片进去,要经过的卷积层的filter什么样,池化层什么样,打算设计几个,最后把得到的数据flatten拉直成一个向量,把这个向量丢到一个全连接层去跑(这整个第二步就叫卷积 神经网络,卷积就是数据去全连接层训练前的一种处理方式,没有卷积处理层只有全连接层的神经网络就叫DNN深度神经网络)。

第三、调用保存的模型对某一张图片预测它的类别,输出预测结果。

贴代码:

接下来以给花分类为例子,给出详细的代码。注:花的数据集和.npy文件的创建代码是参考了这篇博文,后面附上博文的转载版权声明(版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。本文链接:https://blog.csdn.net/umbrellalalalala/article/details/86516928)

第一、收集数据,放到数据文件夹里面去分好类,一个类一个文件夹(如图),然后做成.npy文件,生成的.npy为6维数组,分别是训练集,训练集标签,验证集,验证集标签,测试集,测试集标签。其中训练集的数据是随机打乱了的,标签对应也是;而训练集、验证集和测试集的数据都是来自于数据文件夹,按比例确定的,比如随机拿出10%做验证集,10%测试集,剩下的当作数据集。注:文件名不能有中文文件夹名就是类别名,里面是同一类的图片

# -*- coding: utf-8 -*-

import glob

import os.path

import numpy as np

import tensorflow as tf

from tensorflow.python.platform import gfile

# 将分好类的图片(一个文件夹一类图片)制作成npy文件并为每一个类别定好label。

# 其中npy包含6个内容,分别为

# training_images = np.array(data[0]) 顺序将会打乱

# training_labels = np.array(data[1]) 顺序会和training_images的顺序一样

# validation_images = np.array(data[2])

# validation_labels = np.array(data[3])

# testing_images = np.array(data[4])

# testing_labels = np.array(data[5])

# 且数据都是分好格式了的(picture_numbers, pixel-x, pixel_y, channel_number),可以直接代到模型里面

INPUT_DATA = 'E:\\projects\\manchine_learning\\flower_photos' # 原始输入数据的目录,其有五个子目录,每个目录下保存属于该类别的所有图片

OUTPUT_DATA = 'E:\\projects\\manchine_learning\\flower_processed_data.npy' # 将整理后的图片数据通过numpy的格式保存

# 测试数据和验证数据的比例

VALIDATION_PERCENTAGE = 10

TEST_PERCENTAGE = 10

# 读取数据并将数据分割成训练数据、验证数据和测试数据

# 创建数据列表

def create_image_lists(sess, testing_percentage, validation_percentage):

# sub_dirs用于存储INPUT_DATA下的全部子文件夹目录名称,有5种花所以有五个元素

# os.walk() 方法用于通过在目录树中游走输出在目录中的文件名,向上或者向下

# os.walk(top, topdown=Ture, οnerrοr=None, followlinks=False)

# 该函数可以得到一个三元tupple(dirpath, dirnames, filenames)

# dirpath:string,代表目录的路径;

# dirnames:list,包含了当前dirpath路径下所有的子目录名字(不包含目录路径);

# filenames:list,包含了当前dirpath路径下所有的非目录子文件的名字(不包含目录路径)。

# 注意,dirnames和filenames均不包含路径信息,如需完整路径,可使用os.p

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值