官方 | Keras分布式训练教程

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

总览

tf.distribute.Strategy API提供了一种抽象,用于在多个处理单元之间分布您的训练。目的是允许用户以最小的更改使用现有模型和培训代码来进行分布式培训。

本教程使用tf.distribute.MirroredStrategy,它在一台机器上的多个GPU上进行同步训练的图内复制。本质上,它将所有模型变量复制到每个处理器。然后,它使用all-reduce组合所有处理器的梯度,并将组合后的值应用于模型的所有副本。

MirroredStategy是TensorFlow核心中可用的几种分发策略之一。您可以在分发策略指南中了解更多策略。

Keras API

本示例使用tf.keras API构建模型和训练循环。有关自定义训练循环,请参阅带有训练循环的tf.distribute.Strategy教程。

Keras API

This example uses the tf.keras API to build the model and training loop. For custom training loops, see the tf.distribute.Strategy with training loops tutorial.

Import dependencies

from __future__ import absolute_import, division, print_function, unicode_literals


# Import TensorFlow and TensorFlow Datasetstry:
  !pip install -q tf-nightly
exceptException:
  passimport tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()import os
print(tf.__version__)
2.1.0-dev20191004

Download the dataset

Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.

Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)


mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...


/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)


WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`


WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

Define distribution strategy

Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

# You can also do info.splits.total_num_examples to get the total# number of examples in the dataset.


num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples


BUFFER_SIZE = 10000


BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255


  return image, label

Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Create the model

Create and compile the Keras model in the context of strategy.scope.

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])


  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Define the callbacks

The callbacks used here are:

  • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.

  • Model Checkpoint: This callback saves the model after every epoch.

  • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate in the notebook.

# Define the checkpoint directory to store the checkpoints


checkpoint_dir = './training_checkpoints'# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")


# Function for decaying the learning rate.# You can define any decay function you need.def decay(epoch):
  if epoch < 3:
    return1e-3
  elif epoch >= 3and epoch < 7:
    return1e-4
  else:
    return1e-5
# Callback for printing the LR at the end of each epoch.classPrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                  
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()]

Train and evaluate

Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

Keras中设置分布式训练可以使用TensorFlow的tf.distribute.Strategy API。这个API提供了多种分布式策略,可以根据不同的使用场景选择适合的策略。其中,对于单机多卡训练,可以使用MirroredStrategy。\[1\] 使用MirroredStrategy时,需要在代码中引入tf.distribute.MirroredStrategy,并在创建模型之前实例化该策略。然后,将模型的创建和编译放在strategy.scope()的上下文中,以确保模型在所有可用的GPU上进行复制和训练。\[2\] 下面是一个设置分布式训练的示例代码: ```python import tensorflow as tf from tensorflow import keras # 实例化MirroredStrategy strategy = tf.distribute.MirroredStrategy() # 在strategy.scope()的上下文中创建和编译模型 with strategy.scope(): model = keras.Sequential(\[...\]) # 创建模型 model.compile(\[...\]) # 编译模型 # 加载数据集 train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) # 在分布式环境下训练模型 model.fit(train_dataset, epochs=10, validation_data=eval_dataset) ``` 在上述代码中,MirroredStrategy会自动将模型复制到所有可用的GPU上,并在每个GPU上进行训练。这样可以充分利用多个GPU的计算资源,加快模型训练的速度。\[1\] 需要注意的是,分布式训练需要有多个GPU才能发挥作用。如果只有单个GPU,使用分布式训练可能不会带来性能上的提升。另外,分布式训练还需要适当调整batch size和学习率等超参数,以获得最佳的训练效果。 #### 引用[.reference_title] - *1* [【Keras】TensorFlow分布式训练](https://blog.csdn.net/qq_36643449/article/details/124592521)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Keras分布式训练](https://blog.csdn.net/weixin_39693193/article/details/111539493)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Tensorflow2.0进阶学习-Keras分布式训练 (九)](https://blog.csdn.net/u010095372/article/details/124547254)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值