如何零基础用tensorflow搭建基本的CNN框架 | 附训练断点续练、图像展示、参数保存模块
嗨,我是error。
这次的笔记是关于tensorflow基本框架的搭建,零基础带你熟悉如何应用keras搭建自己的CNN模型,并训练自己的数据,实现深度学习。
代码主要参考来源自
【国家精品课程】北京大学人工智能实践-TensorFlow2.0
CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型
Keras的八股文构建方法
这篇文章主要是写给tensorflow零基础但深度学习对CNN结构有一定了解的朋友,故重点会放在详细介绍代码实现CNN结构的方法上面。
首先要了解的是keras最基本的八股文式构建法,即
首先是import各类的库,然后分train和test数据(可以直接使用keras官方的数据也可以使用自己准备的数据,下面会分别讲解)。构建model后compile最后fit就完成了整个训练,最后的summary可有可无,主要是打印网络结构。
下面老师给出了三个主要对象的参数文档说明。
CNN模型的详细模块解说
首先说明下我的数据来源,来自CIFAR10数据集,照片都放在train文件夹内,在根目录下有trainlabels的csv表格标记着每一张图片的标签
根目录下面还有x_train和y_train的npy文件,每次运行代码时都检测是否存在这两个文件,若不存在,则从官网下载CIFAR10文件并保存,若存在则直接调用,不再重复下载。
根目录下的checkpoint文件夹存放着每次训练的参数,以便下次训练时沿用上次已经训练好的参数而不是重复计算,实现了断点续训。
import tensorflow as tf
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_path = './train/'
train_csv = './trainLabels.csv'
x_train_savepath = './x_train.npy'
y_train_savepath = './y_train.npy'
在import完所需要的库后定义需要的路径以便后面传参时方便。
然后我们定义一个generateds的函数来把我们从csv表格标记的每一个图片的标签对应上就好了。
先定义两个空列表,分别是x和y,因为列表是有顺序的,只要我们放入列表的顺序和表格的顺序是一致,那么它们的标签就是对应匹配的。
def generateds(path, csv):
f = open(csv, 'r') # 以只读形式打开csv文件
contents = f.readlines()[1:] # 读取文件中除了第一行的所有行,因为第一行是id/labels的头
f.close() # 关闭csv文件
x, y_ = [], [] # 建立空列表
for content in contents: # 逐行取出
value = content.split(",") # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表
img_path = path + value[0] + '.png' # 拼出图片路径和文件名
img = Image.open(img_path) # 读入图片
img = np.array(img.convert('L')) # 图片变为8位宽灰度值的np.array格式