使用torch checkpoint时报错:
'NoneType' object has no attribute "'detach'"
报错地址:
site-packages/torch/utils/checkpoints.py
报错原因:
torch版本较旧,旧版本中checkpoints.py中缺失对Input是否是tensor的判断,导致报错。升级torch版本,可以解决这个问题。
报错根本原因:
sparse_masks是self.blocks[i] forward函数中的参数,在此处是None值。在高版本torch中这样写没有问题,但在低版本torch中会导致最开始的报错。
x = checkpoint.checkpoint(self.blocks[i], x, sparse_masks)
解决:
tmp = partial(self.blocks[i], sparse_masks=sparse_masks)
x = checkpoint.checkpoint(tmp, x)