pytorch理论:源码
代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bottleneck(nn.Module): # shuffle Net 模仿的是resnet bottleblock的结构
def __init__(self, in_planes, out_planes, stride, groups):
super(Bottleneck, self).__init__()
self.stride = stride
mid_planes = out_planes // 4 # 每个ShuffleNet unit的bottleneck通道数为输出的1/4(和ResNet设置一致)
self.groups = 1 if in_planes == 24 else groups # 第一层卷积之后是24,所以不必group
self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=self.groups, bias=False)
self.bn1 = nn.BatchNorm2d(mid_planes)
self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) # 这里应该用dw conv的
self.bn2 = nn.BatchNorm2d(mid_planes)
self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
self.bn3 = nn.BatchNorm2d(out_planes)
if stride == 2:
self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) # 每个阶段第一个block步长是2,下个阶段通道翻倍
@staticmethod