在执行 torch.eye(weights.size(1), dtype=torch.bool, device=weights.device)
这句话的时候,出现了报错信息。
RuntimeError: eye_out not supported on CUDAType for Bool
参考这篇文章,链接
将 dtype=torch.bool
改为 dtype=torch.uint8
,即torch.eye(weights.size(1), dtype=torch.uint8, device=weights.device)
问题解决。