Keras用tf的Strategy()分布式训练时候报XLA错误

文章讲述了在更新Keras和FocalLoss库时遇到的问题,版本冲突导致XLA错误。通过将变量创建移出XLA编译函数并回滚到Keras3.0.5版本,解决了多卡分布式训练中的问题。
摘要由CSDN通过智能技术生成

 We failed to lift variable creations out of this tf.function, so this tf.function cannot be run on XLA. A possible workaround is to move variable creation outside of the XLA compiled function.

最早用的pip -U 安装的keras没注意版本,直接可用。

之后装了一个第三方的Focal Loss库,结果自动把tf降了版本,后来再装keras只是==3.0结果就是这个版本不够新,导致了多卡分布式训练报xla错。折腾一下午,恍惚记得最早是3.0.5的keras,随后pip install keras==3.0.5,恢复正常。

Keras中设置分布式训练可以使用TensorFlow的tf.distribute.Strategy API。这个API提供了多种分布式策略,可以根据不同的使用场景选择适合的策略。其中,对于单机多卡训练,可以使用MirroredStrategy。\[1\] 使用MirroredStrategy时,需要在代码中引入tf.distribute.MirroredStrategy,并在创建模型之前实例化该策略。然后,将模型的创建和编译放在strategy.scope()的上下文中,以确保模型在所有可用的GPU上进行复制和训练。\[2\] 下面是一个设置分布式训练的示例代码: ```python import tensorflow as tf from tensorflow import keras # 实例化MirroredStrategy strategy = tf.distribute.MirroredStrategy() # 在strategy.scope()的上下文中创建和编译模型 with strategy.scope(): model = keras.Sequential(\[...\]) # 创建模型 model.compile(\[...\]) # 编译模型 # 加载数据集 train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) # 在分布式环境下训练模型 model.fit(train_dataset, epochs=10, validation_data=eval_dataset) ``` 在上述代码中,MirroredStrategy会自动将模型复制到所有可用的GPU上,并在每个GPU上进行训练。这样可以充分利用多个GPU的计算资源,加快模型训练的速度。\[1\] 需要注意的是,分布式训练需要有多个GPU才能发挥作用。如果只有单个GPU,使用分布式训练可能不会带来性能上的提升。另外,分布式训练还需要适当调整batch size和学习率等超参数,以获得最佳的训练效果。 #### 引用[.reference_title] - *1* [【Keras】TensorFlow分布式训练](https://blog.csdn.net/qq_36643449/article/details/124592521)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Keras分布式训练](https://blog.csdn.net/weixin_39693193/article/details/111539493)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Tensorflow2.0进阶学习-Keras分布式训练 (九)](https://blog.csdn.net/u010095372/article/details/124547254)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值