Pytorch:卷积神经网络-预训练网络微调

教程介绍了如何使用PyTorch对预训练的VGG16模型进行微调,以识别十种猴子。首先,冻结VGG16的特征提取层,然后添加新的全连接层进行分类。数据预处理包括随机裁剪和水平翻转(训练集)以及重置分辨率和中心裁剪(验证集)。使用Adam优化器和交叉熵损失函数进行训练,并通过History库记录训练过程。最终,模型在训练集和验证集上表现出良好的识别效果。
摘要由CSDN通过智能技术生成

Pytorch: 微调预训练好的卷积神经网络(VGG) 识别十类猴子

Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, School of Artificial and Intelligence, Huazhong University of Science and Technology

Pytorch教程专栏链接


本教程不商用,仅供学习和参考交流使用,如需转载,请联系本人。

猴子的数据地址为:https://www.kaggle.com/slothkong/10-monkey-species 。其中包含了训练集和验证集

微调预训练的VGG16网络
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import hiddenlayer as hl
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torchvision import models
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchsummary import summary
from torchviz import make_dot
# 模型加载选择GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
cuda
1
GeForce MX250

冻结特征提取层参数,不更新权重,以提高网络训练速度

# 导入vgg16网络
vgg16 = models.vgg16(pretrained=True)
# 获取vgg16的特征提取层
vgg = vgg16.features
# 将vgg16的特征提取层的参数冻结,不对其进行更新
for param in vgg.parameters():
    param.requires_grad_(False)

在VGG16的基础上,设计全连接层,512、256、10。在前向传播函数中,由 self.classify 得到输出

class MyVggModel(nn.Module):
    def __init__(self):
        super(MyVggModel, self).__init__()
        # vgg16的特征提取层
        self.vgg = vgg
        # 添加新的全连接层
        self.classifier = nn.Sequential(nn.Linear(25088, 512), 
                                       nn.ReLU(), 
                                       nn.Dropout(p=0.5),
                                       nn.Linear(512, 256),
                                       nn.ReLU(),
                                       nn.Dropout(p=0.5),
                                       nn.Linear(256, 10),
                                       nn.Softmax(dim=1))
    def forward(self, x):
        # 前向传播
        x = self.vgg(x)
        x = x.view
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值