Pytorch学习笔记【8】---经典MNIST

Pytorch学习笔记【8】—经典MNIST

Pytorch笔记目录:点击进入


虽然已经不知道写了多少次MNIST手写数据集了,当是为了更加熟悉Pytorch还是把经典一些的案例都敲一遍吧

1. 网络结构

在这里插入图片描述
一个简单的网络结构,随便设计啦,反正MNIST手写识别的网络,你随便玩一玩都可以有很好的效果的,别设计太奇怪应该没问题的哈

2. 代码

相关的解释我都有写注释,也不难懂

#!/usr/bin/env python
# coding: utf-8

# In[2]:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms


# In[3]:


# 设置一些基本变量
batch_size = 200
learning_rate = 0.01
epochs = 10


# In[5]:


# 加载torch内置的MNIST数据集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        '../data',train=True,download = True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,),(0.3081,))
            ]
        )
    ),
    batch_size = batch_size,shuffle=True
)
# 加载测试集
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data',train=False,transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,),(0.3081,))
    ])),
    batch_size = batch_size,shuffle=True
)


# In[6]:


# 随机初始化权值
w1,b1 = torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)
w2,b2 = torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)
w3,b3 = torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)


# In[8]:


# 正态分布填充数据
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)


# In[9]:


# 前向传播函数
def forward(x):
    x = x @ w1.t() + b1
    x = F.relu(x)
    x = x @ w2.t() + b2
    x = F.relu(x)
    x = x @ w3.t()
    x = F.relu(x)
    return x


# In[11]:


# 创建优化器,指定学习率
optimizer = optim.SGD([w1,b1,w2,b2,w3,b3],lr = learning_rate)
# 交叉熵损失函数
criteon = nn.CrossEntropyLoss()


# In[13]:


# 进行训练
for epoch in range(epochs):
    for batch_idx,(data,target) in enumerate(train_loader):
        # 将数据打平
        data= data.view(-1,28*28)
        # 获取前向传播结果
        logits = forward(data)
        # 计算误差
        loss = criteon(logits,target)
        # 优化器优化
        optimizer.zero_grad()
        # 方向传播
        loss.backward()
        # print(w1,grad.norm(),w2.grad.norm())
        # 更新参数
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{} / {}({:.0f}#)]\tLoss:{:.6f}'.format(
                epoch,batch_idx * len(data),len(train_loader.dataset),100. * batch_idx/len(train_loader),
                loss.item()
            ))
    test_loss = 0
    correct = 0
    for data,target in test_loader:
        data = data.view(-1,28*28)
        logits = forward(data)
        test_loss += criteon(logits,target).item()
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss : {:.4f}, Accuracy:{}/ {}({:.0f}%)\n'.format(
        test_loss,correct,len(test_loader.dataset),
        100.*correct / len(test_loader.dataset)
    ))

3. API分析

总结一下这里面用到的一些API吧

torch.utils.data.DataLoader()

这个基本是个算法代码就要用到,不过以后我们都是加载自己的数据集像这次一样直接加载的情况是很少的

1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。

2、batch_size,根据具体情况设置即可。

3、shuffle,一般在训练数据中会采用。用于把数据打散

4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。

5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。

6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。

7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。

8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。

9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

torch.nn.init.kaiming_normal_()

把输入的内容覆盖为正态分布的参数

optimizer = optim.SGD([w1,b1,w2,b2,w3,b3],lr = learning_rate)

优化器函数,可以指定参数和需要优化的函数

里面有两个常用函数:

  • optimizer.zero_grad() # 清空上一步的残余更新参数值
  • optimizer.step() # 将参数更新值施加到 net 的 parameters 上

criteon = nn.CrossEntropyLoss()

交叉熵函数对象,注意这里返回的是一个类对象,是可以传参进行计算的,传参得到的对象可以直接进行反向传播的计算

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值