【无标题】

深度学习系列文章目录
例如:第一章  pytorch提供的众多模型


前言

提示:这里可以添加本文要记录的大概内容:

无需自己从头搭建已经各种顶会或预印的深度学习网络


提示:以下是本篇文章正文内容,下面案例可供参考

一、现有哪些网络?

torchvision.models模块里有的

from .alexnet import *
from .convnext import *
from .densenet import *
from .efficientnet import *
from .googlenet import *
from .inception import *
from .mnasnet import *
from .mobilenet import *
from .regnet import *
from .resnet import *
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from .swin_transformer import *

 有监督学习中,用于图像分类的卷积神经网络,从最开始的alexnet到resnet到squeezenet各种轻量级网络,均提供了。这里不重复网络的各种论文原理

二、使用步骤

1.引入库

 

代码如下(示例):

import torchvision.models as models 
import pathlib import Path

2.修改网络

自定义一个新类,到网上自己下载预训练权重到本地,在自定义类中可以修改原网络,来实现2048分类任务为适合自己的目标分类任务,以下为5分类。分类网络的输出层通常有一个与类别数相等的输出节点,所以num_classes = 5

  
weights_path = Path('"/home/user/resnet_re/resnet50-19c8e357.pth"')
self.resnet = models.resnet50(pretrained=False, weights=weights_path) 
self.resnet.fc = nn.Linear(2048, num_classes)
num_classes = 5 

如果自己没有从网上手动下载,可以让模型自动下载:但一般会下载到临时文件夹里。形如"~/.cache/torch/hub/checkpoints/",在以后加载相同的预训练模型时,权重文件将从缓存目录中加载,而不需要再次下载。

self.resnet = models.resnet50(pretrained=True)

下载时在终端打印如下进度条 

 

Downloading model.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 346M/346M [00:41<00:00, 8.29MB/s]

 但例如VisionTransformer模型并没有内置的from_pretrained()方法。不能通过

self.vit = VisionTransformer.from_pretrained('vit-base-patch16-224')来加载预训练模型,

如果你希望使用预训练的Vision Transformer模型,可以尝试使用timm库(一个常用的图像模型库)中的预训练模型,可以通过命令来安装它

 

pip install timm
# 该库有create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value方法
#例如使用create_model可以使用预训练模型

import timm
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)

 其中'vit_base_patch16_224'是模型的名称,pretrained=True表示使用预训练权重。

  • EfficientNet: efficientnet_b0, efficientnet_b1, ..., efficientnet_b7
  • ResNet: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2
  • ResNeXt: resnext50_32x4d, resnext101_32x4d, resnext101_64x4d
  • VGG: vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
  • DenseNet: densenet121, densenet169, densenet201, densenet161
  • Inception: inception_v3
  • MobileNet: mobilenet_v2, mobilenetv3_large, mobilenetv3_small
  • ViT (Vision Transformer): vit_base_patch16_224, vit_base_patch16_384, vit_base_patch32_384, vit_large_patch16_224, vit_large_patch16_384, vit_large_patch32_384, vit_huge_patch16_224, vit_huge_patch32_384

在timm官方GitHub仓库的模型目录中找到完整的模型列表:如下

https://github.com/huggingface/pytorch-image-models/tree/main/timm/models

或修改为回归网络num_classes=1,与分类网络使用 softmax 激活函数来输出每个类别的概率不同。回归网络中,使用线性激活函数来输出一个连续的数值;除了输出层的节点数,损失函数的替换,在原始分类网络通常使用交叉熵损失函数来衡量预测类别与真实类别之间的差异,回归网络则通常使用均方误差或者平均绝对误差来衡量预测值与真实值之间的差异。

self.resnet.fc = nn.Linear(2048, 1)
self.loss = nn.MSELoss()

 注意:以上修改网络是在self.resnet基础之上的,例如修改输出层的线性层self.resnet.fc,所以定义前向传播函数如下即可:

def forward(self, x):
        x = self.resnet(x)
        return x

但如何是重新定义一层 ,该层必须要前向传播

self.fc = nn.Linear(1000, 5)


def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        return x

 打印模型结构,可以看到模型哪些地方被修改了

model = RegressionModel()
# 打印模型结构
print(model)

总结

以上就是自己学习的记录,侵权立删。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值