本文介绍如何从numpy数据集中构建用于deep learning的Dataset。假设前序在进行数据处理时已经完成了将全部数据保存至.npy中,本文以灰度2D图像为例,假设有N张图像。首先通过np.load读取数据。
images = np.load('images.npy') # shape = (N, W, H)
labels = np.load('labels.npy') # shape = (N, W, H)
然后定义一个class, 该class继承torch.utils.data.Dataset,在该class中需编写2个必须的函数__getitem__(self, index)
,__len__(self)
,另每个class都应有一个__init__()
函数用于接收该class的输入,在实例化该class时会自动调用该函数。def __getitem__(self, index)
函数返回index下对应的数据,用于后续构建DataLoader时使用。def __len__(self)
函数返回数据的总个数。
另外在该class中也可以有一些数据处理的函数,在函数_getitem__
时调用。
class MyDataset(Dataset):
def __init__(self, X, *Y):
'''
Y is labels, it is a optional parameters
X: images, shape = (N, w, h), N images, each image shape is (w, h), pixel value is 0-1
Y: labels, shape = (N, w, h), N labels, each label shape is (w, h), pixel value is 0 or 1
'''
self.X = torch.from_numpy(np.expand_dims(X, axis=1)) # shape = (N, 1, w, h) , tensor
self.haveLabels = False
if len(Y) == 1:
self.Y = torch.from_numpy(np.expand_dims(Y[0], axis=1)) # shape = (N, 1, w, h), tensor
self.haveLabels = True
def __getitem__(self, index):
image = self.X[index, ...] # shape = (1, w, h)
if self.haveLabels:
label = self.Y[index, ...] # shape = (1, w, h)
return image, label#返回数据还有标签
else:
return image
def __len__(self):
return self.X.shape[0] #返回数据的总个数
在定义MyDataset
后便可输入数据调用,代码如下
X_l_train = images[0:40, ...] # shape = (40, 384, 512)
Y_l_train = labels[0:40, ...] # shape = (40, 384, 512)
X_l_val = images[40:45, ...] # shape = (5, 384, 512)
Y_l_val = labels[40:45, ...] # shape = (5, 384, 512)
X_u = images[45:, ...] # shape = (235, 384, 512)
train_data_l = MyDataset(X_l_train, Y_l_train)
val_data_l = MyDataset(X_l_val, Y_l_val)
train_data_u = MyDataset(X_u)
train_loader_l = DataLoader(train_data_l, batch_size=16, shuffle=True)
val_loader_l = DataLoader(val_data_l, batch_size=16, shuffle=False)
train_loader_u = DataLoader(train_data_u, batch_size=16, shuffle=True)
通过torch中的DataLoader
定义的数据需要通过for循环访问,如下代码,代码中inputs和labels的shape = (B, C, W, H), B是batch_size,在DataLoader中定义,本代码中对应为16。
for inputs, labels in train_loader:
inputs, labels = (
inputs.to(device),
labels.to(device),
)
若希望迭代的访问DataLoader中的数据可以通过python中的next()
和iter()
函数
train_iter_l = iter(train_loader_l)
for i in range(4):
try:
inputs, labels = next(train_iter_l)
except StopIteration:
train_iter_l = iter(train_loader_l)
inputs, labels = next(train_iter_l)
print(inputs.shape, inputs.max())