pytorch中TensorDataset与DataLoader的使用

Datatset\TensorDataset\DataLoader

class torch.utils.data.Dataset

表示Dataset的抽象类。所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。dataset是用来做打包和预处理(比如输入资料路径自动读取)。

当对图片进行处理时,如通过定义一个transforms来随机旋转训练图片,将图片格式变成tensor。

import numpy as np
import torch
from torch.utils.data import TensorDataset, Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_data = datasets.MNIST(root="./data/mnist", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root="./data/mnist", train=False, transform=transforms.ToTensor(), download=True)

batch_size = 128
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(), ]
)
test_transform = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.ToTensor(), ]
)  # 测试集不需要翻转或旋转图片

# 继承Dataset
class ImgDataset(Dataset):
    def __init__(self, x, y=None, transform=None):
        self.x = x
        self.y = y
        # label is required to be a LongTensor
        if y is not None:
            self.y = torch.LongTensor(y)
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        X = self.x[index]
        if self.transform is not None:
            X = self.transform(X)
        if self.y is not None:
            Y = self.y[index]
            return X, Y
        else:
            return X

# 将dataset分装到dataloader里
train_dataloader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)
test_dataloader = DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False
)

for batch_x, batch_y in test_dataloader:
    print(batch_x.shape, batch_x[0].shape, batch_y.shape)
    break

输出结果:

#变换前
torch.Size([128, 28, 28, 1]) torch.Size([28, 28, 1]) torch.Size([128])
#变换后
torch.Size([128, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([128])

我们发现一个batch的x[0]的shape由原先的(28, 28, 1)变成了(1, 28, 28)。
原因在于transformers.toTensor()方法有自动转换维度的功能,它会将channel变成第一维。

class torch.utils.data.TensorDataset

 class torch.utils.data.TensorDataset(data_tensor, target_tensor)

包装数据和目标张量的数据集。通过沿着第一个维度索引两个张量来恢复每个样本。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。

参数:
data_tensor (Tensor) -包含样本数据
target_tensor (Tensor) -包含样本目标(标签)

样例:

import torch
from torch.utils.data import TensorDataset
#a的形状为(4*3)
a = torch.tensor([[1,2,3],[2,3,4],[4,5,6],[6,7,8]])
#b的第一维与a相同
b = torch.tensor([1,2,3,4])
train_data = TensorDataset(a,b)
print(train_data[0:4])

输出结果:

(tensor([[1, 2, 3],
        [2, 3, 4],
        [4, 5, 6],
        [6, 7, 8]]), tensor([1, 2, 3, 4]))

当对图片进行处理时,如通过定义一个transforms来随机旋转训练图片,将图片格式变成tensor。

tsr_x_train, tsr_y_train = torch.tensor(x_train), torch.tensor(y_train)
tsr_x_val, tsr_y_val = torch.tensor(x_val), torch.tensor(y_val)
tsr_x_testing = torch.tensor(x_test)

#然后只需要一行就可以啦
train_dataset = TensorDataset(tsr_x_train, tsr_y_train)
val_dataset = TensorDataset(tsr_x_val, tsr_y_val)

#装入dataloader的步骤同上
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)
test_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False
)

这次的x[0]的shape同我们一开始设置的shape,TensorDataset并没有帮我们把channel数调成第一维。

class torch.utils.data.DataLoader

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存。DataLoader是将整个资料集(dataset)按照batch进行迭代分装或者shuffle(可以得到一个iterator以利于for循环读取)。

参数:
dataset (Dataset) -加载数据的数据集。
batch_size (int, optional) -每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) -设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) - 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
num_workers (int, optional) -用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
collate_fn (callable, optional)
pin_memory (bool, optional)
drop_last (bool, optional) -如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

样例:

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

a = torch.tensor([[1,2,3],[2,3,4],[4,5,6],[6,7,8]])
b = torch.tensor([1,2,3,4])
train_data = TensorDataset(a,b)
data = DataLoader(train_data, batch_size=2, shuffle=True)
for i, j in enumerate(data):
    x, y = j
    print(' batch:{0} x:{1}  y: {2}'.format(i, x, y))

输出结果:

 batch:0 x:tensor([[1, 2, 3],
        [2, 3, 4]])  y: tensor([1, 2])
 batch:1 x:tensor([[6, 7, 8],
        [4, 5, 6]])  y: tensor([4, 3])
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值