- 做实验发现seg_gt图边缘乱码了(心里寻思gt怎么还能乱码呢,不合理啊),查了半天发现读取的label边缘就是乱码,如图(强迫症看了想把自己吊起来)
- 问了问师兄,师兄说是下采样插值的问题,把双线性插值改为最邻近插值就好了。(不愧是师兄,一眼就看出来问题的本质)
- 问题就出在了transforms.Resize这个函数上,默认采用双线性插值。(我用的数据集是CelebAMask-HQ,其中分割标签的分辨率为512512,我的模型需要将输入resize为256256,没有注意分割标签的插值问题)
torchvision.transforms.Resize文档
可以利用InterpolationMode类来设置interpolation选项。
- 但是发现不会用,看源码,调了个函数解决的。
torchvision.transforms.functional.InterpolationMode文档 - 具体如下,利用数字设置所选用的插值方式
InterpolationMode源代码
- 解决~
- 不止是resize函数,transforms中的一系列函数,如果带有插值操作,都是可以设置插值方式的哦!!
- 不过要注意使用的torchvision版本,查看对应的文档。太早的torchvision版本可能没有写这个功能。
- 以及,文档的阅读真的真的真的很重要。
整理一下:
- pytorch官方英文文档
- torchvision官方英文文档
- paddlepaddle官方文档 (最近在看论文复现赛,虽然不是真正意义上的复现,不过可以熟悉一下查文档的能力,xdm可以都去试试)