代码连接:https://github.com/csgwon/pytorch-deconvnet/blob/master/models/vgg16_deconv.py
主要内容:
该论文主要使用UnPooling层
使用原因:池化方法是为了通过提取具有单一代表性值的接受域中的激活来过滤下层的噪声激活。虽然它通过只保留上层的稳健激活来帮助分类,但在池化过程中,接受域内的空间信息会丢失,这对于语义分割所需的精确定位可能至关重要。为了解决这一问题,我们在反褶积网络中使用了UnPooling层,它执行池化的反向操作,在交换变量中记录池化操作期间选择的最大激活的位置,用于将每个激活放回其原来的池位置。这种解池策略对于重建中描述的输入对象的结构特别有用,方法如图:
代码如下:即记住pooling的索引,通过索引进行反池化层,其余位置填充0
pool = nn.MaxPool2d(2, stride=2, return_indices=True)
unpool = nn.MaxUnpool2d(2, stride=2)
input = torch.tensor([[[[ 1., 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]]]])
output, indices = pool(input)
unpool(output, indices)
》》》tensor([[[[ 0., 0., 0., 0.],
[ 0., 6., 0., 8.],
[ 0., 0., 0., 0.],
[ 0., 14., 0., 16.]]]])
# specify a different output size than input size
unpool(output, indices, output_size=torch.Size([1, 1, 5, 5]))
》》》tensor([[[[ 0., 0., 0., 0., 0.],
[ 6., 0., 8., 0., 0.],
[ 0., 0., 0., 14., 0.],
[ 16., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.]]]])
模型结构:
代码:
采用VGG16作为backbone:
import torch
import torchvision.models as models
import numpy as np
vgg16_pretrained = models.vgg16(pretrained=True)
class VGG16_conv(torch.nn.Module):
def __init__(self, n_classes):
super(VGG16_conv, self).__init__()
# VGG16 (using return_indices=True on the MaxPool2d layers)
self.features = torch.nn.Sequential(
# conv1
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2, return_indices=True),
# conv2
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2, return_indices=True),
# conv3
torch.nn.Conv2d(128, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2, return_indices=True),
# conv4
torch.nn.Conv2d(256, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2, return_indices=True),
# conv5
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2, return_indices=True))
self.feature_outputs = [0]*len(self.features)
self.pool_indices = dict()
self.classifier = torch.nn.Sequential(
torch.nn.Linear(512*7*7, 4096), # 224x244 image pooled down to 7x7 from features
torch.nn.ReLU(),
torch.nn.Dropout(),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(),
torch.nn.Linear(4096, n_classes))
self._initialize_weights()
def _initialize_weights(self):
# initializing weights using ImageNet-trained model from PyTorch
for i, layer in enumerate(vgg16_pretrained.features):
if isinstance(layer, torch.nn.Conv2d):
self.features[i].weight.data = layer.weight.data
self.features[i].bias.data = layer.bias.data
def get_conv_layer_indices(self):
return [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
def forward_features(self, x):
output = x
for i, layer in enumerate(self.features):
if isinstance(layer, torch.nn.MaxPool2d):
output, indices = layer(output)
self.feature_outputs[i] = output
self.pool_indices[i] = indices
else:
output = layer(output)
self.feature_outputs[i] = output
return output
def forward(self, x):
output = self.forward_features(x)
output = output.view(output.size()[0], -1)
output = self.classifier(output)
return output
class VGG16_deconv(torch.nn.Module):
def __init__(self):
super(VGG16_deconv, self).__init__()
self.conv2DeconvIdx = {0:17, 2:16, 5:14, 7:13, 10:11, 12:10, 14:9, 17:7, 19:6, 21:5, 24:3, 26:2, 28:1}
self.conv2DeconvBiasIdx = {0:16, 2:14, 5:13, 7:11, 10:10, 12:9, 14:7, 17:6, 19:5, 21:3, 24:2, 26:1, 28:0}
self.unpool2PoolIdx = {15:4, 12:9, 8:16, 4:23, 0:30}
self.deconv_features = torch.nn.Sequential(
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(512, 512, 3, padding=1),
torch.nn.ConvTranspose2d(512, 512, 3, padding=1),
torch.nn.ConvTranspose2d(512, 512, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(512, 512, 3, padding=1),
torch.nn.ConvTranspose2d(512, 512, 3, padding=1),
torch.nn.ConvTranspose2d(512, 256, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(256, 256, 3, padding=1),
torch.nn.ConvTranspose2d(256, 256, 3, padding=1),
torch.nn.ConvTranspose2d(256, 128, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(128, 128, 3, padding=1),
torch.nn.ConvTranspose2d(128, 64, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(64, 64, 3, padding=1),
torch.nn.ConvTranspose2d(64, 3, 3, padding=1))
# not the most elegant, given that I don't need the MaxUnpools here
self.deconv_first_layers = torch.nn.ModuleList([
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(1, 512, 3, padding=1),
torch.nn.ConvTranspose2d(1, 512, 3, padding=1),
torch.nn.ConvTranspose2d(1, 512, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(1, 512, 3, padding=1),
torch.nn.ConvTranspose2d(1, 512, 3, padding=1),
torch.nn.ConvTranspose2d(1, 256, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(1, 256, 3, padding=1),
torch.nn.ConvTranspose2d(1, 256, 3, padding=1),
torch.nn.ConvTranspose2d(1, 128, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(1, 128, 3, padding=1),
torch.nn.ConvTranspose2d(1, 64, 3, padding=1),
torch.nn.MaxUnpool2d(2, stride=2),
torch.nn.ConvTranspose2d(1, 64, 3, padding=1),
torch.nn.ConvTranspose2d(1, 3, 3, padding=1) ])
self._initialize_weights()
def _initialize_weights(self):
# initializing weights using ImageNet-trained model from PyTorch
for i, layer in enumerate(vgg16_pretrained.features):
if isinstance(layer, torch.nn.Conv2d):
self.deconv_features[self.conv2DeconvIdx[i]].weight.data = layer.weight.data
biasIdx = self.conv2DeconvBiasIdx[i]
if biasIdx > 0:
self.deconv_features[biasIdx].bias.data = layer.bias.data
def forward(self, x, layer_number, map_number, pool_indices):
start_idx = self.conv2DeconvIdx[layer_number]
if not isinstance(self.deconv_first_layers[start_idx], torch.nn.ConvTranspose2d):
raise ValueError('Layer '+str(layer_number)+' is not of type Conv2d')
# set weight and bias
self.deconv_first_layers[start_idx].weight.data = self.deconv_features[start_idx].weight[map_number].data[None, :, :, :]
self.deconv_first_layers[start_idx].bias.data = self.deconv_features[start_idx].bias.data
# first layer will be single channeled, since we're picking a particular filter
output = self.deconv_first_layers[start_idx](x)
# transpose conv through the rest of the network
for i in range(start_idx+1, len(self.deconv_features)):
if isinstance(self.deconv_features[i], torch.nn.MaxUnpool2d):
output = self.deconv_features[i](output, pool_indices[self.unpool2PoolIdx[i]])
else:
output = self.deconv_features[i](output)
return output