Inception-ResNet模型框架(PyTorch)

I. 前言

Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning一文中,除了提出Inception Network的v4版本,还与ResNet进行结合,提出了Inception-ResNet-v1及Inception-ResNet-v2两个模型.

II. 模型构架图

在这里插入图片描述
【注】Inception-ResNet-v1及Inception-ResNet-v2的总体构架一致,但各部分的结构不尽相同,现予以说明.

1. Inception-ResNet-v1

1.1 Stem

在这里插入图片描述

1.2 Inception-ResNet-A

在这里插入图片描述

1.3 Inception-ResNet-B

在这里插入图片描述

1.4 Inception-ResNet-C

在这里插入图片描述

1.5 Reduction-A

在这里插入图片描述

1.6 Reduction-B

在这里插入图片描述

2. Inception-ResNet-v2

2.1 Stem

在这里插入图片描述

2.2 Inception-ResNet-A

在这里插入图片描述

2.3 Inception-ResNet-B

在这里插入图片描述

2.4 Inception-ResNet-C

在这里插入图片描述

2.5 Reduction-A

在这里插入图片描述

2.6 Reduction-B

在这里插入图片描述

III. 代码复现

import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        return out
class InceptionResNetA(nn.Module):
    def __init__(self, in_channels):
        super(InceptionResNetA, self).__init__()
        #branch1: conv1*1(32)
        self.b1 = BasicConv2d(in_channels, 32, kernel_size=1)
        
        #branch2: conv1*1(32) --> con3*3(32)
        self.b2_1 = BasicConv2d(in_channels, 32, kernel_size=1)
        self.b2_2 = BasicConv2d(32, 32, kernel_size=3, padding=1)
        
        #branch3: conv1*1(32) --> conv3*3(32) --> conv3*3(32)
        self.b3_1 = BasicConv2d(in_channels, 32, kernel_size=1)
        self.b3_2 = BasicConv2d(32, 32, kernel_size=3, padding=1)
        self.b3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
        
        #totalbranch: conv1*1(256)
        self.tb = BasicConv2d(96, 256, kernel_size=1)
        
    def forward(self, x):
        x = F.relu(x)
        b_out1 = F.relu(self.b1(x))
        b_out2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))
        b_out3 = F.relu(self.b3_3(F.relu(self.b3_2(F.relu(self.b3_1(x))))))
        b_out = torch.cat([b_out1, b_out2, b_out3], 1)
        b_out = self.tb(b_out)
        y = b_out + x
        out = F.relu(y)
                           
        return out
class InceptionResNetB(nn.Module):
    def __init__(self, in_channels):
        super(InceptionResNetB, self).__init__()
        #branch1: conv1*1(128)
        self.b1 = BasicConv2d(in_channels, 128, kernel_size=1)
        
        #branch2: conv1*1(128) --> con1*7(128) --> conv7*1(128)
        self.b2_1 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.b2_2 = BasicConv2d(128, 128, kernel_size=(1,7), padding=(0,3))
        self.b2_3 = BasicConv2d(128, 128, kernel_size=(7,1), padding=(3,0))
    
        #totalbranch: conv1*1(896)
        self.tb = BasicConv2d(256, 896, kernel_size=1)
        
    def forward(self, x):
        x = F.relu(x)
        b_out1 = F.relu(self.b1(x))
        b_out2 = F.relu(self.b2_3(F.relu(self.b2_2(F.relu(self.b2_1(x))))))
        b_out = torch.cat([b_out1, b_out2], 1)
        b_out = self.tb(b_out)
        y = b_out + x
        out = F.relu(y)
                           
        return out
class InceptionResNetC(nn.Module):
    def __init__(self, in_channels):
        super(InceptionResNetC, self).__init__()
        #branch1: conv1*1(192)
        self.b1 = BasicConv2d(in_channels, 192, kernel_size=1)
        
        #branch2: conv1*1(192) --> con1*3(192) --> conv3*1(192)
        self.b2_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.b2_2 = BasicConv2d(192, 192, kernel_size=(1,3), padding=(0,1))
        self.b2_3 = BasicConv2d(192, 192, kernel_size=(3,1), padding=(1,0))
    
        #totalbranch: conv1*1(1792)
        self.tb = BasicConv2d(384, 1792, kernel_size=1)
        
    def forward(self, x):
        x = F.relu(x)
        b_out1 = F.relu(self.b1(x))
        b_out2 = F.relu(self.b2_3(F.relu(self.b2_2(F.relu(self.b2_1(x))))))
        b_out = torch.cat([b_out1, b_out2], 1)
        b_out = self.tb(b_out)
        y = b_out + x
        out = F.relu(y)
                           
        return out
class ReductionA(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        super(ReductionA, self).__init__()
        #branch1: maxpool3*3(stride2 valid)
        self.b1 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        #branch2: conv3*3(n stride2 valid)
        self.b2 = BasicConv2d(in_channels, n, kernel_size=3, stride=2)
        
        #branch3: conv1*1(k) --> conv3*3(l) --> conv3*3(m stride2 valid)
        self.b3_1 = BasicConv2d(in_channels, k, kernel_size=1)
        self.b3_2 = BasicConv2d(k, l, kernel_size=3, padding=1)
        self.b3_3 = BasicConv2d(l, m, kernel_size=3, stride=2)
        
    def forward(self, x):
        y1 = self.b1(x)
        y2 = F.relu(self.b2(x))
        y3 = F.relu(self.b3_3(F.relu(self.b3_2(F.relu(self.b3_1(x))))))
        
        outputsRedA = [y1, y2, y3]
        return torch.cat(outputsRedA, 1)
class ReductionB(nn.Module):
    def __init__(self, in_channels):
        super(ReductionB, self).__init__()
        #branch1: maxpool3*3(stride2 valid)
        self.b1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        
        #branch2: conv1*1(256) --> conv3*3(384 stride2 valid)
        self.b2_1 = BasicConv2d(in_channels, 256, kernel_size=1)
        self.b2_2 = BasicConv2d(256, 384, kernel_size=3, stride=2)
        
        #branch3: conv1*1(256) --> conv3*3(256 stride2 valid)
        self.b3_1 = BasicConv2d(in_channels, 256, kernel_size=1)
        self.b3_2 = BasicConv2d(256, 256, kernel_size=3, stride=2)
        
        #branch4: conv1*1(256) --> conv3*3(256) --> conv3*3(256 stride2 valid)
        self.b4_1 = BasicConv2d(in_channels, 256, kernel_size=1)
        self.b4_2 = BasicConv2d(256, 256, kernel_size=3, padding=1)
        self.b4_3 = BasicConv2d(256, 256, kernel_size=3, stride=2)
        
    def forward(self, x):
        y1 = self.b1(x)
        y2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))
        y3 = F.relu(self.b3_2(F.relu(self.b3_1(x))))
        y4 = F.relu(self.b4_3(F.relu(self.b4_2(F.relu(self.b4_1(x))))))
        
        outputsRedB = [y1, y2, y3, y4]
        return torch.cat(outputsRedB, 1)
class StemForIR1(nn.Module):
    def __init__(self, in_channels):
        super(StemForIR1, self).__init__()
        #conv3*3(32 stride2 valid)
        self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
        #conv3*3(32 valid)
        self.conv2 = BasicConv2d(32, 32, kernel_size=3)
        #conv3*3(64)
        self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        #maxpool3*3(stride2 valid)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        #conv1*1(80)
        self.conv4 = BasicConv2d(64, 80, kernel_size=1)
        #conv3*3(192 valid)
        self.conv5 = BasicConv2d(80, 192, kernel_size=3)
        #conv3*3(256, stride2 valid)
        self.conv6 = BasicConv2d(192, 256, kernel_size=3, stride=2)
        
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = self.maxpool1(out)
        out = F.relu(self.conv4(out))
        out = F.relu(self.conv5(out))
        out = F.relu(self.conv6(out))
        
        return out
class InceptionResNetv1(nn.Module):
    def __init__(self):
        super(InceptionResNetv1, self).__init__()
        self.stem = StemForIR1(3)
        self.irA = InceptionResNetA(256)
        self.redA = ReductionA(256, 192, 192, 256, 384)
        self.irB = InceptionResNetB(896)
        self.redB = ReductionB(896)
        self.irC = InceptionResNetC(1792)
        self.avgpool = nn.MaxPool2d(kernel_size=8)
        self.dropout = nn.Dropout(p=0.8)
        self.linear = nn.Linear(1792, 1000)
        
    def forward(self, x):
        n = [5, 10, 5]
        out = self.stem(x)
        
        if n[0] > 0:
            out = self.irA(out)
            n[0] -= 1
        out = self.redA(out)
        
        if n[1] > 0:
            out = self.irB(out)
            n[1] -= 1
        out = self.redB(out)
        
        if n[2] > 0:
            out = self.irC(out)
            n[2] -= 1
            
        out = self.avgpool(out)
        out = self.dropout(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        
        return out
### Inception-ResNet-V2 模型的提出背景与相关信息 Inception-ResNet-V2 是一种基于卷积神经网络(CNN)架构设计的深度学习模型,其主要目标是在提高分类性能的同时减少计算资源的需求。该模型由 Google 的研究团队于 2016 年提出[^1],并作为 Inception 系列的一部分进行了改进发展。 #### 提出背景 随着深度学习技术的发展,研究人员发现通过增加网络层数可以显著提升模型的表现能力。然而,单纯堆叠更多层可能导致梯度消失或爆炸等问题,从而阻碍训练过程的有效性。为解决这一挑战,残差连接(residual connections)的概念被引入到 CNN 架构中[^2]。这种机制允许信息绕过某些层直接传递至后续层,有效缓解了深层网络中的优化困难问题。 在此背景下,Google 团队结合经典的 Inception 结构与新兴的残差网络设计理念开发出了 Inception-ResNet 系列模型,其中包括 Inception-ResNet-V2。具体而言,它继承了前代版本的优点——即利用多尺度特征提取来增强表达能力灵活性,并进一步融入了残差单元的设计理念以促进更深层次的学习能力[^3]。 #### 主要特点 - **混合结构**:Inception-ResNet-V2 将标准的 Inception 与带有跳跃连接的残差相结合,在保持高效特征表示的基础上提升了收敛速度最终精度。 - **轻量化设计**:相比其他复杂的大规模模型,此版本采用了更加紧凑高效的组件配置策略,使得整体参数量得以控制而不牺牲太多准确性。 - **渐进式降维方法**:为了降低计算成本同时保留重要空间信息,作者们精心调整了各阶段内的操作顺序以及滤波器尺寸分布模式。 以下是实现该算法的一个简单 PyTorch 版本代码片段: ```python import torch.nn as nn class BasicConv2d(nn.Module): def __init__(self, input_channels, output_channels, **kwargs): super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(output_channels) def forward(self, x): x = self.conv(x) x = self.bn(x) return nn.functional.relu(x) class StemBlock(nn.Module): ... # 定义完整的网络... ``` 以上仅为部分核心逻辑展示;实际部署时需补充完整定义及初始化细节等内容。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值