pytorch入门04--resnet网络

归一化操作:

模型:

import  torch
from    torch import  nn
from    torch.nn import functional as F  #调用F.函数

class ResBlk(nn.Module):    # 定义Resnet Block模块
    """
    resnet block
    """
    def __init__(self, ch_in, ch_out, stride=2):
        #进入网络前先得知道传入层数和传出层数的设定
        """
        :param ch_in:
        :param ch_out:
        """
        super(ResBlk, self).__init__()  #初始化

        # we add stride support for resbok, which is distinct from tutorials.
        #根据resnet网络结构构建2个(block)块结构 第一层卷积 卷积核大小3*3,步长为1,边缘加1
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        #将第一层卷积处理的信息通过BatchNorm2d
        self.bn1 = nn.BatchNorm2d(ch_out)
        #第二块卷积接收第一块的输出,操作一样
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

       #确保输入维度等于输出维度
        self.extra = nn.Sequential()   #先建一个空的extra
        if ch_out != ch_in:
            # [b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )
    def forward(self, x):  #定义局部向前传播函数
        """

        :param x: [b, ch, h, w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))  #对第一块卷积后的数据再经过relu操作
        out = self.bn2(self.conv2(out))   #第二块卷积后的数据输出
        # short cut.
        # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
        # element-wise add:
        out = self.extra(x) + out  #将x传入extra经过2块(block)输出后与原始值进行相加
        out = F.relu(out)  #调用relu,这里使用F.调用
        
        return out
class ResNet18(nn.Module):   #构建resnet18层

    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1 = nn.Sequential(   #首先定义一个卷积层
            nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks 调用4次resnet网络结构,输出都是输入的2倍
        # [b, 64, h, w] => [b, 128, h ,w]
        self.blk1 = ResBlk(64, 128, stride=2)
        # [b, 128, h, w] => [b, 256, h, w]
        self.blk2 = ResBlk(128, 256, stride=2)
        # # [b, 256, h, w] => [b, 512, h, w]
        self.blk3 = ResBlk(256, 512, stride=2)
        # # [b, 512, h, w] => [b, 1024, h, w]
        self.blk4 = ResBlk(512, 512, stride=2)

        self.outlayer = nn.Linear(512*1*1, 10)   #最后是全连接层

    def forward(self, x):  #定义整个向前传播
        """
        :param x:
        :return:
        """
        x = F.relu(self.conv1(x))  #先经过第一层卷积

        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)  #然后通过4次resnet网络结构
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        # print('after conv:', x.shape) #[b, 512, 2, 2]
        # F.adaptive_avg_pool2d功能尾巴变为1,1,[b, 512, h, w] => [b, 512, 1, 1]
        x = F.adaptive_avg_pool2d(x, [1, 1])
        # print('after pool:', x.shape)
        x = x.view(x.size(0), -1)  #平铺一维值
        x = self.outlayer(x)    #全连接层
        return x

def main():  #测试代码

    blk = ResBlk(64, 128, stride=4) #确定resnet block 的输入维度和输出维度
    tmp = torch.randn(2, 64, 32, 32)  #输入数据
    out = blk(tmp)  #调用resnet网络
    print('block:', out.shape)   #打野结构

    x = torch.randn(2, 3, 32, 32)  #输入图片信息 这里相当与2张32*32大小的彩图
    model = ResNet18()  #调用resnet18整个网络结构
    out = model(x)
    print('resnet:', out.shape)

if __name__ == '__main__':
    main()

主函数:

import  torch
from    torch.utils.data import DataLoader    #导入下载功能通道
from    torchvision import datasets      #加载数据使用
from    torchvision import transforms    #对数据做变换使用
from    torch import nn, optim   #导入nn网络和optim优化器
# from    lenet5 import Lenet5     #引进类
from    resnet import ResNet18
def main():
    batchsz = 16   #每次投喂的数据量
       #datasets加载CIFAR10数据集到本地,命名为cifar,transform对数据做变换,32*32的大小,自动下载数据集
    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)  #每次导入batchsz那么多的数据
#定义测试集与训练集一样
    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)


    x, label = iter(cifar_train).next()   #打印训练集数据和标签形状
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')   #调用cuda加速
    # model = Lenet5().to(device)    #将进入的Lenet5也使用cuda加速
    model = ResNet18().to(device)

    criteon = nn.CrossEntropyLoss().to(device)     #调用损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)   #调用Adam优化器,
    print(model)   #打印类的实例

    for epoch in range(1000):
        model.train()  #变成训练模式
        for batchidx, (x, label) in enumerate(cifar_train):  #获取数据
            # [b, 3, 32, 32]
            x, label = x.to(device), label.to(device)  #cuda加速
            logits = model(x)  #通过lenet5训练
            # logits: [b, 10]   # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label) #计算损失
            # backprop
            optimizer.zero_grad()  #优化器把梯度清零 防梯度累加
            loss.backward()
            optimizer.step()  #运行优化器走流程
        print(epoch, 'loss:', loss.item())  #打印每次损失,item表示转化成numpy类型


        model.eval()  #变成测试模式
        with torch.no_grad():   #这里告诉pytorch运算时不需计算图的
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:   #获取测试集数据
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)   #调用cuda

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)  #在第2个维度上索引最大的值的下标
                # [b] vs [b] => scalar tensor  比较预测值与真实值预测对的数量 eq是否相等
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)  #统计输入总数
                # print(correct)

            acc = total_correct / total_num  #计算平均准确率
            print(epoch, 'test acc:', acc)
if __name__ == '__main__':
    main()

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值