Pytorch实战4——ResNet网络对MNIST数据集分类

目录

一、ResNet(Residual Network) 残差网络

二、代码编程实战对mnist数据集分类


 

一、ResNet(Residual Network) 残差网络

       主要用于解决梯度消失问题 ,H(x)=F(x)+x,F表示残差部分,残差部分与上一层的输出相加就构成了下一层的输入,这整体结构也就被称为残差块(Residual Block)。

 上图左是普通的网络,在这个网络中,训练过程中可能会出现梯度消失的情况(梯度=0),权重参数W = W - 入g, 当梯度g趋于0时,W未发生改变。上图右是ResNet残差网络。

 上图右是ResNet残差网络,x为输入值,F(X)是经过第一层线性变化并激活后的输出,该图表示在残差网络中,第二层进行线性变化之后激活之前,F(x)加入了这一层输入值X,然后再进行激活后输出。在第二层输出值激活前加入X,这条路径称作shortcut连接。

ResNet最终更新某一个节点的参数时,由于H(x)=F(x)+x,求导后,不管括号内右边部分的F(x)求导有多小,始终有x求导后的1存在,不会发生梯度消失的情况。

二、代码编程实战对mnist数据集分类

代码基于之前实战2的卷积神经网络编程,只需要把实战2的代码中设计模型部分替换成ResNet网络的即可,其它部分代码不变,替换的模型如下

#  Residual Network(ResNet)网络模型
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.channels = channels
        self.conv1 = nn.Conv2d(channels, channels,
        kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels,
        kernel_size=3, padding=1)
    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x + y) #两次卷积后的输出y,加上两次卷积前的输入x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.mp = nn.MaxPool2d(2)
        self.rblock1 = ResidualBlock(16) #ResNet网络不改变输入输出维度
        self.rblock2 = ResidualBlock(32)
        self.fc = nn.Linear(512, 10)
    def forward(self, x):
        in_size = x.size(0)
        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
        x = x.view(in_size, -1)
        x = self.fc(x)
        return x
     
model = Net()

运行结果:

[1,   300] loss: 0.526
[1,   600] loss: 0.152
[1,   900] loss: 0.110
Accuracy on test set: 97 %
[2,   300] loss: 0.090
[2,   600] loss: 0.080
[2,   900] loss: 0.072
Accuracy on test set: 98 %

使用残差网络,在只训练了2轮的情况下,准确率达到了98%,错误率降低了三分之一。

是的,我可以使用PyTorch框架来实现ResNet网络并进行MNIST数字识别。ResNet是一种深度残差网络结构,特别适用于图像分类任务。MNIST是一个经典的手写数字数据集,常用于测试和验证机器学习模型。 我可以为您提供一个示例代码来实现这个任务。以下是一个简单的代码示例: ```python import torch import torch.nn as nn import torch.optim as optim from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader # 定义ResNet模型 class ResNet(nn.Module): def __init__(self): super(ResNet, self).__init__() # 定义网络结构... def forward(self, x): # 定义前向传播... # 加载MNIST数据集 train_dataset = MNIST(root='.', train=True, transform=ToTensor(), download=True) test_dataset = MNIST(root='.', train=False, transform=ToTensor()) # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 创建ResNet模型实例 model = ResNet() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练模型 num_epochs = 10 for epoch in range(num_epochs): for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 在测试集上评估模型 with torch.no_grad(): total_correct = 0 total_samples = 0 for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total_samples += labels.size(0) total_correct += (predicted == labels).sum().item() accuracy = total_correct / total_samples print(f'Epoch {epoch+1}, Accuracy: {accuracy}') # 使用训练好的模型进行预测 # ... ``` 请注意,此示例仅提供了一个基本的框架,您可能需要根据实际情况进行适当的修改和调整。希望对您有所帮助!***
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值