1、Epoch, Batch, Iteration的关系
cifar10数据集有 60000 张图片作为训练数据,10000 张图片作为测试数据。
假设现在选择 Batch Size = 100 对模型进行训练。
那么:
每个 Epoch 要训练的图片数量:60000(训练集上的所有图像)
训练集具有的 Batch 个数: 60000/100=600
每个 Epoch 需要完成的 Batch 个数: 600
每个 Epoch 具有的 Iteration 个数: 600(完成一个Batch训练,相当于参数迭代一次)
每个 Epoch 中发生模型权重更新的次数:600
训练 10 个Epoch后,模型权重更新的次数: 600*10=6000
不同Epoch的训练,其实用的是同一个训练集的数据。第1个Epoch和第10个Epoch虽然用的都是训练集的60000图片,但是对模型的权重更新值却是完全不同的。因为不同Epoch的模型处于代价函数空间上的不同位置,模型的训练代越靠后,越接近谷底,其代价越小。
总共完成30000次迭代,相当于完成了 30000/600=50 个Epoch
————————————————
版权声明:本文为CSDN博主「xytywh」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/xiaohuihui1994/article/details/80624593
2、repeat
repeat(2)表示2epock,但是repeat的位置不同,实际效果不同。
看下面的例子
(1)repeat在batch后面
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
生成的11个数组
[[0.5488135 0.71518937]
[0.60276338 0.54488318]
[0.4236548 0.64589411]
[0.43758721 0.891773 ]
[0.96366276 0.38344152]
[0.79172504 0.52889492]
[0.56804456 0.92559664]
[0.07103606 0.0871293 ]
[0.0202184 0.83261985]
[0.77815675 0.87001215]
[0.97861834 0.79915856]]
效果是,把生成的11个数组,按batch_size=4取,4、4、3,然后再取一遍4、4、3.
[[0.60276338 0.54488318]
[0.4236548 0.64589411]
[0.5488135 0.71518937]
[0.79172504 0.52889492]]
[[0.96366276 0.38344152]
[0.56804456 0.92559664]
[0.07103606 0.0871293 ]
[0.43758721 0.891773 ]]
[[0.0202184 0.83261985]
[0.77815675 0.87001215]
[0.97861834 0.79915856]]
[[0.5488135 0.71518937]
[0.43758721 0.891773 ]
[0.60276338 0.54488318]
[0.4236548 0.64589411]]
[[0.56804456 0.92559664]
[0.79172504 0.52889492]
[0.07103606 0.0871293 ]
[0.77815675 0.87001215]]
[[0.0202184 0.83261985]
[0.96366276 0.38344152]
[0.97861834 0.79915856]]
(2)如果把repeat放到shuffle前面,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,然后shuffle和batch。
那么就是22个数组,按4、4、4、4、4、2取
[[0.4236548 0.64589411]
[0.5488135 0.71518937]
[0.43758721 0.891773 ]
[0.79172504 0.52889492]]
[[0.96366276 0.38344152]
[0.56804456 0.92559664]
[0.07103606 0.0871293 ]
[0.77815675 0.87001215]]
[[0.0202184 0.83261985]
[0.60276338 0.54488318]
[0.97861834 0.79915856]
[0.60276338 0.54488318]]
[[0.5488135 0.71518937]
[0.43758721 0.891773 ]
[0.4236548 0.64589411]
[0.56804456 0.92559664]]
[[0.07103606 0.0871293 ]
[0.0202184 0.83261985]
[0.77815675 0.87001215]
[0.97861834 0.79915856]]
[[0.79172504 0.52889492]
[0.96366276 0.38344152]]
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系。