导语
哈喽吖铁汁萌~今天来教大家写个猫狗识别系统
小猫小狗真的太可爱了🙈🙈
这篇文章中我放弃了以往的model.fit()训练方法,改用model.train_on_batch方法。两种方法的比较:
-
model.fit()
:用起来十分简单,对新手非常友好 -
model.train_on_batch()
:封装程度更低,可以玩更多花样。
此外我也引入了进度条的显示方式,更加方便我们及时查看模型训练过程中的情况,可以及时打印各项指标。
🚀 我的环境:
-
语言环境:Python3.6.5
-
编译器:jupyter notebook
-
深度学习环境:TensorFlow2.4.1
-
显卡(GPU):NVIDIA GeForce RTX 3080
🚀 来自专栏:《深度学习100例》
文章目录
一、前期工作
1. 设置GPU
2. 导入数据
3. 查看数据
二、数据预处理
1. 加载数据
2. 再次检查数据
3. 配置数据集
4. 可视化数据
三、构建VG-16网络
四、编译
五、训练模型
六、模型评估
七、保存and加载模型
八、预测
一、前期工作
1. 设置GPU
如果使用的是CPU可以注释掉这部分的代码。
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpus[0]],"GPU")
# 打印显卡信息,确认GPU可用
print(gpus)
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2. 导入数据
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
import os,PIL
# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)
# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)
#隐藏警告
import warnings
warnings.filterwarnings('ignore')
import pathlib
data_dir = "./data/train"
# data_dir = "D:/jupyter notebook/DL-100-days/datasets/017_Eye_dataset"
data_dir = pathlib.Path(data_dir)
3. 查看数据
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为:3400
二、数据预处理
1. 加载数据
使用image_dataset_from_directory
方法将磁盘中的数据加载到tf.data.Dataset
中
batch_size = 8
img_height = 224
img_width = 224
TensorFlow版本是2.2.0的同学可能会遇到module 'tensorflow.keras.preprocessing' has no attribute 'image_dataset_from_directory'
的报错,升级一下TensorFlow就OK了
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=12,
image_size=(img_height, img_width),
batch_size=batch_size)
Found 3400 files belonging to 2 classes.
Using 2720 files for training.
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/detail