在昇腾910芯片,训练Transformer大模型时,我们的序列长度特别长,例如8192的序列长度,这边遇到的一个问题是,当我们预先将数据处理为mindrecords格式的数据后,在设置 dataset_sink_mode=True
的情况下,数据迭代会因为超时而报错(如下错误信息)。
[ERROR] DEVICE(301,fff158ff6160,python):2023-01-09-11:35:20.909.119 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc:721] DumpTaskExceptionInfo] Dump node (Default/GetNext-op3507) task error input/output data to: ./rank_21/node_dump
[WARNING] DEVICE(301,fff158ff6160,python):2023-01-09-11:35:20.909.161 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc:728] DumpTaskExceptionInfo] GetNext error may be caused by slow data processing (bigger than 20s / batch) or transfer data to device error.
[WARNING] DEVICE(301,fff158ff6160,python):2023-01-09-11:35:20.909.171 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc:730] DumpTaskExceptionInfo] Suggestion:
[WARNING] DEVICE(301,fff158ff6160,python):2023-01-09-11:35:20.909.180 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc:731] DumpTaskExceptionInfo] 1) Set the parameter dataset_sink_mode=False of model.train(...) or model.eval(...) and try again.
[WARNING] DEVICE(301,fff158ff6160,python):2023-01-09-11:35:20.909.188 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc:733] DumpTaskExceptionInfo] 2) Reduce the batch_size in data processing and try again.
[WARNING] DEVICE(301,fff158ff6160,python):2023-01-09-11:35:20.909.196 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc:734] DumpTaskExceptionInfo] 3) You can create iterator by interface create_dict_iterator() of dataset class to independently verify the performance of data processing without training.
我们的尝试在构建数据迭代类中,设置大的进程数,例如 dataset.map(operations=map_func, num_parallel_workers=32)
和 dataset.batch(batch_size, drop_remainder=drop, num_parallel_workers=32)
这样似乎会导致机器的内存占用过大。如果我将 dataset_sink_mode=False
,那是不是会导致训练中CPU将成为瓶颈,降低昇腾芯片的利用率。
****************************************************解答*****************************************************
1. 太大的worker数可能反而性能会变低,建议降低一些worker数,如12 16
2. dataset.map(operations=map_func, num_parallel_workers=32) 可以使用参数python_multiprocessing=True提升效率
3.map_func可以考虑拆成多个map,缓解当前节点计算效率阻塞
4. dataset_sink_mode=False会降低性能昇腾芯片,因为数据下发会变成同步,建议还是通过提升数据集的效率解决超时报错