使用Pytorch构建简单的Resnet网络

该博客介绍了如何使用PyTorch实现ResNet模型中的残差块。通过代码展示了每个残差块中通道数的变化规则,并且详细解释了如何将多个残差块串联起来。通过实例演示了不同配置下残差块的输出尺寸,为理解ResNet网络结构提供了基础。
摘要由CSDN通过智能技术生成

来自李沐老师课上的图片
每一个残差块的第一个layer通道加倍,宽高减半,后边的所有layer保持不变。

import torch.nn as nn
import torchvision
import torch
class Resnet(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super(Resnet, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, stride=strides, padding=1)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides) # 当使用1*1卷积时步长指定为2
        else:
            self.conv3 = None
        self.BN1 = nn.BatchNorm2d(num_features=num_channels)
        self.BN2 = nn.BatchNorm2d(num_features=num_channels)
        self.relu = nn.ReLU()
    def forward(self, x):
        y = self.BN2(self.conv2(self.relu(self.BN1(self.conv1(x)))))
        if self.conv3 == None:
            x = x
        else:
            x = self.conv3(x)
        return self.relu(y + x)
# 测试
input = torch.ones((64, 6, 128, 128))
resnet = Resnet(3, 6, use_1x1conv=True, strides=2)   #torch.size(64, 6, 64, 64)  通道数加倍 宽高减半
resnet = Resnet(3, 3, use_1x1conv=False)  #torch.size(64, 3, 128, 128)  通道数不变  宽高不变
print(resnet(input).shape)

将残差块链接起来

def Resnet_num(input_channels, num_channels, resblock_num, first_block=False):
    blk = []
    for i in range(resblock_num):
        if i == 0 and first_block == False:   # 每一个残差块的第一个layer通道加倍  宽高减半。 好像有具体的名词
            blk.append(Resnet(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Resnet(num_channels, num_channels))  # 注意 第一个残差块以后的所有块的通道数都和输出一致
    return blk

主函数测试

if __name__ == '__main__':
    input = torch.ones((64, 6, 128, 128))
    # input2 = torch.ones((64, 3, 224, 224))
    # resnet = Resnet(3, 6, use_1x1conv=True, strides=2)   # torch.size(64, 6, 64, 64)  通道数加倍 宽高减半
    # resnet = Resnet(3, 3, use_1x1conv=False)  # torch.size(64, 3, 128, 128)  通道数不变  宽高不变
    # print(resnet(input).shape)
    # b1 = nn.Sequential(*Resnet_num(3, 6, 1, first_block=True))
    b1 = Resnet_num(6, 6, 1, first_block=True)
    b2 = Resnet_num(6, 12, 2)
    b3 = Resnet_num(12, 24, 2)
    # print(len(b1))  查看长度
    # b1 = nn.Sequential(b1[0])
    # b2 = nn.Sequential(b2[0], b2[1])   # *是python中的语法
    # print(b1)
    # print(b2)
    # wuze = nn.Sequential(b1[0], b2[0], b2[1], b3[0], b3[1])
    wuze = nn.Sequential(*b1, *b2, *b3)
    # print((input2).shape)
    for layer in wuze:
        output = layer(input)
        input = output
        print(layer.__class__.__name__, output.shape) # 打印每层输出的尺寸

感觉有了这个基础 就可以自己写一个resnet50之类的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值