构建pytorch Dataset from numpy

本文介绍如何从numpy数据集中构建用于deep learning的Dataset。假设前序在进行数据处理时已经完成了将全部数据保存至.npy中,本文以灰度2D图像为例,假设有N张图像。首先通过np.load读取数据。

images = np.load('images.npy')  # shape = (N, W, H)
labels = np.load('labels.npy')  # shape = (N, W, H)

然后定义一个class, 该class继承torch.utils.data.Dataset,在该class中需编写2个必须的函数__getitem__(self, index)__len__(self),另每个class都应有一个__init__()函数用于接收该class的输入,在实例化该class时会自动调用该函数。def __getitem__(self, index)函数返回index下对应的数据,用于后续构建DataLoader时使用。def __len__(self)函数返回数据的总个数。
另外在该class中也可以有一些数据处理的函数,在函数_getitem__时调用。

class MyDataset(Dataset):

    def __init__(self, X, *Y):
        '''
        Y is labels, it is a optional parameters
        X: images, shape = (N, w, h), N images, each image shape is (w, h), pixel value is 0-1
        Y: labels, shape = (N, w, h), N labels, each label shape is (w, h), pixel value is 0 or 1
        '''
        self.X = torch.from_numpy(np.expand_dims(X, axis=1))   # shape = (N, 1, w, h) , tensor
        self.haveLabels = False
        if len(Y) == 1:
            self.Y = torch.from_numpy(np.expand_dims(Y[0], axis=1))   # shape = (N, 1, w, h), tensor
            self.haveLabels = True

    def __getitem__(self, index):

        image = self.X[index, ...]  # shape = (1, w, h)
        
        if self.haveLabels:
            label = self.Y[index, ...]  # shape = (1, w, h)
            return image, label#返回数据还有标签
        else:
            return image
       
    
    def __len__(self):
        return self.X.shape[0] #返回数据的总个数

在定义MyDataset后便可输入数据调用,代码如下

X_l_train = images[0:40, ...] # shape = (40, 384, 512)
Y_l_train = labels[0:40, ...] # shape = (40, 384, 512)
X_l_val = images[40:45, ...] # shape = (5, 384, 512)
Y_l_val = labels[40:45, ...]  # shape = (5, 384, 512)
X_u = images[45:, ...]    # shape = (235, 384, 512)
train_data_l = MyDataset(X_l_train, Y_l_train)
val_data_l = MyDataset(X_l_val, Y_l_val)
train_data_u = MyDataset(X_u)
train_loader_l = DataLoader(train_data_l, batch_size=16, shuffle=True)
val_loader_l = DataLoader(val_data_l, batch_size=16, shuffle=False)
train_loader_u = DataLoader(train_data_u, batch_size=16, shuffle=True)

通过torch中的DataLoader定义的数据需要通过for循环访问,如下代码,代码中inputs和labels的shape = (B, C, W, H), B是batch_size,在DataLoader中定义,本代码中对应为16。

 for inputs, labels in train_loader:
            inputs, labels = (
                inputs.to(device),
                labels.to(device),
            )

若希望迭代的访问DataLoader中的数据可以通过python中的next()iter()函数

train_iter_l = iter(train_loader_l)
for i in range(4):
    try:
        inputs, labels = next(train_iter_l)
    except StopIteration:
        train_iter_l = iter(train_loader_l)
        inputs, labels = next(train_iter_l)
    print(inputs.shape, inputs.max())
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值