TensorFlow Keras 使用Inception-resnet-v2模型训练自己的分类数据集(含源码)
运行环境
TensorFlow 1.13.1
TensorFlow.Keras 2.2.4-tf
简单介绍
使用TensorFlow自带的Inception-resnet-v2模型训练自己的数据集。数据读取用的是TensorFlow自己的Dataset类,且无需转存成TFrecord格式。使用TensorFlow中的Keras,简单易懂,容易上手
注意事项
- 先准备好一个分类数据集
- 使用GPU训练(用CPU应该训练不动Inception-resnet-v2模型,如果没有GPU你可以换成TensorFlow现有的其他模型,但代码需要进行一定的改动)
源码
废话不多说,上源码。
DATA_PATH放数据集路径
ds = ds.prefetch(buffer_size=10*BATCH_SIZE)
这一句用于预读取数据,用的时候注意下CPU和内存,特别是内存,如果百分之九十多了就把程序关了(友情提示Ctrl+C可关闭程序),把这句话注释掉再跑
mport pathlib
import random
import tensorflow as tf
# 训练数据集路径
DATA_PATH = 'E:\dataset'
# 一个批次的大小
BATCH_SI