用Pytorch实现一个用于MNIST手写数字识别的卷积神经网络(CNN)类

class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

初始化部分 __init__

  1. super(CNNMnist, self).__init__(): 与之前一样,这调用基类nn.Module的构造函数。

  2. self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5): 第一层卷积层,接受args.num_channels个通道,输出10个特征图,并使用5x5的卷积核。

  3. self.conv2 = nn.Conv2d(10, 20, kernel_size=5): 第二层卷积层,接受10个通道,输出20个特征图,并使用5x5的卷积核。

  4. self.conv2_drop = nn.Dropout2d(): 二维Dropout层,可以在训练过程中随机设置一部分特征图为零,以防止过拟合。

  5. self.fc1 = nn.Linear(320, 50): 第一个全连接层,有320个输入单元和50个输出单元。

  6. self.fc2 = nn.Linear(50, args.num_classes): 第二个全连接层,有50个输入单元和args.num_classes个输出单元,通常等于分类任务中的类别数。

前向传播部分 forward

  1. x = F.relu(F.max_pool2d(self.conv1(x), 2)): 应用第一层卷积,然后执行最大池化,池化窗口大小为2,并使用ReLU激活函数。

  2. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)): 应用第二层卷积,然后应用2D Dropout,再进行最大池化,最后使用ReLU激活函数。

  3. x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]): 改变张量的形状,以适应全连接层。

  4. x = F.relu(self.fc1(x)): 通过第一个全连接层并应用ReLU激活函数。

  5. x = F.dropout(x, training=self.training): 应用Dropout。

  6. x = self.fc2(x): 通过第二个全连接层。

  7. return F.log_softmax(x, dim=1): 应用对数Softmax激活函数,并返回网络的最终输出。

这个CNNMnist类定义了一个典型的卷积神经网络结构,它包括了两个卷积层、两个最大池化层、两个全连接层和Dropout正则化。这种架构特别适合用于处理MNIST这样的图像分类任务,其中的2D卷积层可以捕捉图像的空间特征。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值