基于Pytorch框架的GoogleNet:MNIST数据集手写数字识别

Debug问题总结:

一、使用x.view(size,-1)进行矩阵变换后没有赋值给x;

二、在编写测试函数时没有使用mini-hatch会出现报错(无法同时处理过多测试集);

三、使用torch.max()将pred_data转换成pred_labels时有两个返回值,分别是最大项及其索引值,要使用_,pred_labels=torch.max(pred_data)进行赋值,不能省略逗号;

四、准确率计算时是比较pred_labels和test_labels,不是pred_data,否则准确率输出为0%;

五、inception模块同层神经元的输入通道为上一个神经元的输出通道,不能写成channels_in,即整个模块的输入通道数;

六、在编写训练函数和测试函数时最好输出训练/测试进度,用以预判等待时间。

模型图解

在这里插入图片描述

代码展示

import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

#迭代次数
iteration=5
#小批量梯度下降的batch大小
batch=50

#数据载入
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.15,0.30)])
train_set=datasets.MNIST("D:\桌面\GoogleNet",train=True,download=True,transform=trans)
test_set=datasets.MNIST("D:\桌面\GoogleNet",train=False,download=True,transform=trans)
train_loader=DataLoader(train_set,shuffle=True,batch_size=batch,num_workers=2)
test_loader=DataLoader(test_set,shuffle=False,batch_size=batch,num_workers=2)

#模块搭建
class InceptionModule(torch.nn.Module):
    def __init__(self,channels_in):
        super().__init__()
        self.layer1_kenel1_avg=torch.nn.AvgPool2d(3,stride=1,padding=1)
        self.layer1_kenel2_1by1=torch.nn.Conv2d(channels_in,24,1)
        self.layer2_kenel1_1by1=torch.nn.Conv2d(channels_in,16,1)
        self.layer3_kenel1_1by1=torch.nn.Conv2d(channels_in,16,1)
        self.layer3_kenel2_5by5=torch.nn.Conv2d(16,24,5,padding=2)
        self.layer4_kenel1_1by1=torch.nn.Conv2d(channels_in,16,1)
        self.layer4_kenel2_3by3=torch.nn.Conv2d(16,23,3,padding=1)
        self.layer4_kenrl3_3by3=torch.nn.Conv2d(23,24,3,padding=1)

    def forward(self,x):
        x_layer1=self.layer1_kenel2_1by1(self.layer1_kenel1_avg(x))
        x_layer2=self.layer2_kenel1_1by1(x)
        x_layer3=self.layer3_kenel2_5by5(self.layer3_kenel1_1by1(x))
        x_layer4=self.layer4_kenrl3_3by3(self.layer4_kenel2_3by3(self.layer4_kenel1_1by1(x)))
        output=torch.cat([x_layer1,x_layer2,x_layer3,x_layer4],dim=1)
        return output

#GoogleNet搭建
class GoogleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=torch.nn.Conv2d(1,10,5)
        self.conv2=torch.nn.Conv2d(88,20,5)
        self.incep1=InceptionModule(channels_in=10)
        self.incep2=InceptionModule(channels_in=20)
        self.maxpool=torch.nn.MaxPool2d(2)
        self.fully_connection=torch.nn.Linear(1408,10)

    def forward(self,x):
        size_fc=x.shape[0]
        x=F.relu(self.maxpool(self.conv1(x)))
        x=self.incep1(x)
        x=F.relu(self.maxpool(self.conv2(x)))
        x=self.incep2(x)
        #print(x.shape,size_fc)
        x=x.view(size_fc,-1)
        #print(x.shape)
        x=self.fully_connection(x)
        return x

#选择损失函数和优化器
model=GoogleNet()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.005)
#设置学习率衰减
schedular=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.999)

#训练函数
def train():
    for epoch in range(iteration):
        for batch_index,data in enumerate(train_loader,0):
            l=0.0
            train_data,train_labels=data
            optimizer.zero_grad()
            pred_data=model(train_data)
            loss=criterion(pred_data,train_labels)
            l+=loss.item()
            loss.backward()
            optimizer.step()
            schedular.step()
            if batch_index%50==0:
                print("epoch:",epoch,"batch_seq:",batch_index/50,"loss:",l)

#测试函数
def test():
    correct_total = 0.0
    total=0.0
    with torch.no_grad():
        for batch_index,data in enumerate(test_loader,0):
            test_data,test_labels=data
            pred_data=model(test_data)
            _,pred_labels=torch.max(pred_data,dim=1)
            correct_total+=(pred_labels==test_labels).sum().item()
            total+=test_labels.shape[0]
            if batch_index%10==0:
                print("测试进度:batch_sequence:",batch_index/10)
    print("准确率为:",100*correct_total/total,"%")

#主函数
if __name__=='__main__':
    train()
    test()

运行结果

迭代次数为5时,准确率为98.7%,可以尝试增加迭代次数获取更高的准确率,但对电脑硬件和时间要求较高。
在这里插入图片描述

  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值