S
E
N
e
t
−
M
o
d
e
l
(
p
y
t
o
r
c
h
版
本
)
SENet-Model(pytorch版本)
SENet−Model(pytorch版本)
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet
from torch import nn
from sklearn.linear_model import LogisticRegression
import torch.nn as nn
import torch.nn.functional as F
import torch
classSELayer(nn.Module):def__init__(self, channel, reduction=16):super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid())defforward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c,1,1)return x * y.expand_as(x)defconv3x3(in_planes, out_planes, stride=1):return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)classSEBasicBlock(nn.Module):
expansion =1def__init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None,*, reduction=16):super(SEBasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes,1)
self.bn2 = nn.BatchNorm2d(planes)
self.se = SELayer(planes, reduction)
self.downsample = downsample
self.stride = stride
defforward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)if self.downsample isnotNone:
residual = self.downsample(x)
out += residual
out = self.relu(out)return out
classSEBottleneck(nn.Module):
expansion =4def__init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None,*, reduction=16):super(SEBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes *4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes *4)
self.relu = nn.ReLU(inplace=True)
self.se = SELayer(planes *4, reduction)
self.downsample = downsample
self.stride = stride
defforward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.se(out)if self.downsample isnotNone:
residual = self.downsample(x)
out += residual
out = self.relu(out)return out
defse_resnet18(num_classes=5):"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(SEBasicBlock,[2,2,2,2], num_classes=num_classes)
model.avgpool = nn.AdaptiveAvgPool2d(1)return model
defse_resnet34(num_classes=5):"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(SEBasicBlock,[3,4,6,3], num_classes=num_classes)
model.avgpool = nn.AdaptiveAvgPool2d(1)return model
defse_resnet50(num_classes=5, pretrained=False):"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(SEBottleneck,[3,4,6,3], num_classes=num_classes)
model.avgpool = nn.AdaptiveAvgPool2d(1)if pretrained:
model.load_state_dict(load_state_dict_from_url("https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"))return model
defse_resnet101(num_classes=5):"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(SEBottleneck,[3,4,23,3], num_classes=num_classes)
model.avgpool = nn.AdaptiveAvgPool2d(1)return model
defse_resnet152(num_classes=5):"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(SEBottleneck,[3,8,36,3], num_classes=num_classes)
model.avgpool = nn.AdaptiveAvgPool2d(1)return model
# 随机生成输入数据
rgb = torch.randn(1,3,512,512)# 定义网络
net = se_resnet18(num_classes=8)# 前向传播
out = net(rgb)print('-----'*5)# 打印输出大小print(out.shape)print('-----'*5)
# 随机生成输入数据
rgb = torch.randn(1,3,512,512)# 定义网络
net = se_resnet50(num_classes=8)# 前向传播
out = net(rgb)print('-----'*5)# 打印输出大小print(out.shape)print('-----'*5)
# 随机生成输入数据
rgb = torch.randn(1,3,512,512)# 定义网络
net = se_resnet101(num_classes=8)# 前向传播
out = net(rgb)print('-----'*5)# 打印输出大小print(out.shape)print('-----'*5)