1. 在终端运行如下指令进行配置
accelerate config
2. 根据需要生成如下配置:
➜ ✗ cat default_config.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: '0'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
3. 终端运行如下命令:
accelerate launch --config_file default_config.yaml \
train.py \