every blog every motto: You can do more than you think.
0. 前言
在训练模型时,我们往往不一次将数据全部加载进内存中,而是将数据分批次加载到内存中。
- 一种方法是用 while True 遍历数据,用yeid产生,具体可参考语义分割代码讲解部分
- 另一种方法是本文即将讲解的tf.keras.utils.Sequence方法
1. 正文
1.1 基础用法
__ len __ 中返回的即1个epoch迭代的次数,即:
总样本数/ batch_size
__ getitem __ 根据len中的迭代次数,生成数据
注意: __ len __ ,__ getitem __ 必须要实现
"""
测试
__getitem__
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
class Date(tf.keras.utils.Sequence):
def __init__(self):
print('初始化相关参数')
def __len__(self):
"""
此方法要实现,否则会报错
正常程序中返回1个epoch迭代的次数
:return:
"""
return 5
def __getitem__(self, index):
"""生成一个batch的数据"""
print('index:', index)
x_batch = ['x1', 'x2', 'x3', 'x4']
y_batch = ['y1', 'y2', 'y3', 'y4']
print('-'*20)
return x_batch, y_batch
# 实例化数据
date = Date()
for batch_number, (x, y) in enumerate(date):
print('正在进行第{} batch'.format(batch_number))
print('x_batch:', x)
print('y_batcxh:', y)
结果:
1.2 扩展(2020.11.12 15:37增补)
可以在类中实现on_epoch_end方法,保证在每个epoch后打乱原有数据的顺序
1.2.1 训练样例:
测试代码,如下:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
print('tensorflow version: ', tf.__version__)
class ZerosFirstEpochOnesAfter(tf.keras.utils.Sequence):
def __init__(self):
self.shuffle = True
def __len__(self):
return 2
def on_epoch_end(self):
print('---------------on_epoch_end------------')
# 打乱索引
# if self.shuffle:
# print('==============================================================shuffle')
# np.random.shuffle(self.indices)
def __getitem__(self, item):
return np.zeros((16, 1)), np.zeros((16,))
def main():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_dim=1, activation="softmax"))
model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']
)
model.fit(ZerosFirstEpochOnesAfter(), epochs=3, )
if __name__ == '__main__':
main()
tensorflow 2.0:
tensorflow 2.1:
tesorflow 2.3:
由以上三个版本的训练结果,我们可以发现,
- 在2.0和2.1版本中,是没有进行on_epoch_end方法调用的,即没有实现on_epoch_end方法内注释部分的打乱顺序,这是tensorflow早期版本的一个bug,具体可参考文后第4个链接。
- 在2.3版本中已得到改进
1.2.2 循环遍历:
1.2.2.1 原始版测试
循环遍历,如下所示:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
print('tensorflow version: ', tf.__version__)
class Date(tf.keras.utils.Sequence):
def __init__(self):
print('初始化相关参数')
self.lines = [1,2,3,4,5]
self.shuffle = True
def __len__(self):
"""
此方法要实现,否则会报错
正常程序中返回1个epoch迭代的次数
:return:
"""
return 2
def on_epoch_end(self):
print('=======================')
if self.shuffle == True:
print('------------一个epoch结束,打乱了顺序---')
np.random.shuffle(self.lines)
def __getitem__(self, index):
"""生成一个batch的数据"""
print('index:', index)
x_batch = ['x1', 'x2', 'x3', 'x4']
y_batch = ['y1', 'y2', 'y3', 'y4']
print('-' * 20)
return x_batch, y_batch
# 实例化数据
date = Date()
for epoch in range(2):
for batch_number, (x, y) in enumerate(date):
print('正在进行第{} batch'.format(batch_number))
print('x_batch:', x)
print('y_batcxh:', y)
print('一个epoch结束=============================')
结果:
如上图所示,通过循环遍历这种方法仍然不能调用on_epoch_end,即无法打乱顺序
1.2.2.2 改进版
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
print('tensorflow version: ', tf.__version__)
class Date(tf.keras.utils.Sequence):
def __init__(self):
print('初始化相关参数')
self.lines = [1,2,3,4,5]
self.shuffle = True
def __len__(self):
"""
此方法要实现,否则会报错
正常程序中返回1个epoch迭代的次数
:return:
"""
return 2
def on_epoch_end(self):
print('=======================')
if self.shuffle == True:
print('------------一个epoch结束,打乱了顺序---')
np.random.shuffle(self.lines)
def __getitem__(self, index):
"""生成一个batch的数据"""
print('index:', index)
x_batch = ['x1', 'x2', 'x3', 'x4']
y_batch = ['y1', 'y2', 'y3', 'y4']
print('-' * 20)
return x_batch, y_batch
# 实例化数据
date = Date()
for epoch in range(2):
print(date.lines)
for batch_number, (x, y) in enumerate(date):
print('正在进行第{} batch'.format(batch_number))
print('x_batch:', x)
print('y_batcxh:', y)
np.random.shuffle(date.lines)
print('一个epoch结束=============================')
如下图所示,我们发现已经打乱了“样本”顺序,
参考文献
[1] https://blog.csdn.net/weixin_39190382/article/details/105808830
[2] https://blog.csdn.net/weixin_43198141/article/details/89926262
[3] https://blog.csdn.net/u011311291/article/details/80991330
[4] https://github.com/tensorflow/tensorflow/issues/35911
[5] https://colab.research.google.com/gist/bfs15/fd18263f788a071225c60cedaf126748/35911.ipynb