最近一直在复现TransUnet网络的结构,经常会遇到修改输入图片维度的问题。比如在train的时候,读入的数据的outputs要4维的,而label则要3维的。此时就需要调用torch.squeeze(),让原本4维的label变为3维。
又比如,在test的时候,读入的image和label都要是3维的,可是,label是灰度图,本身只有1维。这时最佳的解决方法不是修改图片的维度数,而是在读取图片时就要把灰度图按彩色(RGB)图来读取。
我们依次来看:一、修改图片的维度数:
label = torch.squeeze(label,dim =1)
修改之前的label的维度是(24,1,224,224) 删掉第2维度后,变为(24,224,224)
二、
使用cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 将灰度图按彩色图来读取。读取后的label的维度 变为(3,224,224)