我在训练vilbert的时候,报错如下:
File "/vilbert/vilbert.py", line 1351, in forward
dtype=next(self.parameters()).dtype
StopIteration
查了一下可能是pytorch版本问题,所以有两种解决方法:
解决方法一
降级pytorch==1.4.0
解决方法二
修改代码:
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
为:
extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility