tensorflow入门教程(五十二)通过fit_generator和Sequence多进程读取训练数据

#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#

---- 韦访20190822

1、概述

这一讲用的是tensorflow2.0tensorflow2.0确实比tensorflow1.x简洁许多,特别是直接集成了keras使得它更易上手。我们以前的教程中有说过tensorflow1.x的多线程获取并预处理数据集的数据,然后送入模型进行训练,现在来看看tensorflow2.0怎么操作。1.x的多线程数据读取链接如下:

https://blog.csdn.net/rookie_wei/article/details/80187950

 

2、单进程读取数据

先来看看单线程怎么读取mnist数据,代码如下,

import tensorflow as tf
import time
import numpy as np
import os
import threading

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

def gen():
    print('  generator initiated')
    idx = 0
    while True:
        yield x_train[:32], y_train[:32]
        print('  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch %d  PID:%d  ident:%d' % (idx, os.getpid(), threading.currentThread().ident))
        idx += 1
        time.sleep(3)

if __name__ == "__main__":

    tr_gen = gen()
    
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    print('  wf>>>>>>>>>>>>>>>>>>>>> PID:%d  ident:%d' % (os.getpid(), threading.currentThread().ident))
    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    start_time = time.time()
    model.fit_generator(generator=tr_gen, steps_per_epoch=20, max_queue_size=10)
    print('Total used time: %d '% (time.time() - start_time))

运行结果:

  wf>>>>>>>>>>>>>>>>>>>>> PID:5644  ident:17420
  generator initiated
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 0  PID:5644  ident:23484
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
 1/20 [>.............................] - ETA: 6s - loss: 2.4578 - accuracy: 0.0625  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1  PID:5644  ident:23484
 2/20 [==>...........................] - ETA: 27s - loss: 2.3681 - accuracy: 0.0938  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 2  PID:5644  ident:23484
 3/20 [===>..........................] - ETA: 34s - loss: 2.2340 - accuracy: 0.1771  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 3  PID:5644  ident:23484
 4/20 [=====>........................] - ETA: 36s - loss: 2.1372 - accuracy: 0.2266  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 4  PID:5644  ident:23484
 5/20 [======>.......................] - ETA: 36s - loss: 2.0525 - accuracy: 0.3000  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 5  PID:5644  ident:23484
 6/20 [========>.....................] - ETA: 35s - loss: 1.9673 - accuracy: 0.3542  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 6  PID:5644  ident:23484
 7/20 [=========>....................] - ETA: 33s - loss: 1.9038 - accuracy: 0.4018  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 7  PID:5644  ident:23484
 8/20 [===========>..................] - ETA: 31s - loss: 1.8344 - accuracy: 0.4414  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 8  PID:5644  ident:23484
 9/20 [============>.................] - ETA: 29s - loss: 1.7695 - accuracy: 0.4896  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 9  PID:5644  ident:23484
10/20 [==============>...............] - ETA: 27s - loss: 1.7086 - accuracy: 0.5281  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 10  PID:5644  ident:23484
11/20 [===============>..............] - ETA: 24s - loss: 1.6503 - accuracy: 0.5597  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 11  PID:5644  ident:23484
12/20 [=================>............] - ETA: 22s - loss: 1.5898 - accuracy: 0.5911  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 12  PID:5644  ident:23484
13/20 [==================>...........] - ETA: 19s - loss: 1.5349 - accuracy: 0.6202  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 13  PID:5644  ident:23484
14/20 [====================>.........] - ETA: 16s - loss: 1.4856 - accuracy: 0.6473  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 14  PID:5644  ident:23484
15/20 [=====================>........] - ETA: 14s - loss: 1.4359 - accuracy: 0.6708  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 15  PID:5644  ident:23484
16/20 [=======================>......] - ETA: 11s - loss: 1.3856 - accuracy: 0.6914  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 16  PID:5644  ident:23484
17/20 [========================>.....] - ETA: 8s - loss: 1.3397 - accuracy: 0.7096   wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 17  PID:5644  ident:23484
18/20 [==========================>...] - ETA: 5s - loss: 1.2944 - accuracy: 0.7257  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 18  PID:5644  ident:23484
19/20 [===========================>..] - ETA: 2s - loss: 1.2492 - accuracy: 0.7401  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 19  PID:5644  ident:23484
20/20 [==============================] - 57s 3s/step - loss: 1.2107 - accuracy: 0.7516
Total used time: 57

main函数和gen函数中,将进程ID和线程ID都打印出来,然后对比一下发现,它们属于一个进程的,但是给gen开了一个线程来跑,总用时57秒。

3、多进程读取数据

接下来,看看怎么用多进程来读取数据,并且对比一下多进程的训练速度和单进程比有没有提升。代码如下,

import tensorflow as tf
import time
import numpy as np
import os
import threading

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

class MNISTSequence(tf.keras.utils.Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        
    def __len__(self):        
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        print('  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch %d  PID:%d  ident:%d' % (idx, os.getpid(), threading.currentThread().ident))
        time.sleep(3)
        return np.array(batch_x), np.array(batch_y)

if __name__ == "__main__":
    tr_gen = MNISTSequence(x_train, y_train, 32)
    
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    print('  wf>>>>>>>>>>>>>>>>>>>>> PID:%d  ident:%d' % (os.getpid(), threading.currentThread().ident))
    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    start_time = time.time()
    model.fit_generator(generator=tr_gen, steps_per_epoch=20, max_queue_size=10, use_multiprocessing=True, workers=4)
    print('Total used time: %d '% (time.time() - start_time))

运行结果,

  wf>>>>>>>>>>>>>>>>>>>>> PID:24384  ident:24120
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 423  PID:12756  ident:15440
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1790  PID:23340  ident:5484
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 318  PID:22668  ident:128
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1742  PID:21748  ident:14344
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 840  PID:23340  ident:5484
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 169  PID:22668  ident:128
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1514  PID:21748  ident:14344
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 161  PID:12756  ident:15440
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
 1/20 [>.............................] - ETA: 5:26 - loss: 2.3349 - accuracy: 0.1250  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 285  PID:22668  ident:128
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1662  PID:12756  ident:15440
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 657  PID:21748  ident:14344
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 673  PID:23340  ident:5484
 5/20 [======>.......................] - ETA: 59s - loss: 2.2405 - accuracy: 0.1437   wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1540  PID:21748  ident:14344
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1803  PID:12756  ident:15440
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1509  PID:23340  ident:5484
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1241  PID:22668  ident:128
 9/20 [============>.................] - ETA: 27s - loss: 2.0978 - accuracy: 0.2535  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1841  PID:23340  ident:5484
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1764  PID:22668  ident:128
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 737  PID:21748  ident:14344
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 581  PID:12756  ident:15440
13/20 [==================>...........] - ETA: 13s - loss: 1.9722 - accuracy: 0.3558  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 179  PID:12756  ident:15440
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 787  PID:22668  ident:128
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1504  PID:21748  ident:14344
  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1845  PID:23340  ident:5484
20/20 [==============================] - 29s 1s/step - loss: 1.7766 - accuracy: 0.4719
Total used time: 28

可以看到,通过MNISTSequence类的__getitem__函数打印的进程ID都不一样,说明是有4个进程在分别跑__getitem__函数,而且总用时也只有28秒,比上面单进程的58秒要快很多,如果在超大数据量的训练中,这个优势就很明显了。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值