pytorch实现微调(finetune) vgg16和resnet50

前言:

pytorch的torchvision.models模块中封装了alexnet,resnet、squeezenet,vgg,inception等常见网络的结构,并可以供我们方便地调用在ImageNet数据集上预训练过的模型。

一、finetune vgg16:

以torchvision.models.vgg16_bn为例(_bn表示包含BN层),首先来看一下它的网络结构,通过源码发现网络结构包含了以下三个部分:

1. features(包含了一堆卷积和最大池化操作):

(features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU(inplace=True)
    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU(inplace=True)
    (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (36): ReLU(inplace=True)
    (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (39): ReLU(inplace=True)
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU(inplace=True)
    (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )

2. avgpool(包含一个平均池化操作):

(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))

3. classifier(包含了全连接操作,用于分类):

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )

从网络结构中可以看出,当我们finetune时,只需修改网络中的classifier部分以匹配自己的数据集的类别数即可。

定义FineTuneVGG16类如下:

import torch
import torchvision.models as models
import torch.nn as nn


class FineTuneVGG16(nn.Module):
    def __init__(self, num_class=10):
        super(FineTuneVGG16, self).__init__()
        vgg16_net = models.vgg16_bn(pretrained=False)
        self.num_class = num_class
        self.features = vgg16_net.features
        self.avgpool = vgg16_net.avgpool
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(128, self.num_class),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


if __name__ == '__main__':
    input_test = torch.ones(1, 3, 224, 224)
    vgg16 = FineTuneVGG16(num_class=10)
    output_test = vgg16(input_test)
    print(len(list(vgg16.parameters())))
    print(get_parameter_number(vgg16))
    print(output_test.shape)

也可以使用全局最大池化来替代全连接层,达到分类的效果,这样做的目的是减少模型参数,节约显存,如下所示:

import torch
import torchvision.models as models
import torch.nn as nn


class FineTuneVGG16(nn.Module):
    def __init__(self, num_class=10):
        super(FineTuneVGG16, self).__init__()
        vgg16_net = models.vgg16_bn(pretrained=False)
        self.num_class = num_class
        self.features = vgg16_net.features
        self.avgpool = vgg16_net.avgpool
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 200, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(200),
            nn.ReLU(True),

            nn.Conv2d(200, self.num_class, kernel_size=1, stride=1, padding=0),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

    def forward(self, x):
        batchsize = x.size(0)
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        x = x.view(batchsize, -1)
        return x
    
    
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


if __name__ == '__main__':
    input_test = torch.ones(1, 3, 224, 224)
    vgg16 = FineTuneVGG16(num_class=10)
    output_test = vgg16(input_test)
    print(len(list(vgg16.parameters())))
    print(get_parameter_number(vgg16))
    print(output_test.shape)

二、finetune resnet50:

类似的步骤,这里是通过切片去掉resnet50中最后一层全连接层,然后添加上匹配自己数据集的类别数的全连接层即可。

定义的FineTuneResnet50类如下:

import torch
import torchvision.models as models
import torch.nn as nn


class FineTuneResnet50(nn.Module):
    def __init__(self, num_class=10):
        super(FineTuneResnet50, self).__init__()
        self.num_class = num_class
        resnet50_net = models.resnet50(pretrained=True)
        # state_dict = torch.load("./models/resnet50-19c8e357.pth")
        # resnet50_net.load_state_dict(state_dict)
        self.features = nn.Sequential(*list(resnet50_net.children())[:-1])
        self.classifier = nn.Linear(2048, self.num_class)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    input_test = torch.ones(1, 3, 224, 224).to(device)
    resnet50_net = FineTuneResnet50(num_class=10).to(device)
    output_test = resnet50_net(input_test)
    # print(resnet50_net)
    # print(output_test.shape)

 

  • 10
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,我可以给你一个简单的示例代码。首先,你需要导入相关的库: ```python import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image ``` 接下来,我们可以定义一个函数来加载和预处理图像,以便将其传递到模型中进行特征提取: ```python def load_image(image_path): image = Image.open(image_path) # 定义预处理转换 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 对图像进行预处理 image = transform(image).unsqueeze(0) return image ``` 现在我们可以定义一个函数来提取VGG16ResNet50模型的特征: ```python def extract_features(image_path, model): # 加载和预处理图像 image = load_image(image_path) # 将图像传递到模型中进行特征提取 features = model(image) return features.detach().numpy() ``` 现在我们可以使用这些函数来进行特征提取。首先,我们将加载VGG16模型并提取图像的特征: ```python # 加载VGG16模型 vgg16 = models.vgg16(pretrained=True).features # 提取图像的特征 vgg16_features = extract_features('image.jpg', vgg16) ``` 接下来,我们将加载ResNet50模型并提取图像的特征: ```python # 加载ResNet50模型 resnet50 = models.resnet50(pretrained=True) # 删除ResNet50模型的最后一层全连接层 resnet50 = torch.nn.Sequential(*list(resnet50.children())[:-1]) # 提取图像的特征 resnet50_features = extract_features('image.jpg', resnet50) ``` 现在,我们已经提取了VGG16ResNet50模型的特征,可以进行比较和分析。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值