项目场景:
跑U-net网络的时候,有一步是torch.cat()操作,出现
下面是代码
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.conv1 = DoubleConv(in_channels, 32)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(32, 64)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(64, 128)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(128, 256)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(256, 512)
self.up6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv6 = DoubleConv(512, 256)
self.up7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv7 = DoubleConv(256, 128)
self.up8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv8 = DoubleConv(128, 64)
self.up9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.conv9 = DoubleConv(64, 32)
self.conv10 = nn.Conv2d(32, out_channels, 1)
def forward(self, x):
#print(x.shape)
c1 = self.conv1(x)
p1 = self.pool1(c1)
#print(p1.shape)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
#print(p2.shape)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
#print(p3.shape)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
#print(p4.shape)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
#print(up_7.shape,c3.shape)
merge7 = torch.cat([up_7, c3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, c1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
问题描述
本来是在跑彩色图像的去噪任务,想尝试一些灰度图像的去噪。刚开始只是修改了img_channel,将3改为1,想当然的将crop_img_size改为了灰度图像数据集的统一尺寸,没有考虑pool(2)操作后和上采样后的对齐问题。
比如5/2=2
但是2*2=4
这时候,4和5就对不上了。
原因分析:
torch.cat()函数的功能是将多个tensor类型矩阵的连接。它有两个参数,第一个是tensor元组或者tensor列表;第二个是dim,如果tensor是二维的,dim=0指在行上连接,dim=1指在列上连接。
注意:torch.cat 进行连接的tensor的shape,除了需要连接的维度上的shape值可不同,必须拥有相同的shape,a是(2,3),b是(2,20)即torch.cat((a,b),-1)可以进行连接;torch.cat((a,b),0)不可以进行连接,因为3和20值不同
那么问题找到了,就是维度没有对上。
为什么没有对上呢?先开始跑彩色图像的时候crop_img_size是256,稳稳的在2的幂结果上,怎么除2都不会有余数,再乘回去也没有误差。
但是我改成了180后,180/2/2=45,45/2=22,这时候22上采样完是44,44与45不齐了,问题产生。
解决方案:
尺寸crop_img_size改为2的幂结果,比如128,这应该也是为什么我们看到的大部分输入的size都是256、512等等这些数的原因,为了能整除,没有余数。