联邦聚合算法代码详解(第四期resnetcifar.py部分)

目录

第一部分:def conv3x3():-----创建并返回二维卷积层 

第二部分:def conv1x1():

第三部分:class BasicBlock(nn.Module):

残差网络,残差映射(是残差网络ResNet中基本构建块之一)

第四部分:class Bottleneck(nn.Module):

(是残差网络ResNet中基本构建块之一)

第五部分:class ResNetCifar10(nn.Module):!

def __init__

def _make_layer

def _forward_impl

def forward(self, x):

第六部分:def ResNet18_cifar10(**kwargs):

第七部分:def ResNet50_cifar10(**kwargs):

第八部分:def ResNet101_cifar10(**kwargs):


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):

def conv1x1(in_planes, out_planes, stride=1):
  
class BasicBlock(nn.Module):   
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):     
    def forward(self, x):
      
class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):       
    def forward(self, x):

class ResNetCifar10(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):        
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):       
    def _forward_impl(self, x):
    def forward(self, x):
       
def ResNet18_cifar10(**kwargs):
   
def ResNet50_cifar10(**kwargs):
  
def ResNet101_cifar10(**kwargs):
    

第一部分:def conv3x3():-----创建并返回二维卷积层 


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):   
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

参数

  • in_planes (int): 输入通道数。这是指进入卷积层的特征图的深度(或称为通道数)。例如,对于RGB图像,输入通道数为3(红、绿、蓝)。

  • out_planes (int): 输出通道数。这是指卷积操作后输出的特征图的深度。这个值决定了该卷积层将学习到多少种不同的特征。

  • stride (int, 可选): 卷积核移动的步长。默认值为1,意味着卷积核将逐个像素地滑动。如果步长大于1,则卷积核在滑动时会跳过一些像素,这有助于减少输出特征图的尺寸。

  • groups (int, 可选): 控制输入和输出之间的连接groups=1 表示所有输入通道都会被用来计算输出通道的每个通道(这是最常见的设置)。如果 groups 大于1,则输入通道会被分成多个组,每个组的输入通道只会与对应组的输出通道相连接。这可以用于实现深度可分离卷积等。

  • dilation (int, 可选): 空洞卷积(也称为扩张卷积)的扩张率。它用于在卷积核元素之间插入空格,从而在不增加参数数量的情况下增加感受野。默认值为1,表示不进行空洞卷积

第二部分:def conv1x1():

第三部分:class BasicBlock(nn.Module):

残差网络,残差映射(是残差网络ResNet中基本构建块之一)

 

class BasicBlock(nn.Module):
    expansion = 1'''*1'''
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")       
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)#第一个

        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes) #第二个
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

 BasicBlock 类,它继承自 PyTorch 的 nn.Module 类(cnn)BasicBlock 通常是残差网络(ResNet)中的基本构建块之一,用于学习输入和输出之间的残差映射。下面是对这个类的详细解读:

  • expansion = 1该基本块不改变特征图的通道数(或称为深度)。在ResNet的某些变体(如Bottleneck块)中,这个值可能大于1,用于在块内部增加特征图的通道数,然后在块的末尾通过1x1卷积将其还原到原始通道数。但在BasicBlock中,它始终为1。
  • downsample下采样操作,如果输入和输出的特征图尺寸不一致时使用)、groups(分组卷积的组数,默认为1)、base_width(基础宽度,这里未使用)、dilation(空洞卷积的扩张率,默认为1)、norm_layer(归一化层的类型,默认为nn.BatchNorm2d)。
  • (if,if,if)检查groupsbase_width是否为默认值(因为BasicBlock不支持非默认值),然后检查dilation是否大于1(BasicBlock不支持空洞卷积)。
  • 接着,定义了两个3x3的卷积层(conv1conv2),每个卷积层后面都跟着一个归一化层bn1bn2),归一化层的类型由norm_layer参数指定。
  • relu是一个ReLU激活函数,设置为inplace=True以节省内存。
  • downsample是一个可选的下采样操作,用于在输入和输出的特征图尺寸不一致时调整输入特征图的尺寸,以便进行残差连接

前向传播 forward

  • 输入x首先被保存为identity,用于后续的残差连接。
  • x通过第一个卷积层conv1、归一化层bn1和ReLU激活函数后,得到中间特征图。
  • 中间特征图再经过第二个卷积层conv2和归一化层bn2
  • 如果存在downsample操作,则将原始输入x通过downsample以匹配输出特征图的尺寸。
  • 最后,将处理后的特征图与identity(原始输入或经过下采样的输入)相加,并通过ReLU激活函数,得到最终的输出。

这种结构允许网络学习输入和输出之间的残差,有助于解决深度网络训练中的梯度消失或梯度爆炸问题,从而可以训练更深的网络。

第四部分:class Bottleneck(nn.Module):

(是残差网络ResNet中基本构建块之一)

class Bottleneck(nn.Module):

    expansion = 4'''*2'''
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups        
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

构建深度残差网络(ResNet)中常用的一个构建块。与 BasicBlock 相比,Bottleneck 使用了三个卷积层,并且通常用于更深的ResNet变体(如ResNet-50, ResNet-101等),以减少参数数量计算复杂度,同时保持模型的性能。

  • 参数与 BasicBlock 类似,但增加了对 base_width 的支持,这允许在不改变Bottleneck基本结构的情况下调整内部宽度。
  • width 的计算考虑了 base_width 和 groups,用于确定中间两个卷积层的宽度
  • 第一个卷积层 conv1 是1x1卷积,用于减少通道数(或称为“压缩”)。
  • 第二个卷积层 conv2 是3x3卷积,具有可变的步长和空洞卷积支持,是Bottleneck的核心部分。
  • 第三个卷积层 conv3 是1x1卷积,用于扩展通道数到原始输入通道数的四倍(或根据 expansion 属性指定的倍数)。
  • 每个卷积层后面都跟着一个归一化层(bn1bn2bn3),用于稳定训练过程。

第五部分:class ResNetCifar10(nn.Module):

class ResNetCifar10(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNetCifar10, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)
def __init__
  1. 初始化 (__init__ 方法):
    • 初始化网络的基本设置,包括归一化层 (norm_layer,默认为 nn.BatchNorm2d)、输入通道数 (inplanes,初始化为 64)、空洞卷积的扩张率 (dilation,初始化为 1)。
    • replace_stride_with_dilation 是一个可选参数,用于指定是否在某些层中用空洞卷积(dilation convolution)替换 2x2 的步长(stride),这有助于在不减少特征图空间分辨率的情况下增加感受野。
    • 设置组卷积的参数(groups 和 width_per_group),尽管在这个版本的代码中它们可能没有被直接使用来定义卷积层。
    • 定义第一个卷积层 conv1 和相应的归一化层 bn1,以及ReLU激活函数。
  2. 层定义:
    • 定义了四个主要的层 (layer1 到 layer4),每个层都是由多个残差块(block组成的。_make_layer 方法(尽管在这段代码中未给出,但通常在 ResNet 实现中定义)用于创建这些层。每个层中的残差块数量由 layers 参数指定。
    • 每个层都可以选择性地通过 dilate 参数来替换步长为空洞卷积,这取决于 replace_stride_with_dilation 列表中的相应值。
  3. 全局平均池化和全连接层:
    • 使用 nn.AdaptiveAvgPool2d 对特征图进行全局平均池化,将其输出大小调整为 (1, 1)
    • 定义一个全连接层 fc,其输入特征数为 512 * block.expansion(最后一个残差块的输出通道数乘以扩展倍数),输出特征数为 num_classes(类别数)。
  4. 循环初始化参数:
    • 遍历网络中的所有模块(self.modules()),并对某些类型的模块(如卷积层和批归一化层)的参数进行初始化。例如,你可能会看到对权重进行正态分布初始化,对偏置进行零初始化的代码。
def _make_layer

这段代码是ResNet中用于构建残差层(residual layer)的_make_layer函数的实现。这个函数根据给定的参数(如残差块类型block、输出通道数planes、残差块数量blocks、步长stride以及是否使用空洞卷积dilate)来创建一系列残差块,并将它们顺序地连接成一个层。下面是对这段代码的详细解读:

  1. 初始化变量
    • norm_layer:获取网络配置中指定的归一化层类型,默认为nn.BatchNorm2d
    • downsample:用于调整输入维度以匹配残差块的输出维度(如果需要的话)。初始化为None
    • previous_dilation:保存当前层的初始空洞率(dilation rate),用于后续计算。
  2. 处理空洞卷积
    • 如果dilateTrue,则将当前层的空洞率乘以步长stride,并将步长stride设置为1。这是因为在空洞卷积中,我们希望在不减小特征图尺寸的情况下增加感受野,所以通过调整空洞率来实现这一点,而不是通过步长。
  3. 构建下采样路径(如果需要):
    • 如果步长stride不为1,或者输入通道数self.inplanes与经过扩展(block.expansion)后的输出通道数planes * block.expansion不匹配,则需要构建一个下采样路径(downsample)。这个路径包括一个1x1的卷积层(用于调整通道数)和一个归一化层。
  4. 构建第一个残差块
    • 将第一个残差块添加到layers列表中。这个残差块使用当前的输入通道数self.inplanes、输出通道数planes、步长stride(可能已经调整为1)、下采样路径(如果需要的话)、组数self.groups、基础宽度self.base_width、初始空洞率previous_dilation以及归一化层类型norm_layer作为参数。
    • 更新self.inplanes为当前残差块的输出通道数(即planes * block.expansion),以便后续残差块使用。
  5. 构建剩余的残差块
    • 使用一个循环来构建剩余的blocks-1个残差块。这些残差块只需要输入通道数(已更新为self.inplanes)、输出通道数planes、组数self.groups、基础宽度self.base_width、当前的空洞率self.dilation以及归一化层类型norm_layer作为参数。
  6. 返回顺序容器
    • 使用nn.Sequentiallayers列表中的所有残差块顺序地连接成一个层,并返回这个层。

这个函数是ResNet中构建层次结构的关键部分,它允许我们灵活地构建具有不同深度和宽度的残差网络。通过调整blockplanesblocks等参数,我们可以轻松地定制网络结构以适应不同的任务和数据集。

def _forward_impl

这段代码是一个典型的卷积神经网络(CNN)的前向传播实现,特别是在使用PyTorch框架时常见。这个函数_forward_impl定义了一个网络如何通过输入x(通常是一个批次的图像数据)来计算输出。这个过程模拟了数据在网络中的流动,从输入层开始,通过一系列的卷积层、批归一化层、激活函数层、以及可能的池化层和全连接层,最终产生输出。下面是对这个过程中每个步骤的详细解读:

  1. 卷积层(self.conv1(x):这是网络的第一层,通常是一个卷积层。它会对输入x应用卷积操作,目的是提取输入数据的低级特征。卷积层的输出会传递到下一个层。

  2. 批归一化层(self.bn1(x):批归一化层会对卷积层的输出进行归一化处理,使得输出的数据分布具有相同的均值和方差。这有助于加速训练过程并提高模型的稳定性。

  3. 激活函数(self.relu(x):这里使用的是ReLU(Rectified Linear Unit)激活函数。ReLU函数将输入的所有负值置为0,而保持正值不变。它引入了非线性,使得网络能够学习复杂的模式。

  4. 残差层(self.layer1(x)self.layer2(x)self.layer3(x)self.layer4(x):这些层通常包含多个卷积块,每个块可能包括卷积层、批归一化层、激活函数层,以及可能的下采样层(通过步长大于1的卷积或池化层实现)。在残差网络中,这些层还通过残差连接将输入直接加到输出上,这有助于解决深度网络中的梯度消失问题。

  5. 平均池化层(self.avgpool(x):在通过所有残差层之后,通常会使用一个池化层来进一步降低特征图的维度。这里使用的是平均池化,它计算特征图中每个区域的平均值,并将这个平均值作为该区域的输出。

  6. 展平(torch.flatten(x, 1):在将特征图传递到全连接层之前,需要将其展平为一维张量。torch.flatten(x, 1)x从第二个维度开始展平,因为第一个维度通常是批次大小。

  7. 全连接层(self.fc(x):最后,展平后的特征被传递到全连接层(也称为线性层或密集层)。这个层会学习特征之间的非线性组合,并产生最终的输出。在分类任务中,输出层的神经元数量通常等于类别的数量,并且每个神经元的输出可以解释为输入属于对应类别的概率。

总的来说,_forward_impl函数定义了网络如何通过一系列的操作(卷积、归一化、激活、池化、展平和全连接)来将输入数据转换为输出。这是构建和训练深度学习模型的核心部分。

def forward(self, x):

forward 定义了数据通过网络的前向传播过程。这个方法通过调用 _forward_impl 方法(作为类的另一个成员函数)来实现具体的前向传播逻辑。

第六部分:def ResNet18_cifar10(**kwargs):

def ResNet18_cifar10(**kwargs):  
    return ResNetCifar10(BasicBlock, [2, 2, 2, 2], **kwargs)

一个工厂函数,用于创建并返回一个针对 CIFAR-10 数据集优化的 ResNet 模型的实例BasicBlock 是构建 ResNet 网络的基本残差块类型,而 [2, 2, 2, 2] 指定了每个残差层(或称为阶段)中残差块的数量。

  • [2, 2, 2, 2]: 这是一个列表,指定了每个残差层(或称为阶段)中 BasicBlock 的数量。在 ResNet-18 的典型实现中,网络被划分为四个这样的阶段,每个阶段的特征图大小逐渐减小(通过步长为 2 的卷积或池化操作实现),同时特征图的数量(即通道数)逐渐增加。这里的 [2, 2, 2, 2] 意味着每个阶段都有 2 个 BasicBlock

  • **kwargs: 这是一个关键字参数,允许在调用 ResNet18_cifar10 函数时传递额外的参数给 ResNetCifar10 类的构造函数。这些参数可以用于进一步自定义模型的行为或结构,比如设置不同的学习率、权重衰减等。

一个方便的接口,用于快速创建和初始化一个针对 CIFAR-10 数据集优化的 ResNet 模型的实例。

第七部分:def ResNet50_cifar10(**kwargs):

 

ResNet50_cifar10 函数与 ResNet18_cifar10 函数之间的主要区别在于它们构建的 ResNet 模型的架构和深度。这些差异主要体现在所使用的残差块类型(BasicBlock vs Bottleneck)以及每个阶段中残差块的数量上。

残差块类型

ResNet18_cifar10 使用 BasicBlock 作为残差块。BasicBlock 包含两个 3x3 的卷积层,每个卷积层后面跟着一个批归一化层和一个 ReLU 激活函数。

ResNet50_cifar10 使用 Bottleneck 作为残差块。Bottleneck 块的设计是为了在保持计算量相对不变的同时增加模型的深度。它通常包含一个 1x1 的卷积用于减少通道数(瓶颈),然后是一个 3x3 的卷积用于提取特征,最后是一个 1x1 的卷积用于恢复通道数。这种设计可以减少参数数量和计算量,同时保持模型的表示能力。

  1. 残差块数量
    • ResNet18_cifar10 中每个阶段的残差块数量为 [2, 2, 2, 2],这意味着整个网络相对较浅,尽管名字中有“18”,但实际的层数(特别是卷积层或残差块的数量)会根据 BasicBlock 的内部结构和输入/输出层的设计而有所不同。
    • ResNet50_cifar10 中每个阶段的残差块数量为 [3, 4, 6, 3],这表明网络更深,有更多的层来提取和组合特征。这种更深的架构通常能够学习更复杂的模式,但也可能需要更多的数据和计算资源来训练。
  2. 模型复杂度
    • 由于使用了更复杂的 Bottleneck 块和更多的残差块,ResNet50_cifar10 模型的复杂度和参数数量通常会比 ResNet18_cifar10 更高。这意味着它可能具有更强的表示能力,但也可能更容易过拟合,特别是在数据量较少的情况下。

ResNet50_cifar10 函数创建了一个比 ResNet18_cifar10 更深、更复杂且可能具有更强表示能力的 ResNet 模型,用于 CIFAR-10 数据集。这种差异主要体现在残差块的选择和数量上,以及它们对模型整体架构和性能的影响上。选择哪个模型取决于具体的应用场景、可用的计算资源和数据集的大小。

第八部分:def ResNet101_cifar10(**kwargs):

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):   
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")       
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class Bottleneck(nn.Module):
    convolution(self.conv2)   
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups        
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class ResNetCifar10(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNetCifar10, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)

def ResNet18_cifar10(**kwargs):  
    return ResNetCifar10(BasicBlock, [2, 2, 2, 2], **kwargs)

def ResNet50_cifar10(**kwargs): 
    return ResNetCifar10(Bottleneck, [3, 4, 6, 3], **kwargs)

def ResNet101_cifar10(**kwargs):
    return ResNetCifar10(Bottleneck, [3, 4, 23, 3], **kwargs)

  • 8
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值