Xception原理加pytorch代码实现

Xception原理加代码实现

Xception原理加代码实现

1 Xception原理

第二层的每个3x3卷积核只处理第一层共享1x1卷积核的一个通道,grouped的极限进化版本

Note:depthwise(3x3处理单个通道)和pointwise(1x1处理所有通道)的前后顺序不影响结果,且在这两层之间不加入激活函数效果更好

在这里插入图片描述

2 Xception代码

from torch import nn
from torch.nn import Conv2d,BatchNorm2d
class Xception(nn.Module):
    def __init__(self,inp,oup):
        super(Xception, self).__init__()
        # depthwise
        self.conv1 = Conv2d(inp, inp, kernel_size=(3, 3), stride=(1, 1), padding=1, groups=inp)
        self.bn1 = BatchNorm2d(inp)  # 输入为上一层输出的通道数
        # pointwise
        self.conv2 = Conv2d(inp, oup, (1, 1))  # Stride of the convolution. Default: 1
        self.bn2 = BatchNorm2d(oup)
        self.relu = nn.ReLU()

    def forward(self, input):
        output = self.conv1(input)
        output = self.bn1(output)
        output = self.relu(output)
        output = self.conv2(output)
        output = self.bn2(output)
        output = self.relu(output)
        return output

以上就是全部内容

  • 0
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
Xception是一种深度卷积神经网络模型,它在ImageNet数据集上取得了很好的性能。下面是一个简单的Xception模型的PyTorch代码示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=False): super(SeparableConv2d, self).__init__() self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) def forward(self, x): x = self.depthwise(x) x = self.pointwise(x) return x class Block(nn.Module): def __init__(self, in_channels, out_channels, reps, stride=1, start_with_relu=True, grow_first=True): super(Block, self).__init__() if out_channels != in_channels or stride != 1: self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False) self.skipbn = nn.BatchNorm2d(out_channels) else: self.skip = None self.relu = nn.ReLU(inplace=True) rep = [] filters = in_channels if grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_channels, out_channels, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_channels)) filters = out_channels for i in range(reps - 1): rep.append(self.relu) rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(filters)) if not grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_channels, out_channels, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_channels)) if stride != 1: rep.append(self.relu) rep.append(SeparableConv2d(out_channels, out_channels, 3, stride=2, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_channels)) self.rep = nn.Sequential(*rep) def forward(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x += skip return x class Xception(nn.Module): def __init__(self, num_classes=1000): super(Xception, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, 3, bias=False) self.bn2 = nn.BatchNorm2d(64) self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, padding=1) self.bn3 = nn.BatchNorm2d(1536) self.conv4 = SeparableConv2d(1536, 2048, 3, stride=1, padding=1) self.bn4 = nn.BatchNorm2d(2048) self.fc = nn.Linear(2048, num_classes) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) x = self.bn4(x) x = self.relu(x) x = F.adaptive_avg_pool2d(x, (1, 1)) x = torch.flatten(x, 1) x = self.fc(x) return x model = Xception() ``` 这段代码定义了一个Xception模型,包括了各个模块和层的定义。你可以根据自己的需求进行修改和使用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值