摘要
本文分享了在使用transformers库进行BERT模型训练时遇到的AttributeError: ‘AdamW’ object has no attribute 'train’错误的解决过程。通过查找相关信息,发现问题源于accelerate库版本过低,并通过将库升级至0.34.2版本成功解决报错。本文详细介绍了问题排查、版本更新的步骤,以及如何忽略更新中的警告提示,以帮助读者快速解决类似问题。
报错信息描述
在使用 transformers
库的 Trainer
训练 BERT 模型时,遇到了以下报错信息:
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train
return inner_training_loop(
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 3477, in training_step
self.optimizer.train()
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/optimizer.py", line 128, in train
return self.optimizer.train()
AttributeError: 'AdamW' object has no attribute 'train'
Traceback (most recent call last):
File "/home/jie/gitee/pku_industry/general/pipeline.py", line 202, in <module>
run("optical_communication_laser")
File "/home/jie/gitee/pku_industry/general/pipeline.py", line 97, in run
bert_cls.train(5)
File "/home/jie/gitee/pku_industry/general/bert_train.py", line 108, in train
self.trainer.train()
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train
return inner_training_loop(
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 3477, in training_step
self.optimizer.train()
File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/optimizer.py", line 128, in train
return self.optimizer.train()
AttributeError: 'AdamW' object has no attribute 'train'
因为之前这段代码是可以正常运行的,所以我怀疑问题可能与某些库的版本更新有关。
Bug修复过程
我使用搜索引擎查询了报错信息,尝试找到解决方案。友情提醒,如果使用百度搜索,可能很难找到有用的信息,因为对于这类专业性较强的问题,百度的表现还有待提升。百度加油!
点击查看:GitHub 上的相关 issue
在一个 GitHub issue 中,有人提到了需要将 accelerate
库更新到 0.34.2
版本,解决这个问题。
检查当前库的版本
使用以下命令查看当前安装的 accelerate
库版本:
pip show accelerate
发现当前版本低于 0.34.2
,所以需要进行更新。
更新 accelerate
库
使用以下命令将 accelerate
库更新到 0.34.2
:
pip install accelerate==0.34.2
在更新过程中,可能会出现一些警告信息,不过这些警告可以忽略。
验证更新结果
更新完成后,重新运行代码,问题已经解决,程序可以正常执行了。
这个过程表明,部分依赖库的更新可能会引入不兼容的改动。定期检查并更新项目的依赖项,可以避免遇到类似问题。希望这篇博客能帮助大家解决类似的报错问题。