使用pytorch实现手写数字识别
-
思路和流程分析
流程: (1)准备数据,这些需要准备DataLoader (2)构建模型,这里可以使用torch构造一个深层的神经网络 (3)模型的训练 (4)保存模型,后续持续使用 (5)模型的评估,使用测试集,观察模型的好坏
-
准备训练集和测试集
(1)torchvision.transforms的图形数据处理方法 a. torchvision.transforms.ToTensor i. 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为(C,H,W)的tensor ii. 其中(H,W,C)意思为(高,宽,通道数),黑白图片的通道数只有1,其中每个像素点的取值为[0,255],彩色图片的通道数为3,分别是(R,G,B),每个通道的每个像素点的取值范围为[0,255],三个通道的颜色互相叠加,形成了各种颜色 iii. img_tensor = transforms.ToTensor()(img) #其中img为(H,W,C)的numpy.ndarray b. torchvision.transforms.Normalize(mean,std) i. 均值和标准差的形状和通道数相同 ii. norm_img = transforms.Normalize((10,10,10),(1,1,1))(img) #通道数为3 c. torchvision.transforms.Compose(transforms) i. 传入list ii. 数据经过list中的每一个方法挨个进行处理 iii. transforms.Compose([ torchvision.transforms.ToTensor(), #先转化为Tensor torchvision.transforms.Normalize(mean,std) #再进行正则化 ]) (2)准备MNIST数据集的Dataset和DataLoader #准备数据集 def get_dataloader(train=True): transforms_fn = Compose([ ToTensor(), Normalize((0.1307),(0.3081)) ]) dataset = MNIST(root='./data',train=train,download=False,transform=transforms_fn) data_loader = torch.utils.data.DataLoader(dataset,batch_size=128,shuffle=True) return data_loader
-
构建模型
(1)模型的构建 a.使用了一个四层的神经网络,其中包括两个全连接层和一个输出层,第一个全连接层会经过激活函数的处理,将处理后的结果交给下一个全连接层,进行变换后输出结果 b.注意 i. 激活函数如何使用 ii. 每一层数据的形状 iii. 模型的损失函数 (2)激活函数的使用 a. import torch.nn.functional as F b. F.relu(x)即可对x进行处理 (3)模型中数据的形状 a.原始输入数据的形状为:[batch_size,1,28,28] b.进行形状的修改:[batch_size,28*28],(全连接层是在进行矩阵的乘法操作) c.第一个全连接层的输出形状:[batch_size,28],这里的28是个人设定的,你也可以设置为别的 d.激活函数不会修改数据的形状 e.第二个全连接层的输出形状:[batch_size,10],因为手写数字由10个类别,为[0,9] 代码如下: #构建模型 class MnistNet(nn.Module): def __init__(self): super(MnistNet, self).__init__() self.fc1 = nn.Linear(28*28*1,28) self.fc2 = nn.Linear(28,10) def forward(self,x): ''' :param x:[batch_size,1,28,28],可能最后一个不到batch_size这么多 :return: ''' #1.修改形状 x = x.view([-1,1*28*28])#x = x.view([x.size(0),1*28*28]) #2.进行全连接操作 x = self.fc1(x)#[batch_size,28] #3.进行激活函数的处理 x = F.relu(x)#形状没有变化[batch_size,28] #4.输出层 out = self.fc2(x)#[batch_size,10] return out
-
模型的损失函数
(1)二分类与多分类的区别 a.多分类和二分类中唯一的区别是我们不能够再使用sigmoid函数来计算当前样本属于某个类别的概率,而应该使用softmax函数 b.softmax和sigmoid的区别在于我们需要去计算样本属于每个类别的概率,需要计算多次,而sigmoid只需要计算一次 (2)softmax的使用方法 使用softmax将结果变为[0,1]之间的值,可以将它当作概率
a.公式
b.例如下图
(3)损失函数处理方法 a.多分类问题损失函数计算
b.交叉熵损失:softmax概率传入对数似然损失得到的损失函数称为交叉熵损失
c.pytorch实现交叉熵损失的方法
i. criterion = nn.CrossEntropyLoss()
loss = criterion(y_predict,y_true)
ii. #1.对输出值计算softmax和取对数
output = F.log_softmax(x,dim=-1)
#2.使用torch中带权损失
loss = F.nll_loss(output,y_true)
-
模型的训练
(1)训练流程 a.实例化模型,设置模型为训练模式 b.实例化优化器类,实例化损失函数 c.获取,遍历dataloader d.梯度置为0 e.进行向前计算 f.计算损失 g.反向传播 h.更新参数 (2)代码实现: #4.训练模型 def train(epoch): mode = True model.train(mode=mode) #模型设置为训练模式 train_dataloader = get_dataloader(train=mode) #获取训练数据集 for idx,(data,target) in enumerate(train_dataloader): y_predict = model(data) loss = F.nll_loss(y_predict,target) optimizer.zero_grad() loss.backward() optimizer.step() if idx % 100==0: print(epoch,idx,loss.item()) # 模型的保存 if idx % 100 == 0: torch.save(model.state_dict(), "./model/mnist.pkl") # 保存模型参数 torch.save(optimizer.state_dict(), "./results/mnist_optim.pkl") # 保存优化器参数
-
模型的保存和加载
(1)模型的保存 torch.save(model.state_dict(),"./model/mnist.pkl")#保存模型参数 torch.save(optimizer.state_dict(),"./results/mnist_optim.pkl")#保存优化器参数 (2)模型的加载 model.load_state_dict(torch.load("./model/mnist.pkl")) optimizer.load_state_dict(torch.load("./results/mnist_optim.pkl"))
-
模型的评估
(1)评估过程与训练过程的区别 a.不需要计算梯度 b.需要收集损失和准确率,用来计算平均损失和平均准确率 c.损失的计算和训练的时候损失计算方法不同 d.准确率的计算: 1)模型的输出为[batch_size,10]的形状 2)其中最大值的位置就是其预测的目标值(预测值进行过softmax后为概率,softmax中分母都是相同的,分子越大,概率越大) 3)最大值的位置获取的方法可以使用torch.max,返回最大值和最大值的位置 4)返回最大值的位置后,和真实值[batch_size]进行对比,相同表示预测成功 (2)代码实现: #6.模型的评估 def test(): loss_list = [] acc_list = [] test_dataloader = get_dataloader(train=False,batch_size=TEST_BATCH_SIZE) for idx,(input,target) in enumerate(test_dataloader): with torch.no_grad(): output = model(input) cur_loss = F.nll_loss(output,target) loss_list.append(cur_loss)#记录每个batch的损失 #计算准确率 # output:[batch_size,10] target:[batch_size,1] pred = output.max(dim=-1)[-1]#获取每行最大值的位置,即该行预测的数字 cur_acc = pred.eq(target).float().mean()#获取均值,即准确值 acc_list.append(cur_acc)#记录每个batch的准确率 print("平均准确率,平均损失:",np.mean(acc_list),np.mean(loss_list))