import torch
from torch import nn
from torch.nn import functional as F
# 1x1 conv -> 3x3 conv -> 1x1 conv
class Bottleneck(nn.Module):
def __init__(self, in_channels, channels, stride=1, use_1x1conv=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, stride=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
self.conv3 = nn.Conv2d(channels, channels * 4, kernel_size=1, stride=1, bias=False)
self.bn3 = nn.BatchNorm2d(channels * 4)
if use_1x1conv:
self.conv4 = nn.Conv2d(
in_channels, channels * 4, kernel_size=1, stride=stride
)
else:
self.conv4 = None
def forward(self, x):
# 1x1 conv 通道数:in_channels -> channels
out = F.relu(self.bn1(self.conv1(x)))
# 3x3 conv 通道数:channels -> channels
out = F.relu(self.bn2(self.conv2(out)))
# 1x1 conv 通道数: channels -> 4*channels
out = self.bn3(self.conv3(out))
# 恒等映射 or 1x1 conv
if self.conv4 == None:
identity = x
else:
identity = self.conv4(x)
out += identity
return F.relu(out)
def bottleneck_block(in_channels, channels, num_bottlenecks, not_FirstBlock=True):
# 第一个neck使用1x1conv,剩余的neck不使用1x1conv
# 第一个block的stride=1,后面的block的stride=2
blk = []
for i in range(num_bottlenecks):
if i == 0:
blk.append(
Bottleneck(in_channels, channels, stride=not_FirstBlock + 1, use_1x1conv=True)
)
else:
blk.append(
Bottleneck(channels * 4, channels)
)
return blk
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
b2 = nn.Sequential(*bottleneck_block(64, 64, 3, not_FirstBlock=False))
b3 = nn.Sequential(*bottleneck_block(64 * 4, 128, 3))
b4 = nn.Sequential(*bottleneck_block(128 * 4, 256, 3))
b5 = nn.Sequential(*bottleneck_block(256 * 4, 512, 3))
resnet50 = nn.Sequential(
b1, b2, b3, b4, b5,
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(2048, 10)
)
Resnet基础代码
最新推荐文章于 2024-08-01 09:00:01 发布
本文介绍了如何在PyTorch中定义Bottleneck模块,其包含1x1conv、3x3conv和可选的1x1conv操作,以及bottleneck_block函数用于构建ResNet50网络的基本单元。
摘要由CSDN通过智能技术生成