用pytorch实现卷积神经网络RESNET的CIFAR10分类问题

上回讲到Lenet5的分类问题。

用pytorch实现卷积神经网络Lenet5的CIFAR10分类问题

今天讲一下 一个非常简单的RESNET,总共十层,包括一个卷积层,四个blog(每个blog两个卷积层),一个全连接层。blog的基本单元如图所示

每个blog有一个短接x最后和F(x) 相加。

以下是blog块的定义

import torch
from torch.nn import functional as F
from torch import  nn
class ResBlk(nn.Module):

    def __init__(self,ch_in,ch_out,stride):
        super(ResBlk,self).__init__()
        self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), nn.BatchNorm2d(ch_out))

        # 改变stride是为了使得图片的size变小,以避免占用过多内存

        self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size = 3,stride = stride,padding = 1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size = 3,stride = 1,padding = 1)
        self.bn2 = nn.BatchNorm2d(ch_out)


    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        # 这里的relu取决于自己
        out = F.relu(self.bn2(self.conv2(out)))
        # short cut
        # extra module: [b,ch_in,h,w] -->  [b,ch_in,h,w]
        # element-wise add 需要ch_in和ch_out相等
        # 由于是残差网络,所以要把f(x)和短路的x相加

        out = self.extra(x) + out


        return out

以下是将一个卷积层和四个blog块加一个全连接层串联起来

class ResNet10(nn.Module):

    def __init__(self):
        super(ResNet10,self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size = 3,stride = 1,padding = 1),
            nn.BatchNorm2d(64)
        )
        # followws 4 blocks
        # [b,64,h,w] --> [b,128,h,w]
        self.blk1 = ResBlk(64,128,stride = 2)
        # [b,128,h,w] --> [b,256,h,w]
        self.blk2 = ResBlk(128,256,stride = 2)
        # [b,256,h,w] --> [b,512,h,w]
        self.blk3 = ResBlk(256,512,stride = 2)
        # [b,512,h,w] --> [b,512,h,w]
        self.blk4 = ResBlk(512,512,stride = 2)

        # 线性层的输入需要测试之后才能知道
        self.outlayer = nn.Linear(512*1*1,10)

    def forward(self,x):
        x = F.relu(self.conv1(x))
        # [b,64,h,w] --> [b,1024,h,w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        #print('after conv:',x.shape) # [b,512,2,2]
        # [b,512,1,1] --> [b,512,1,1]
        x = F.adaptive_avg_pool2d(x,[1,1])
        #print('after conv:',x.shape)
        x = x.view(x.size(0),-1)
        x = self.outlayer(x)

        return x

以下是主程序代码

import torch
import torchvision
from torch import nn, optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

#from lenet5 import Lenet5
from resnet import ResNet10

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def main():
    training= datasets.CIFAR10('data',True,transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()]),download=True)
    trainloader = DataLoader(training, batch_size=32,shuffle=True)
    test= datasets.CIFAR10('data',False,transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()]),download=True)
    testloader = DataLoader(test, batch_size=16,shuffle=True)
    #x, label=iter(training).next()
    #print('x:',x.shape,'label:',label.shape)
    device=torch.device('cuda')

    model=ResNet10().to(device)
    criteon=nn.CrossEntropyLoss()
    optimizer=optim.Adam(model.parameters(),lr=0.001)
    print(model)
    for epoch in range(75):
        model.train()
        for batchix,(x,label) in enumerate(trainloader):
            x,label=x.to(device),label.to(device)
            logits=model(x)
            loss=criteon(logits,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch,loss.item())
        model.eval()
        class_correct = list(0. for i in range(10))
        class_total = list(0. for i in range(10))
        with torch.no_grad():
            total_correct=0
            total_num=0

            for x ,label in testloader:
                x,label=x.to(device),label.to(device)
                logits=model(x)
                pred= logits.argmax(dim=1)
                total_correct+=torch.eq(pred,label).float().sum().item()
                total_num+=x.size(0)
                c = (pred == label).squeeze()

                for i in range(16):
                    labels = label[i]
                    class_correct[labels] += c[i].item()
                    class_total[labels] += 1
            acc = total_correct / total_num
            print(epoch, acc)
        for i in range(10):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))








if __name__=='__main__':
    main()

跑了75个epoch能够达到83.19%多的识别率比lenet5 提高了15%.

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

老于找工作

谢谢打赏!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值