fcn, fully convolutional network,全卷积网络,为了解决图像分割问题,把vgg的全连接层换成转置卷积层。本文实现fcn_4s核心部分,不包括导包和上采样模块权重初始化等。
如上图所示,输入图片S_In经过conv1后得到S,经过pool1后大小减半,记录为S/2,经过conv2后大小不变,经过pool2后大小减半,记录为S/4,同样的分别经过pool3、pool4、pool5后分别记为S/8,S/16,S/32,经过kernel_size为7*7的conv6后变为S/32-6,经过kernel_size为1的conv7后大小不变还是S/32-6。接下来需要把S/32-6最终恢复成S_In的大小,为了实现方便,先恢复为S大小再由S恢复至S_In。有三种恢复方案:16s、8s、4s。本文讨论4s。4s是将conv7后得到的特征图(S/32-6)上采样至本身8倍、将pool4后得到的特征图(S/16)上采样至本身的4倍、将pool3后得到的特征图(S/8)上采样至本身的2倍,这样它们的大小都是S/4,将它们与pool2后得到的特征图(S/4)相加得到特征图S/4。将该特征图上采样至本身的4倍得到S。这里的S/32、S/16等大小并不一定是这个数,因为转置卷积的参数的设置,是的特征图的大小不一定恰好等于S/32、S/16等数字,写出S/32、S/16等数字是为了更方便的分析过程,也可以助于理解原理。
具体的代码过程并不是将(S/32-6)直接上采样至本身的8倍,而是跟其他特征图一起分步上采样。
如上图所示,先对conv7后得到的特征图上采样至本身的2倍再与pool4后得到的特征图相加得到特征图S/16,将S/16上采样至本身的两倍再与pool3后得到的特征图相加得到特征图S/8,将S/8上采样至本身2倍再与pool2后得到的特征图相加得到特征图S/4,将S/4上采样至本身4倍得到S。由于转置卷积参数的设置问题,这里的上采样至本身的n倍不一定严格等于本身的n倍,所以需要裁剪特征图的大小。
关于裁剪的计算,需要掌握转置卷积输出的长/宽公式:O=(I-1)*S+K-2P。
经过conv7后特征图大小为S/32-6,将S/32-6上采样至本身2倍的卷积核大小为4,stride为2,故得到(S/32-6-1)*2+4=S/16-10,要将S/16-10和pool4后得到的S/16相加,就要先将S/16裁切,要保留中间部分,故左边右边上边下边各裁掉5个单位,也就是forward函数里的“h = h[:, :, 5:5 + up2_feat.size(2), 5:5 + up2_feat.size(3)]”。
这两个特征图相加后得到S/16-10。
接着将(S/16-10)上采样至本身2倍,故同样地得到(S/16-10-1)*2+4=S/8-18,要将S/8-18与pool3后的S/8相加,就要把S/8裁切,
把S/8上下左右各裁9,也就是“h = h[:, :, 9:9 + up4_feat.size(2), 9:9 + up4_feat.size(3)]”。
这两个特征图相加后得到S/8-18。
继续将(S/8-18)上采样至本身2倍,同样地得到(S/8-18-1)*2+4=
S/4-34,要把S/4-34与pool2后S/4相加,需要把S/4上下左右各裁17,也就是“h = h[:,:,17:17+ up8_feat.size(2), 17:17+ up8_feat.size(3)]”。
这两个特征图相加后得到S/4-34。
将(S/4-34)上采样至本身的4倍,得到(S/4-34-1)4+8=S-132。
S是S_In经过conv1后得到的,conv1的kernel_size为3,padding为100,故有关系S=(S_In-3+2100)/1+1=S_In+198。那么S-132=S_In+66。故将S_In+66上下左右各裁33个单元得到S_In,也就是“final_scores = h[:, :, 33:33 + x.size(2), 33:33 + x.size(3)].contiguous()”。
具体代码如下:
class VGG_19bn_4s(nn.Module):
def __init__(self, n_class=21):
super(VGG_19bn_4s, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=100)
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = Layer(64, [64])
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer2 = Layer(64, [128, 128])
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer3 = Layer(128, [256, 256, 256, 256])
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer4 = Layer(256, [512, 512, 512, 512])
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer5 = Layer(512, [512, 512, 512, 512])
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc6 = nn.Conv2d(512, 4096, 7) # padding=0
self.relu6 = nn.ReLU(inplace=True)
self.drop6 = nn.Dropout2d()
self.fc7 = nn.Conv2d(4096, 4096, 1)
self.relu7 = nn.ReLU(inplace=True)
self.drop7 = nn.Dropout2d()
self.score_fr = nn.Conv2d(4096, n_class, 1)
self.trans_f4 = nn.Conv2d(512, n_class, 1)
self.trans_f3 = nn.Conv2d(256, n_class, 1)
self.trans_f2 = nn.Conv2d(128, n_class, 1)
self.up2times = nn.ConvTranspose2d(
n_class, n_class, 4, stride=2, bias=False)
self.up4times = nn.ConvTranspose2d(
n_class, n_class, 4, stride=2, bias=False)
self.up8times = nn.ConvTranspose2d(
n_class, n_class, 4, stride=2, bias=False)
self.up32times = nn.ConvTranspose2d(
n_class, n_class, 8, stride=4, bias=False)
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
m.weight.data = bilinear_kernel(n_class, n_class, m.kernel_size[0])
def forward(self, x):
f0 = self.relu1(self.bn1(self.conv1(x)))
f1 = self.pool1(self.layer1(f0))
f2 = self.pool2(self.layer2(f1))
f3 = self.pool3(self.layer3(f2))
f4 = self.pool4(self.layer4(f3))
f5 = self.pool5(self.layer5(f4))
f6 = self.drop6(self.relu6(self.fc6(f5)))
f7 = self.score_fr(self.drop7(self.relu7(self.fc7(f6))))
up2_feat = self.up2times(f7)
h = self.trans_f4(f4)
print(h.shape)
print(up2_feat.shape)
h = h[:, :, 5:5 + up2_feat.size(2), 5:5 + up2_feat.size(3)]
h = h + up2_feat
up4_feat = self.up4times(h)
h = self.trans_f3(f3)
print(h.shape)
print(up4_feat.shape)
h = h[:, :, 9:9 + up4_feat.size(2), 9:9 + up4_feat.size(3)]
h = h + up4_feat
up8_feat = self.up8times(h)
h = self.trans_f2(f2)
print(h.shape)
print(up8_feat.shape)
h = h[:,:,17:17+ up8_feat.size(2), 17:17+ up8_feat.size(3)]
h = h + up8_feat
h = self.up32times(h)
print(h.shape)
final_scores = h[:, :, 33:33 + x.size(2), 33:33 + x.size(3)].contiguous()
return final_scores
model = VGG_19bn_4s(21)
x = torch.randn(2, 3, 64, 64)
model.eval()
y_vgg = model(x)
y_vgg.size()
输出为:
torch.Size([2, 21, 16, 16])
torch.Size([2, 21, 6, 6])
torch.Size([2, 21, 32, 32])
torch.Size([2, 21, 14, 14])
torch.Size([2, 21, 65, 65])
torch.Size([2, 21, 30, 30])
torch.Size([2, 21, 124, 124])
torch.Size([2, 21, 64, 64])