Pytorch学习(八)---- 批数据训练

莫烦python视频学习笔记 视频链接https://www.bilibili.com/video/BV1Vx411j7kT?from=search&seid=3065687802317837578

import torch
import torch.utils.data as Data

BATCH_SIZE = 5
if __name__ == '__main__':
    x = torch.linspace(1, 10, 10)
    y = torch.linspace(10, 1, 10)
# 将数据放入数据库,用x来训练,用y来计算误差
# 先转换成 torch 能识别的 Dataset
    torch_dataset = Data.TensorDataset(x, y)

    loader = Data.DataLoader(          # 将数据分批
        dataset=torch_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,  # Do you want to break this order?
        num_workers=2,
    )

    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training...
            print('Epoch', epoch, '|Step', step, '|batch x', batch_x.numpy(), '|batch y:', batch_y.numpy())

输出:

Epoch 0 |Step 0 |batch x [ 1.  3.  7.  6. 10.] |batch y: [10.  8.  4.  5.  1.]
Epoch 0 |Step 1 |batch x [5. 8. 9. 4. 2.] |batch y: [6. 3. 2. 7. 9.]
Epoch 1 |Step 0 |batch x [ 2. 10.  3.  7.  8.] |batch y: [9. 1. 8. 4. 3.]
Epoch 1 |Step 1 |batch x [9. 5. 1. 6. 4.] |batch y: [ 2.  6. 10.  5.  7.]
Epoch 2 |Step 0 |batch x [ 4.  1.  5. 10.  6.] |batch y: [ 7. 10.  6.  1.  5.]
Epoch 2 |Step 1 |batch x [2. 8. 7. 3. 9.] |batch y: [9. 3. 4. 8. 2.]

代码在运行中报错:init() got an unexpected keyword argument ‘data_tensor’
此处参考 (https://blog.csdn.net/thunderf/article/details/94733747)
其次还要注意代码的缩进问题。

PyTorch是一个用于深度学习的开源框架,它提供了一组工具和接口,使得我们可以轻松地进行模型训练、预测和部署。在PyTorch中,数据处理是深度学习应用的重要部分之一。 PyTorch中的数据处理主要涉及以下几个方面: 1.数据预处理:包括数据清洗、数据归一化、数据增强等操作,以提高模型的鲁棒性和泛化能力。 2.数据加载:PyTorch提供了多种数据加载方式,包括内置的数据集、自定义的数据集和数据加载器等,以便我们更好地管理和使用数据。 3.数据可视化:为了更好地理解数据和模型,PyTorch提供了多种数据可视化工具,如Matplotlib、TensorBoard等。 下面是一个简单的数据预处理示例,展示如何将图像进行归一化和数据增强: ```python import torch import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 # 定义一个数据预处理管道 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) # 加载CIFAR10数据集,进行预处理 trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) ``` 在上面的例子中,我们首先定义了一个数据预处理管道,其中包括了对图像进行随机裁剪、水平翻转、归一化等操作。然后,我们使用PyTorch内置的CIFAR10数据集,并将其预处理后,使用DataLoader进行量加载。这个过程可以帮助我们更好地管理和使用数据,同时提高模型的训练效率和泛化能力。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值