RuntimeError: view size is not compatible
运行具体错误如下:
报错的语句为:
flat = prob.view(-1)
修改为如下即可:
flat = prob.contiguous().view(-1)
原因是 .view() 需要 Tensor 对象的元素地址是连续分布的,但我们的 prob 可能是不连续的,因此先使用 .contiguous() 将 prob 转化为连续分布的。
更多关于 tensor 连续性的理解,可以去参考Pytorch 官方教程。