这篇文章中我放弃了以往的model.fit()训练方法,
改用model.train_on_batch方法。
两种方法的比较:
- model.fit():用起来十分简单,对新手非常友好
- model.train_on_batch():封装程度更低,可以玩更多花样。
此外我也引入了进度条的显示方式,更加方便我们及时查看模型训练过程中的情况,可以及时打印各项指标。
🚀 我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
- 显卡(GPU):NVIDIA GeForce RTX 3080
一、前期工作
1. 设置GPU
如果使用的是CPU可以注释掉这部分的代码。
import tensorflow as tf
##python学习裙: 660193417###
gpus = tf.config.list_physical_devices("GPU")
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gp