训练VTimeLLM时用到bf16、tf32混精训练,transformers.TrainingArguments主要用于处理命令行参数,单独或同时设置bf16、tf32可能会出现以下两种报错:
‘’’
1、ValueError: Your setup doesn’t support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0
2、ValueError: --tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7
‘’’
命令行:
deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT vtimellm/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path ./checkpoints/vicuna-7b-v1.5 \
--version plain \
--data_path ./data/blip_laion_cc_sbu_558k.json \
--feat_folder ./feat/558k_clip_feat \
--tune_mm_mlp_adapter True \
--output_dir ./checkpoints/vtimellm-$MODEL_VERSION-stage1_test \
--bf16 True \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--learning_rate 1e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--tf32 True \
--logging_steps 1 \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
相关代码:
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
@dataclass
class TrainingArguments(transformers.TrainingArguments):
training_stage: int = field(default=2)
环境:
python:python3.10
pytorch:2.1.0
flash-attn:2.0.4
torchvision:0.16.0
deepspeed:0.14.2
transformers:4.31.0
解决:
问题出在transformers检测bf16/tf32用的是torch.version.cuda,AMD底层不支持,看了下issue应该是广泛存在的情况(4.35.0以后bf16验证不再使用torch.version.cuda,tf32仍使用)
由于已知系统环境支持bf16和tf32,所以我直接删掉了transformers/utils/import_utils.py验证部分的torch.version.cuda相关代码,具体如下:
def is_torch_bf16_gpu_available():
...
#TODO:if torch.cuda.is_available() and torch.version.cuda is not None:
if torch.cuda.is_available():
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
# if int(torch.version.cuda.split(".")[0]) < 11:
# return False
if not hasattr(torch.cuda.amp, "autocast"):
return False
else:
return False
def is_torch_tf32_available():
...
#TODO:if not torch.cuda.is_available() or torch.version.cuda is None:
if not torch.cuda.is_available():
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
# if int(torch.version.cuda.split(".")[0]) < 11:
# return False
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
return False
return True