CNN手写数字识别——使用MSE误差函数

使用MSE误差函数进行手写字体识别时,需要保证label为one-hot形式

可以通过target-transform参数调整label数据为one-hot,也可以在学习过程中调整

本文采用后者,代码如下

import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import mnist
import numpy as np

#网络模型
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size = 3, stride = 1),   #16 26 26
                                nn.BatchNorm2d(16),
                                nn.ReLU(True))
    self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size = 3, stride = 1),   #16 24 24
                                nn.BatchNorm2d(32),
                                nn.ReLU(True),
                                nn.MaxPool2d(2, 2))       #32 12 12
    self.layer3 = nn.Sequential(nn.Conv2d(32, 6
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值