pytorch入门07--Dataset与DataLoader

这篇博客介绍了如何在PyTorch中自定义数据集`WineDataset`,该数据集从CSV文件加载酒类数据,并实现了`__init__`、`__getitem__`和`__len__`方法。然后使用`DataLoader`进行批量处理和随机打乱数据。此外,还展示了训练循环的简单实现。最后,提到了 torchvision 中的预封装数据集如MNIST。
摘要由CSDN通过智能技术生成
#Dataset and DataLoader
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math

class WineDataset(Dataset): #定义加载酒数据集类

    def __init__(self):    #加载数据集赋予x和y
        # read with numpy or pandas
        xy = np.loadtxt('./wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
        self.n_samples = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, 1:])  # size [n_samples, n_features]
        self.y_data = torch.from_numpy(xy[:, [0]])  # size [n_samples, 1]

    # support indexing such that dataset[i] can be used to get i-th sample 检索x、y中元素
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    # we can call len(dataset) to return the size 获取样本数
    def __len__(self):
        return self.n_samples

# create dataset  完成酒类的定义赋予dataset
dataset = WineDataset()
#调用DataLoader函数,第一个参数把dataset获得数据赋给dataloader,第二个参数为每次投喂的是数据大小。第三参数要打乱数据,第四个参数为多进程个数。
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0)

# Dummy Training loop
total_samples = len(dataset)   #获得样本数
n_iterations = math.ceil(total_samples / 4)   #向上取整
print(total_samples, n_iterations)

for i, (inputs, labels) in enumerate(dataloader):
    #forward backward, update  每隔5次看一下输出的数据
    if (i+1)% 5 == 0:
        print(f'Step {i+1}/{n_iterations}| Inputs {inputs.shape} | Labels {labels.shape}')

# some famous datasets are available in torchvision.datasets
# e.g. MNIST, Fashion-MNIST, CIFAR10, COCO

# train_dataset = torchvision.datasets.MNIST(root='./mnist_data',
#                                            train=True,
#                                            transform=torchvision.transforms.ToTensor(),
#                                            download=True)
#
# train_loader = DataLoader(dataset=train_dataset,
#                           batch_size=3,
#                           shuffle=True)
# # look at one random sample
# dataiter = iter(train_loader)
# data = dataiter.next()
# inputs, targets = data
# print(inputs.shape, targets.shape)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值