pytorch学习第二章——神经网络分类任务

Mnist分类任务:

  • 网络基本构建与训练方法,常用函数解析

  • torch.nn.functional模块

  • nn.Module模块

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)#将数据转化为tensor类型
import torch.nn.functional as F
loss_func = F.cross_entropy#可以使用functional中的不用训练的函数
def model(xb):
    return xb.mm(weights) + bias
from torch import nn
#定义全连接神经网络,继承module不用定义反向传播
class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
        
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
#两个模块帮助你分批次使用Data

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
from torch import optim
#定义model和优化器
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
#损失函数+反向传播
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)
import numpy as np
#训练过程
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()#训练过程
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)
        model.eval()#测试过程
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        #计算损失,是很多数据的平均值
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))
#主函数
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(10, model, loss_func, opt, train_dl, valid_dl)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值