前言:
pytorch的torchvision.models模块中封装了alexnet,resnet、squeezenet,vgg,inception等常见网络的结构,并可以供我们方便地调用在ImageNet数据集上预训练过的模型。
一、finetune vgg16:
以torchvision.models.vgg16_bn为例(_bn表示包含BN层),首先来看一下它的网络结构,通过源码发现网络结构包含了以下三个部分:
1. features(包含了一堆卷积和最大池化操作):
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace=True)
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): ReLU(inplace=True)
(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): ReLU(inplace=True)
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): ReLU(inplace=True)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(26): ReLU(inplace=True)
(27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(32): ReLU(inplace=True)
(33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(36): ReLU(inplace=True)
(37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(39): ReLU(inplace=True)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace=True)
(43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
2. avgpool(包含一个平均池化操作):
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
3. classifier(包含了全连接操作,用于分类):
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
从网络结构中可以看出,当我们finetune时,只需修改网络中的classifier部分以匹配自己的数据集的类别数即可。
定义FineTuneVGG16类如下:
import torch
import torchvision.models as models
import torch.nn as nn
class FineTuneVGG16(nn.Module):
def __init__(self, num_class=10):
super(FineTuneVGG16, self).__init__()
vgg16_net = models.vgg16_bn(pretrained=False)
self.num_class = num_class
self.features = vgg16_net.features
self.avgpool = vgg16_net.avgpool
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 128),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(128, self.num_class),
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
if __name__ == '__main__':
input_test = torch.ones(1, 3, 224, 224)
vgg16 = FineTuneVGG16(num_class=10)
output_test = vgg16(input_test)
print(len(list(vgg16.parameters())))
print(get_parameter_number(vgg16))
print(output_test.shape)
也可以使用全局最大池化来替代全连接层,达到分类的效果,这样做的目的是减少模型参数,节约显存,如下所示:
import torch
import torchvision.models as models
import torch.nn as nn
class FineTuneVGG16(nn.Module):
def __init__(self, num_class=10):
super(FineTuneVGG16, self).__init__()
vgg16_net = models.vgg16_bn(pretrained=False)
self.num_class = num_class
self.features = vgg16_net.features
self.avgpool = vgg16_net.avgpool
self.classifier = nn.Sequential(
nn.Conv2d(512, 200, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(200),
nn.ReLU(True),
nn.Conv2d(200, self.num_class, kernel_size=1, stride=1, padding=0),
nn.AdaptiveAvgPool2d((1, 1)),
)
def forward(self, x):
batchsize = x.size(0)
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
x = x.view(batchsize, -1)
return x
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
if __name__ == '__main__':
input_test = torch.ones(1, 3, 224, 224)
vgg16 = FineTuneVGG16(num_class=10)
output_test = vgg16(input_test)
print(len(list(vgg16.parameters())))
print(get_parameter_number(vgg16))
print(output_test.shape)
二、finetune resnet50:
类似的步骤,这里是通过切片去掉resnet50中最后一层全连接层,然后添加上匹配自己数据集的类别数的全连接层即可。
定义的FineTuneResnet50类如下:
import torch
import torchvision.models as models
import torch.nn as nn
class FineTuneResnet50(nn.Module):
def __init__(self, num_class=10):
super(FineTuneResnet50, self).__init__()
self.num_class = num_class
resnet50_net = models.resnet50(pretrained=True)
# state_dict = torch.load("./models/resnet50-19c8e357.pth")
# resnet50_net.load_state_dict(state_dict)
self.features = nn.Sequential(*list(resnet50_net.children())[:-1])
self.classifier = nn.Linear(2048, self.num_class)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
input_test = torch.ones(1, 3, 224, 224).to(device)
resnet50_net = FineTuneResnet50(num_class=10).to(device)
output_test = resnet50_net(input_test)
# print(resnet50_net)
# print(output_test.shape)