#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#
---- 韦访20190822
1、概述
这一讲用的是tensorflow2.0,tensorflow2.0确实比tensorflow1.x简洁许多,特别是直接集成了keras使得它更易上手。我们以前的教程中有说过tensorflow1.x的多线程获取并预处理数据集的数据,然后送入模型进行训练,现在来看看tensorflow2.0怎么操作。1.x的多线程数据读取链接如下:
https://blog.csdn.net/rookie_wei/article/details/80187950
2、单进程读取数据
先来看看单线程怎么读取mnist数据,代码如下,
import tensorflow as tf
import time
import numpy as np
import os
import threading
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def gen():
print(' generator initiated')
idx = 0
while True:
yield x_train[:32], y_train[:32]
print(' wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch %d PID:%d ident:%d' % (idx, os.getpid(), threading.currentThread().ident))
idx += 1
time.sleep(3)
if __name__ == "__main__":
tr_gen = gen()
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
print(' wf>>>>>>>>>>>>>>>>>>>>> PID:%d ident:%d' % (os.getpid(), threading.currentThread().ident))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start_time = time.time()
model.fit_generator(generator=tr_gen, steps_per_epoch=20, max_queue_size=10)
print('Total used time: %d '% (time.time() - start_time))
运行结果:
wf>>>>>>>>>>>>>>>>>>>>> PID:5644 ident:17420
generator initiated
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 0 PID:5644 ident:23484
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
1/20 [>.............................] - ETA: 6s - loss: 2.4578 - accuracy: 0.0625 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1 PID:5644 ident:23484
2/20 [==>...........................] - ETA: 27s - loss: 2.3681 - accuracy: 0.0938 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 2 PID:5644 ident:23484
3/20 [===>..........................] - ETA: 34s - loss: 2.2340 - accuracy: 0.1771 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 3 PID:5644 ident:23484
4/20 [=====>........................] - ETA: 36s - loss: 2.1372 - accuracy: 0.2266 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 4 PID:5644 ident:23484
5/20 [======>.......................] - ETA: 36s - loss: 2.0525 - accuracy: 0.3000 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 5 PID:5644 ident:23484
6/20 [========>.....................] - ETA: 35s - loss: 1.9673 - accuracy: 0.3542 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 6 PID:5644 ident:23484
7/20 [=========>....................] - ETA: 33s - loss: 1.9038 - accuracy: 0.4018 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 7 PID:5644 ident:23484
8/20 [===========>..................] - ETA: 31s - loss: 1.8344 - accuracy: 0.4414 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 8 PID:5644 ident:23484
9/20 [============>.................] - ETA: 29s - loss: 1.7695 - accuracy: 0.4896 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 9 PID:5644 ident:23484
10/20 [==============>...............] - ETA: 27s - loss: 1.7086 - accuracy: 0.5281 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 10 PID:5644 ident:23484
11/20 [===============>..............] - ETA: 24s - loss: 1.6503 - accuracy: 0.5597 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 11 PID:5644 ident:23484
12/20 [=================>............] - ETA: 22s - loss: 1.5898 - accuracy: 0.5911 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 12 PID:5644 ident:23484
13/20 [==================>...........] - ETA: 19s - loss: 1.5349 - accuracy: 0.6202 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 13 PID:5644 ident:23484
14/20 [====================>.........] - ETA: 16s - loss: 1.4856 - accuracy: 0.6473 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 14 PID:5644 ident:23484
15/20 [=====================>........] - ETA: 14s - loss: 1.4359 - accuracy: 0.6708 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 15 PID:5644 ident:23484
16/20 [=======================>......] - ETA: 11s - loss: 1.3856 - accuracy: 0.6914 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 16 PID:5644 ident:23484
17/20 [========================>.....] - ETA: 8s - loss: 1.3397 - accuracy: 0.7096 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 17 PID:5644 ident:23484
18/20 [==========================>...] - ETA: 5s - loss: 1.2944 - accuracy: 0.7257 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 18 PID:5644 ident:23484
19/20 [===========================>..] - ETA: 2s - loss: 1.2492 - accuracy: 0.7401 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 19 PID:5644 ident:23484
20/20 [==============================] - 57s 3s/step - loss: 1.2107 - accuracy: 0.7516
Total used time: 57
在main函数和gen函数中,将进程ID和线程ID都打印出来,然后对比一下发现,它们属于一个进程的,但是给gen开了一个线程来跑,总用时57秒。
3、多进程读取数据
接下来,看看怎么用多进程来读取数据,并且对比一下多进程的训练速度和单进程比有没有提升。代码如下,
import tensorflow as tf
import time
import numpy as np
import os
import threading
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class MNISTSequence(tf.keras.utils.Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
print(' wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch %d PID:%d ident:%d' % (idx, os.getpid(), threading.currentThread().ident))
time.sleep(3)
return np.array(batch_x), np.array(batch_y)
if __name__ == "__main__":
tr_gen = MNISTSequence(x_train, y_train, 32)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
print(' wf>>>>>>>>>>>>>>>>>>>>> PID:%d ident:%d' % (os.getpid(), threading.currentThread().ident))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start_time = time.time()
model.fit_generator(generator=tr_gen, steps_per_epoch=20, max_queue_size=10, use_multiprocessing=True, workers=4)
print('Total used time: %d '% (time.time() - start_time))
运行结果,
wf>>>>>>>>>>>>>>>>>>>>> PID:24384 ident:24120
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 423 PID:12756 ident:15440
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1790 PID:23340 ident:5484
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 318 PID:22668 ident:128
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1742 PID:21748 ident:14344
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 840 PID:23340 ident:5484
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 169 PID:22668 ident:128
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1514 PID:21748 ident:14344
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 161 PID:12756 ident:15440
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
1/20 [>.............................] - ETA: 5:26 - loss: 2.3349 - accuracy: 0.1250 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 285 PID:22668 ident:128
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1662 PID:12756 ident:15440
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 657 PID:21748 ident:14344
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 673 PID:23340 ident:5484
5/20 [======>.......................] - ETA: 59s - loss: 2.2405 - accuracy: 0.1437 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1540 PID:21748 ident:14344
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1803 PID:12756 ident:15440
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1509 PID:23340 ident:5484
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1241 PID:22668 ident:128
9/20 [============>.................] - ETA: 27s - loss: 2.0978 - accuracy: 0.2535 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1841 PID:23340 ident:5484
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1764 PID:22668 ident:128
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 737 PID:21748 ident:14344
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 581 PID:12756 ident:15440
13/20 [==================>...........] - ETA: 13s - loss: 1.9722 - accuracy: 0.3558 wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 179 PID:12756 ident:15440
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 787 PID:22668 ident:128
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1504 PID:21748 ident:14344
wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1845 PID:23340 ident:5484
20/20 [==============================] - 29s 1s/step - loss: 1.7766 - accuracy: 0.4719
Total used time: 28
可以看到,通过MNISTSequence类的__getitem__函数打印的进程ID都不一样,说明是有4个进程在分别跑__getitem__函数,而且总用时也只有28秒,比上面单进程的58秒要快很多,如果在超大数据量的训练中,这个优势就很明显了。