完整报错信息如下:
Traceback (most recent call last):
File "bert.py", line 172, in <module>
output = predict('../../../data/end2end/title_content5.csv', model_path='../../../data/end2end/bert.pth')
File "bert.py", line 149, in predict
model = model.load_state_dict(torch.load(kw['model_path'])).cuda()
AttributeError: '_IncompatibleKeys' object has no attribute 'cuda'
load_state_dict方法去加载模型时,模型是不会作为返回值返回的,所以不要用变量去接收
改为如下:
model.load_state_dict(torch.load(kw['model_path']))
model = model.cuda()