python dataset_python-TensorDataset上的PyTorch转换

我正在使用TensorDataset从numpy数组创建数据集.

# convert numpy arrays to pytorch tensors

X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])

y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])

# reshape into [C, H, W]

X_train = X_train.reshape((-1, 1, 28, 28)).float()

# create dataset and dataloaders

train_dataset = torch.utils.data.TensorDataset(X_train, y_train)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)

如何将数据扩充(transforms)应用于TensorDataset?

例如,使用ImageFolder,我可以将转换指定为torchvision.datasets.ImageFolder(root,transform = …)的参数之一.

根据PyTorch团队成员之一的this reply,默认情况下不支持.有其他替代方法吗?

随意询问是否需要更多代码来解释问题.

解决方法:

默认情况下,TensorDataset不支持变换.但是我们可以创建我们的自定义类来添加该选项.但是,正如我已经提到的,大多数转换都是为PIL.Image开发的.但是无论如何,这里是带有非常虚拟转换的非常简单的MNIST示例.带有MNIST here的csv文件.

码:

import numpy as np

import torch

from torch.utils.data import Dataset, TensorDataset

import torchvision

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# Import mnist dataset from cvs file and convert it to torch tensor

with open('mnist_train.csv', 'r') as f:

mnist_train = f.readlines()

# Images

X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])

X_train = X_train.reshape((-1, 1, 28, 28))

X_train = torch.tensor(X_train)

# Labels

y_train = np.array([int(i[0]) for i in mnist_train])

y_train = y_train.reshape(y_train.shape[0], 1)

y_train = torch.tensor(y_train)

del mnist_train

class CustomTensorDataset(Dataset):

"""TensorDataset with support of transforms.

"""

def __init__(self, tensors, transform=None):

assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)

self.tensors = tensors

self.transform = transform

def __getitem__(self, index):

x = self.tensors[0][index]

if self.transform:

x = self.transform(x)

y = self.tensors[1][index]

return x, y

def __len__(self):

return self.tensors[0].size(0)

def imshow(img, title=''):

"""Plot the image batch.

"""

plt.figure(figsize=(10, 10))

plt.title(title)

plt.imshow(np.transpose( img.numpy(), (1, 2, 0)), cmap='gray')

plt.show()

# Dataset w/o any tranformations

train_dataset_normal = CustomTensorDataset(tensors=(X_train, y_train), transform=None)

train_loader = torch.utils.data.DataLoader(train_dataset_normal, batch_size=16)

# iterate

for i, data in enumerate(train_loader):

x, y = data

imshow(torchvision.utils.make_grid(x, 4), title='Normal')

break # we need just one batch

# Let's add some transforms

# Dataset with flipping tranformations

def vflip(tensor):

"""Flips tensor vertically.

"""

tensor = tensor.flip(1)

return tensor

def hflip(tensor):

"""Flips tensor horizontally.

"""

tensor = tensor.flip(2)

return tensor

train_dataset_vf = CustomTensorDataset(tensors=(X_train, y_train), transform=vflip)

train_loader = torch.utils.data.DataLoader(train_dataset_vf, batch_size=16)

result = []

for i, data in enumerate(train_loader):

x, y = data

imshow(torchvision.utils.make_grid(x, 4), title='Vertical flip')

break

train_dataset_hf = CustomTensorDataset(tensors=(X_train, y_train), transform=hflip)

train_loader = torch.utils.data.DataLoader(train_dataset_hf, batch_size=16)

result = []

for i, data in enumerate(train_loader):

x, y = data

imshow(torchvision.utils.make_grid(x, 4), title='Horizontal flip')

break

输出:

标签:pytorch,data-augmentation,python

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值