Pytorch学习笔记(二)

本文介绍了PyTorch中的DataLoader如何用于批训练,以及如何使用ConvNet训练cifar-10数据集。同时讲解了模型的保存与获取方法,包括保存网络结构和参数以及仅保存参数的情况。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

(3)批训练包装器DataLoader
Pytorch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练.

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print("Epoch: ", epoch, " | Step: ", step, " | batchx: ", batch_x.numpy(), " | batchy: ", batch_y.numpy())

运行结果如下:

Epoch:  0  | Step:  0  | batchx:  [ 6.  5.  1.  9.  3.]  | batchy:  [  5.   6.  10.   2.   8.]
Epoch:  0  | Step:  1  | batchx:  [ 10.   2.   7.   8.   4.]  | batchy:  [ 1.  9.  4.  3.  7.]
Epoch:  1  | Step:  0  | batchx:  [ 6.  5.  4.  7.  9.]  | batchy:  [ 5.  6.  7.  4.  2.]
Epoch:  1  | Step:  1  | batchx:  [  1.   8.   3.  10.   2.]  | batchy:  [ 10.   3.   8.   1.   9.]
Epoch:  2  | Step:  0  | batchx:  [ 5.  6.  7.  9.  3.]  | batchy:  [ 6.  5.  4.  2.  8.]
Epoch:  2  | Step:  1  | batchx:  [ 10.   4.   8.   1.   2.]  | batchy:  [  1.   7.   3.  10.   9.]

(4)使用ConvNet训练cifar-10数据集

# -*- coding:utf-8 -*-
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值