基于Pytorch框架的ResNet:MNIST数据集手写数字识别

Debug经验总结

一、常规ResBlock的输出尺寸与输入尺寸相同,否则需要进行尺寸变换;

二、在数据集较大时设置num_work进行多线程处理,可以很大提高训练效率;

三、较复杂的网络在搭建前可以先用草图计算每个输出位置的矩阵尺寸,减少Debug难度;

四、选用ReLU激活函数时,应适当降低学习率,避免出现损失函数值无法下降的情况;

五、比较训练集的准确率和测试集的准确率,判断是否出现过拟合。

六、ResBlock是在激活前加入输入值作为偏移量,不能放错位置;

代码展示

import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

batch=50
iteration=1

#数据载入
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.15,0.30)])
train_set=datasets.MNIST("D:\桌面\ResNet",train=True,download=True,transform=trans)
train_loader=DataLoader(train_set,batch_size=batch,shuffle=True,num_workers=16)
test_set=datasets.MNIST("D:\桌面\ResNet",train=False,download=True,transform=trans)
test_loader=DataLoader(test_set,batch_size=batch,num_workers=16)

#模块搭建
class ResBlock(torch.nn.Module):
    def __init__(self,channels_in):
        super().__init__()
        self.conv1=torch.nn.Conv2d(channels_in,30,5,padding=2)
        self.conv2=torch.nn.Conv2d(30,channels_in,3,padding=1)

    def forward(self,x):
        out=self.conv1(x)
        out=self.conv2(out)
        return F.relu(out+x)

#网络搭建
class ResNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=torch.nn.Conv2d(1,20,5)
        self.conv2=torch.nn.Conv2d(20,15,3)
        self.maxpool=torch.nn.MaxPool2d(2)
        self.resblock1=ResBlock(channels_in=20)
        self.resblock2=ResBlock(channels_in=15)
        self.full_c=torch.nn.Linear(375,10)

    def forward(self,x):
        size=x.shape[0]
        x=F.relu(self.maxpool(self.conv1(x)))
        x=self.resblock1(x)
        x=F.relu(self.maxpool(self.conv2(x)))
        x=self.resblock2(x)
        x=x.view(size,-1)
        x=self.full_c(x)
        return x

#损失函数、优化器、学习率衰减
model=ResNet()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.005)
schedular=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.999)

#训练函数
def train():
    for epoch in range(iteration):
        for batch_index,data in enumerate(train_loader,0):
            l=0.0
            train_data,train_labels=data
            optimizer.zero_grad()
            pred_data=model(train_data)
            loss=criterion(pred_data,train_labels)
            loss.backward()
            l+=loss.item()
            optimizer.step()
            schedular.step()
            if batch_index%50==0:
                print("epoch:",epoch,"batch_index:",batch_index/50,"loss:",l)

#测试函数
def test():
    with torch.no_grad():
        correct=0.0
        total=0.0
        for batch_index,data in enumerate(test_loader,0):
            test_data,test_labels=data
            pred_data=model(test_data)
            _,pred_labels=torch.max(pred_data,dim=1)
            total+=test_labels.shape[0]
            correct+=(pred_labels==test_labels).sum().item()
            if batch_index%20==0:
                print("测试进度:",100.0*batch_index/200,"%")
        print("准确率为:",correct*100.0/total,"%")

#主函数
if __name__ == '__main__':
    train()
    test()

运行结果

在这里插入图片描述

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值