使用MSE误差函数进行手写字体识别时,需要保证label为one-hot形式
可以通过target-transform参数调整label数据为one-hot,也可以在学习过程中调整
本文采用后者,代码如下
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import mnist
import numpy as np
#网络模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size = 3, stride = 1), #16 26 26
nn.BatchNorm2d(16),
nn.ReLU(True))
self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size = 3, stride = 1), #16 24 24
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) #32 12 12
self.layer3 = nn.Sequential(nn.Conv2d(32, 6