pytorch迁移学习
本例子以图像语义分割最终的Segnet网络为例进行说明,简单明了,希望对大家有所帮助。
1.完整代码
原始代码链接:原始代码链接
根据自己需要修改后的代码:
import torch.nn as nn
import torch
from torchvision import models
from tensorboardX import SummaryWriter
from torchsummary import summary
class _DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_conv_layers):
super(_DecoderBlock, self).__init__()
middle_channels = in_channels // 2 #原链接为 in_channels /2,会报错
layers = [
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2),
nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True)
]
layers += [
nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
] * (num_conv_layers - 2)
layers += [
nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
self.decode = nn.Sequential(*layers)
def forward(self, x):
return self.decode(x)
class SegNet(nn.Module):
def __init__(self, num_classes):
super(SegNet, self).__init__()
vgg = models.vgg19_bn(pretrained=True)#pretrained=True下载网络的权重
# if pretrained:
# vgg.load_state_dict(torch.load(vgg19_bn_path))
features = list(vgg.features.children())
self.enc1 = nn.Sequential(*features[3:7])#主要修改的地方
self.enc2 = nn.Sequential(*features[7:14])
self.enc3 = nn.Sequential(*features[14:27])
self.enc4 = nn.Sequential(*features[27:40])
self.enc5 = nn.Sequential(*features[40:])
self.dec5 = nn.Sequential(
*([nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)] +
[nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)] * 4)
)
self.dec4 = _DecoderBlock(1024, 256, 4)
self.dec3 = _DecoderBlock(512, 128, 4)
self.dec2 = _DecoderBlock(256, 64, 2)
self.dec1 = _DecoderBlock(128, num_classes, 2)
# initialize_weights(self.dec5, self.dec4, self.dec3, self.dec2, self.dec1)
self.input = nn.Sequential(nn.Conv2d(2,64,3,1,1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
def forward(self, x1,x2):#x1,x2
layer_merged = torch.cat((x1,x2),1)#改成了双输入结构
input = self.input(layer_merged)
enc1 = self.enc1(input)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
enc5 = self.enc5(enc4)
dec5 = self.dec5(enc5)
dec4 = self.dec4(torch.cat([enc4, dec5], 1))
dec3 = self.dec3(torch.cat([enc3, dec4], 1))
dec2 = self.dec2(torch.cat([enc2, dec3], 1))
dec1 = self.dec1(torch.cat([enc1, dec2], 1))
return dec1
if __name__ == '__main__':
SegNet = SegNet(num_classes=31)
summary(SegNet,[(1, 160, 160), (1, 160, 160)])# summary打印网络的具体用法见我上一篇博客
2.迁移学习
首先找到体现迁移学习的代码:
from torchvision import models
vgg = models.vgg19_bn(pretrained=True)#载入models模块儿中训练好的vgg_bn模型
features = list(vgg.features.children())
self.enc1 = nn.Sequential(*features[3:7])
self.enc2 = nn.Sequential(*features[7:14])
self.enc3 = nn.Sequential(*features[14:27])
self.enc4 = nn.Sequential(*features[27:40])
self.enc5 = nn.Sequential(*features[40:])
1.首先分析“vgg.features.children()”这句话的精髓,理解了这句代码,一切就好办了。
晦涩难懂,vgg是加载进来的模型,那么feature是什么?**children()**又是啥意思?别着急,接下来我们一一分析:
First,我们理解feature从哪儿来的。
from torchvision import models
vgg = models.vgg19_bn()
pprint.pprint(vgg)
结果如下:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace)
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): ReLU(inplace)
(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): ReLU(inplace)
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): ReLU(inplace)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace)
(23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(25): ReLU(inplace)
(26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(29): ReLU(inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(32): ReLU(inplace)
(33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(35): ReLU(inplace)
(36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(38): ReLU(inplace)
(39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace)
(43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(45): ReLU(inplace)
(46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(48): ReLU(inplace)
(49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(51): ReLU(inplace)
(52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
这样,我们看到了调用的vgg19_bn()这个模型的结构,注意到该网络的结构如下:
VGG(
(features):
(avgpool):
(classifier):
)
我们的segnet的迁移学习主要是提取了(features)的网络层用在了自身的网络中作为“编码——解码”结构的编码过程。
那么,vgg.features就可以理解了,就是提取vgg模型的features网络层部分。
Second,现在理解vgg.features.children()中的children()
pytorch中与children()对应的是modules()
简单例子:
a=[1,2,[3,4]]
# children返回:
1,2,[3,4]
# modules返回:
[1,2,[3,4]], 1, 2, [3,4], 3, 4
(features)下的网络层较多,我们拿网络中的 (classifier)部分进行演示:
演示children:
from torchvision import models
vgg = models.vgg19_bn()
# pprint.pprint(vgg)
classifier = list(vgg.classifier.children())
pprint.pprint(classifier)
结果如下:
[Linear(in_features=25088, out_features=4096, bias=True),
ReLU(inplace),
Dropout(p=0.5),
Linear(in_features=4096, out_features=4096, bias=True),
ReLU(inplace),
Dropout(p=0.5),
Linear(in_features=4096, out_features=1000, bias=True)]
演示modules:
from torchvision import models
vgg = models.vgg19_bn()
# pprint.pprint(vgg)
classifier = list(vgg.classifier.modules())
pprint.pprint(classifier)
结果如下:
[Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
),
Linear(in_features=25088, out_features=4096, bias=True),
ReLU(inplace),
Dropout(p=0.5),
Linear(in_features=4096, out_features=4096, bias=True),
ReLU(inplace),
Dropout(p=0.5),
Linear(in_features=4096, out_features=1000, bias=True)]
因此,可以发现,children返回的是结构中的每一层网络即 Sequential中的每一层,而module不但返回每一层,而且还返回Sequential[…]这个完整的部分。
2. 最后,把目光转向 self.enc1 = nn.Sequential(*features[3:7])
经过以上的分析,这句程序理解起来就没有问题了:
现在列出features的(0)到(9)层的内容
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace)
**features[3:7]**得到的就是以下的网络层;
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
简单示例,希望对大家有所帮助。