Pytorch源码学习之四:torchvision.models.squeezenet

0.介绍

Squeezenet网址
torchvision.model.squeeze官方文档
主要思想:堆叠Fire模块,每个Fire模块,分别采用1x1和3x3两个分支,最后做拼;,每个Fire的尺寸不变,channel数不变或增加;每个stage的Fire模块之间用nn.MaxPool2d进行下采样;使用卷积层代替FC层,channel数为类别数

1.源码

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.hub import load_state_dict_from_url

__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}

class Fire(nn.Module): #Fire模块
    def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):

        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.squeeze(x)
        x = self.squeeze_activation(x)
        return torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)

class SqueezeNet(nn.Module):

    def __init__(self, version='1.0', num_classes=1000):
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        if version == '1_0':
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2 ,ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
        elif version == '1_1':
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
        else:
            raise ValueError("Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected".format(version=version))
        #使用卷积代替全连接层
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1))
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
                    init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x.view(x.size(0), self.num_classes)

def _squeezenet(version, pretrained, progress, **kwargs):
    model = SqueezeNet(version, **kwargs)
    if pretrained:
        arch = 'squeezenet' + version
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def squeezenet1_0(pretrained=False, progress=True, **kwargs):
    return _squeezenet('1_0', pretrained, progress, **kwargs)

def squeezenet1_1(pretrained=False, progress=True, **kwargs):
    return _squeezenet('1_1', pretrained, progress, **kwargs)

2.一些用法

2.1 torch.cat

torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)
#按照第一个维度(channel维度)对[]内的Tensor进行拼接

2.2 nn.MaxPool2d()

nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1,
                         return_indices=False, ceil_mode=False)
# kernel_size(int or tuple) - max pooling的窗口大小
#stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size
#padding(int or tuple, optional) - 输入的每一条边补充0的层数
#dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数
#return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助
#ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作

2.3 使用全卷积代替全连接层

#使用全卷积代替FC层
self.classifier = nn.Sequential(
    nn.Dropout(0.5),
    final_conv,
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d((1,1))
    )
def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x.view(x.size(0), self.num_classes)
#即先采用AdaptiveAvgPool2D,将size变为1x1,channel数=num_classes,再做resize
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torchvision.modelsPyTorch中用于图像分类的模型库。它提供了许多预训练的深度学习模型,包括AlexNet、VGG、ResNet等,以及它们的各种变种。这些模型经过在大规模图像数据上的预训练,并在各种图像分类任务上展示出了很好的性能。 torchvision.models源码主要包括两部分:网络结构定义和预训练模型参数加载。 首先,torchvision.models中定义了各种经典的深度学习网络结构。这些网络结构通过继承自torch.nn.Module类来实现,其中包括了网络的前向传播过程和可训练参数的定义。例如,在ResNet模型中,torchvision.models定义了ResNet类,其中包含多个残差块,每个残差块由多个卷积层和恒等映射组成。而在VGG模型中,torchvision.models定义了VGG类,其中包含多个卷积层和全连接层。这些定义好的网络结构可以用于构建和训练图像分类模型。 其次,torchvision.models还提供了预训练模型参数的加载功能。这些预训练模型参数可以通过调用torchvision.models的load_state_dict方法加载到网络结构中。这些参数在大规模图像数据集上进行了预训练,并且经过了反向传播过程进行了优化,因此可以作为初始化参数来训练新的图像分类模型。通过加载预训练模型参数,可以帮助新模型更快地收敛和获得更好的性能。 总之,torchvision.models是一个方便使用的图像分类模型库,提供了各种经典的深度学习网络结构和预训练的模型参数。通过使用这些网络结构和加载预训练模型参数,可以更轻松地构建和训练图像分类模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值