目录
第一部分:def conv3x3():-----创建并返回二维卷积层
第三部分:class BasicBlock(nn.Module):
残差网络,残差映射(是残差网络ResNet中基本构建块之一)
第四部分:class Bottleneck(nn.Module):
第五部分:class ResNetCifar10(nn.Module):!
第六部分:def ResNet18_cifar10(**kwargs):
第七部分:def ResNet50_cifar10(**kwargs):
第八部分:def ResNet101_cifar10(**kwargs):
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
def forward(self, x):
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
def forward(self, x):
class ResNetCifar10(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
def _forward_impl(self, x):
def forward(self, x):
def ResNet18_cifar10(**kwargs):
def ResNet50_cifar10(**kwargs):
def ResNet101_cifar10(**kwargs):
第一部分:def conv3x3():-----创建并返回二维卷积层
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
参数
-
in_planes (
int
): 输入通道数。这是指进入卷积层的特征图的深度(或称为通道数)。例如,对于RGB图像,输入通道数为3(红、绿、蓝)。 -
out_planes (
int
): 输出通道数。这是指卷积操作后输出的特征图的深度。这个值决定了该卷积层将学习到多少种不同的特征。 -
stride (
int
, 可选): 卷积核移动的步长。默认值为1,意味着卷积核将逐个像素地滑动。如果步长大于1,则卷积核在滑动时会跳过一些像素,这有助于减少输出特征图的尺寸。 -
groups (
int
, 可选): 控制输入和输出之间的连接。groups=1
表示所有输入通道都会被用来计算输出通道的每个通道(这是最常见的设置)。如果groups
大于1,则输入通道会被分成多个组,每个组的输入通道只会与对应组的输出通道相连接。这可以用于实现深度可分离卷积等。 -
dilation (
int
, 可选): 空洞卷积(也称为扩张卷积)的扩张率。它用于在卷积核元素之间插入空格,从而在不增加参数数量的情况下增加感受野。默认值为1,表示不进行空洞卷积。
第二部分:def conv1x1():
第三部分:class BasicBlock(nn.Module):
残差网络,残差映射(是残差网络ResNet中基本构建块之一)
class BasicBlock(nn.Module):
expansion = 1'''*1'''
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)#第一个
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes) #第二个
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
BasicBlock
类,它继承自 PyTorch 的 nn.Module
类(cnn)BasicBlock
通常是残差网络(ResNet)中的基本构建块之一,用于学习输入和输出之间的残差映射。下面是对这个类的详细解读:
expansion = 1
:该基本块不改变特征图的通道数(或称为深度)。在ResNet的某些变体(如Bottleneck块)中,这个值可能大于1,用于在块内部增加特征图的通道数,然后在块的末尾通过1x1卷积将其还原到原始通道数。但在BasicBlock
中,它始终为1。downsample
(下采样操作,如果输入和输出的特征图尺寸不一致时使用)、groups
(分组卷积的组数,默认为1)、base_width
(基础宽度,这里未使用)、dilation
(空洞卷积的扩张率,默认为1)、norm_layer
(归一化层的类型,默认为nn.BatchNorm2d
)。- (if,if,if)检查
groups
和base_width
是否为默认值(因为BasicBlock
不支持非默认值),然后检查dilation
是否大于1(BasicBlock
不支持空洞卷积)。 - 接着,定义了两个3x3的卷积层(
conv1
和conv2
),每个卷积层后面都跟着一个归一化层(bn1
和bn2
),归一化层的类型由norm_layer
参数指定。 relu
是一个ReLU激活函数,设置为inplace=True
以节省内存。downsample
是一个可选的下采样操作,用于在输入和输出的特征图尺寸不一致时调整输入特征图的尺寸,以便进行残差连接。
前向传播 forward
- 输入
x
首先被保存为identity
,用于后续的残差连接。 x
通过第一个卷积层conv1
、归一化层bn1
和ReLU激活函数后,得到中间特征图。- 中间特征图再经过第二个卷积层
conv2
和归一化层bn2
。 - 如果存在
downsample
操作,则将原始输入x
通过downsample
以匹配输出特征图的尺寸。 - 最后,将处理后的特征图与
identity
(原始输入或经过下采样的输入)相加,并通过ReLU激活函数,得到最终的输出。
这种结构允许网络学习输入和输出之间的残差,有助于解决深度网络训练中的梯度消失或梯度爆炸问题,从而可以训练更深的网络。
第四部分:class Bottleneck(nn.Module):
(是残差网络ResNet中基本构建块之一)
class Bottleneck(nn.Module):
expansion = 4'''*2'''
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
构建深度残差网络(ResNet)中常用的一个构建块。与 BasicBlock
相比,Bottleneck
使用了三个卷积层,并且通常用于更深的ResNet变体(如ResNet-50, ResNet-101等),以减少参数数量和计算复杂度,同时保持模型的性能。
- 参数与
BasicBlock
类似,但增加了对base_width
的支持,这允许在不改变Bottleneck基本结构的情况下调整内部宽度。 width
的计算考虑了base_width
和groups
,用于确定中间两个卷积层的宽度。- 第一个卷积层
conv1
是1x1卷积,用于减少通道数(或称为“压缩”)。 - 第二个卷积层
conv2
是3x3卷积,具有可变的步长和空洞卷积支持,是Bottleneck的核心部分。 - 第三个卷积层
conv3
是1x1卷积,用于扩展通道数到原始输入通道数的四倍(或根据expansion
属性指定的倍数)。 - 每个卷积层后面都跟着一个归一化层(
bn1
,bn2
,bn3
),用于稳定训练过程。
第五部分:class ResNetCifar10(nn.Module):!
class ResNetCifar10(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNetCifar10, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def __init__
- 初始化 (
__init__
方法):- 初始化网络的基本设置,包括归一化层 (
norm_layer
,默认为nn.BatchNorm2d
)、输入通道数 (inplanes
,初始化为 64)、空洞卷积的扩张率 (dilation
,初始化为 1)。 replace_stride_with_dilation
是一个可选参数,用于指定是否在某些层中用空洞卷积(dilation convolution)替换 2x2 的步长(stride),这有助于在不减少特征图空间分辨率的情况下增加感受野。- 设置组卷积的参数(
groups
和width_per_group
),尽管在这个版本的代码中它们可能没有被直接使用来定义卷积层。 - 定义第一个卷积层
conv1
和相应的归一化层bn1
,以及ReLU激活函数。
- 初始化网络的基本设置,包括归一化层 (
- 层定义:
- 定义了四个主要的层 (
layer1
到layer4
),每个层都是由多个残差块(block
)组成的。_make_layer
方法(尽管在这段代码中未给出,但通常在 ResNet 实现中定义)用于创建这些层。每个层中的残差块数量由layers
参数指定。 - 每个层都可以选择性地通过
dilate
参数来替换步长为空洞卷积,这取决于replace_stride_with_dilation
列表中的相应值。
- 定义了四个主要的层 (
- 全局平均池化和全连接层:
- 使用
nn.AdaptiveAvgPool2d
对特征图进行全局平均池化,将其输出大小调整为(1, 1)
。 - 定义一个全连接层
fc
,其输入特征数为512 * block.expansion
(最后一个残差块的输出通道数乘以扩展倍数),输出特征数为num_classes
(类别数)。
- 使用
- 循环初始化参数:
- 遍历网络中的所有模块(
self.modules()
),并对某些类型的模块(如卷积层和批归一化层)的参数进行初始化。例如,你可能会看到对权重进行正态分布初始化,对偏置进行零初始化的代码。
- 遍历网络中的所有模块(
def _make_layer
这段代码是ResNet中用于构建残差层(residual layer)的_make_layer
函数的实现。这个函数根据给定的参数(如残差块类型block
、输出通道数planes
、残差块数量blocks
、步长stride
以及是否使用空洞卷积dilate
)来创建一系列残差块,并将它们顺序地连接成一个层。下面是对这段代码的详细解读:
- 初始化变量:
norm_layer
:获取网络配置中指定的归一化层类型,默认为nn.BatchNorm2d
。downsample
:用于调整输入维度以匹配残差块的输出维度(如果需要的话)。初始化为None
。previous_dilation
:保存当前层的初始空洞率(dilation rate),用于后续计算。
- 处理空洞卷积:
- 如果
dilate
为True
,则将当前层的空洞率乘以步长stride
,并将步长stride
设置为1。这是因为在空洞卷积中,我们希望在不减小特征图尺寸的情况下增加感受野,所以通过调整空洞率来实现这一点,而不是通过步长。
- 如果
- 构建下采样路径(如果需要):
- 如果步长
stride
不为1,或者输入通道数self.inplanes
与经过扩展(block.expansion
)后的输出通道数planes * block.expansion
不匹配,则需要构建一个下采样路径(downsample
)。这个路径包括一个1x1的卷积层(用于调整通道数)和一个归一化层。
- 如果步长
- 构建第一个残差块:
- 将第一个残差块添加到
layers
列表中。这个残差块使用当前的输入通道数self.inplanes
、输出通道数planes
、步长stride
(可能已经调整为1)、下采样路径(如果需要的话)、组数self.groups
、基础宽度self.base_width
、初始空洞率previous_dilation
以及归一化层类型norm_layer
作为参数。 - 更新
self.inplanes
为当前残差块的输出通道数(即planes * block.expansion
),以便后续残差块使用。
- 将第一个残差块添加到
- 构建剩余的残差块:
- 使用一个循环来构建剩余的
blocks-1
个残差块。这些残差块只需要输入通道数(已更新为self.inplanes
)、输出通道数planes
、组数self.groups
、基础宽度self.base_width
、当前的空洞率self.dilation
以及归一化层类型norm_layer
作为参数。
- 使用一个循环来构建剩余的
- 返回顺序容器:
- 使用
nn.Sequential
将layers
列表中的所有残差块顺序地连接成一个层,并返回这个层。
- 使用
这个函数是ResNet中构建层次结构的关键部分,它允许我们灵活地构建具有不同深度和宽度的残差网络。通过调整block
、planes
、blocks
等参数,我们可以轻松地定制网络结构以适应不同的任务和数据集。
def _forward_impl
这段代码是一个典型的卷积神经网络(CNN)的前向传播实现,特别是在使用PyTorch框架时常见。这个函数_forward_impl
定义了一个网络如何通过输入x
(通常是一个批次的图像数据)来计算输出。这个过程模拟了数据在网络中的流动,从输入层开始,通过一系列的卷积层、批归一化层、激活函数层、以及可能的池化层和全连接层,最终产生输出。下面是对这个过程中每个步骤的详细解读:
-
卷积层(
self.conv1(x)
):这是网络的第一层,通常是一个卷积层。它会对输入x
应用卷积操作,目的是提取输入数据的低级特征。卷积层的输出会传递到下一个层。 -
批归一化层(
self.bn1(x)
):批归一化层会对卷积层的输出进行归一化处理,使得输出的数据分布具有相同的均值和方差。这有助于加速训练过程并提高模型的稳定性。 -
激活函数(
self.relu(x)
):这里使用的是ReLU(Rectified Linear Unit)激活函数。ReLU函数将输入的所有负值置为0,而保持正值不变。它引入了非线性,使得网络能够学习复杂的模式。 -
残差层(
self.layer1(x)
,self.layer2(x)
,self.layer3(x)
,self.layer4(x)
):这些层通常包含多个卷积块,每个块可能包括卷积层、批归一化层、激活函数层,以及可能的下采样层(通过步长大于1的卷积或池化层实现)。在残差网络中,这些层还通过残差连接将输入直接加到输出上,这有助于解决深度网络中的梯度消失问题。 -
平均池化层(
self.avgpool(x)
):在通过所有残差层之后,通常会使用一个池化层来进一步降低特征图的维度。这里使用的是平均池化,它计算特征图中每个区域的平均值,并将这个平均值作为该区域的输出。 -
展平(
torch.flatten(x, 1)
):在将特征图传递到全连接层之前,需要将其展平为一维张量。torch.flatten(x, 1)
将x
从第二个维度开始展平,因为第一个维度通常是批次大小。 -
全连接层(
self.fc(x)
):最后,展平后的特征被传递到全连接层(也称为线性层或密集层)。这个层会学习特征之间的非线性组合,并产生最终的输出。在分类任务中,输出层的神经元数量通常等于类别的数量,并且每个神经元的输出可以解释为输入属于对应类别的概率。
总的来说,_forward_impl
函数定义了网络如何通过一系列的操作(卷积、归一化、激活、池化、展平和全连接)来将输入数据转换为输出。这是构建和训练深度学习模型的核心部分。
def forward(self, x):
forward
定义了数据通过网络的前向传播过程。这个方法通过调用 _forward_impl
方法(作为类的另一个成员函数)来实现具体的前向传播逻辑。
第六部分:def ResNet18_cifar10(**kwargs):
def ResNet18_cifar10(**kwargs):
return ResNetCifar10(BasicBlock, [2, 2, 2, 2], **kwargs)
一个工厂函数,用于创建并返回一个针对 CIFAR-10 数据集优化的 ResNet 模型的实例。BasicBlock
是构建 ResNet 网络的基本残差块类型,而 [2, 2, 2, 2]
指定了每个残差层(或称为阶段)中残差块的数量。
-
[2, 2, 2, 2]
: 这是一个列表,指定了每个残差层(或称为阶段)中BasicBlock
的数量。在 ResNet-18 的典型实现中,网络被划分为四个这样的阶段,每个阶段的特征图大小逐渐减小(通过步长为 2 的卷积或池化操作实现),同时特征图的数量(即通道数)逐渐增加。这里的[2, 2, 2, 2]
意味着每个阶段都有 2 个BasicBlock
。 -
**kwargs
: 这是一个关键字参数,允许在调用ResNet18_cifar10
函数时传递额外的参数给ResNetCifar10
类的构造函数。这些参数可以用于进一步自定义模型的行为或结构,比如设置不同的学习率、权重衰减等。
是一个方便的接口,用于快速创建和初始化一个针对 CIFAR-10 数据集优化的 ResNet 模型的实例。
第七部分:def ResNet50_cifar10(**kwargs):
ResNet50_cifar10
函数与 ResNet18_cifar10
函数之间的主要区别在于它们构建的 ResNet 模型的架构和深度。这些差异主要体现在所使用的残差块类型(BasicBlock
vs Bottleneck
)以及每个阶段中残差块的数量上。
残差块类型:
ResNet18_cifar10
使用 BasicBlock
作为残差块。BasicBlock
包含两个 3x3 的卷积层,每个卷积层后面跟着一个批归一化层和一个 ReLU 激活函数。
ResNet50_cifar10
使用 Bottleneck
作为残差块。Bottleneck
块的设计是为了在保持计算量相对不变的同时增加模型的深度。它通常包含一个 1x1 的卷积用于减少通道数(瓶颈),然后是一个 3x3 的卷积用于提取特征,最后是一个 1x1 的卷积用于恢复通道数。这种设计可以减少参数数量和计算量,同时保持模型的表示能力。
- 残差块数量:
ResNet18_cifar10
中每个阶段的残差块数量为[2, 2, 2, 2]
,这意味着整个网络相对较浅,尽管名字中有“18”,但实际的层数(特别是卷积层或残差块的数量)会根据BasicBlock
的内部结构和输入/输出层的设计而有所不同。ResNet50_cifar10
中每个阶段的残差块数量为[3, 4, 6, 3]
,这表明网络更深,有更多的层来提取和组合特征。这种更深的架构通常能够学习更复杂的模式,但也可能需要更多的数据和计算资源来训练。
- 模型复杂度:
- 由于使用了更复杂的
Bottleneck
块和更多的残差块,ResNet50_cifar10
模型的复杂度和参数数量通常会比ResNet18_cifar10
更高。这意味着它可能具有更强的表示能力,但也可能更容易过拟合,特别是在数据量较少的情况下。
- 由于使用了更复杂的
ResNet50_cifar10
函数创建了一个比 ResNet18_cifar10
更深、更复杂且可能具有更强表示能力的 ResNet 模型,用于 CIFAR-10 数据集。这种差异主要体现在残差块的选择和数量上,以及它们对模型整体架构和性能的影响上。选择哪个模型取决于具体的应用场景、可用的计算资源和数据集的大小。
第八部分:def ResNet101_cifar10(**kwargs):
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
convolution(self.conv2)
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNetCifar10(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNetCifar10, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def ResNet18_cifar10(**kwargs):
return ResNetCifar10(BasicBlock, [2, 2, 2, 2], **kwargs)
def ResNet50_cifar10(**kwargs):
return ResNetCifar10(Bottleneck, [3, 4, 6, 3], **kwargs)
def ResNet101_cifar10(**kwargs):
return ResNetCifar10(Bottleneck, [3, 4, 23, 3], **kwargs)