python中
import torch.utils.data import Dataset
通常会有三个函数:
def __init__(self):
def __len__(self):
def __getitem__(self,index):
其中def __init__()常用,就是初始化。不需要返回值。
def __len__(self): 是用来获取数据集的长度。需要返回值。
def __getitem__(self, index): 根据索引获取图片和标签。需要返回值。
在这个函数下面,可以做图像增强,比如:旋转、剪切、归一化等。比如:
def __getitem__(self, idx):
df = open(self.data_file) #打开文件
lines = df.readlines() #读取文件所有行并返回列表
lst = lines[idx].split() #通过指定分隔符对字符串进行切片
img_name = lst[0]
img_label = lst[1]
image_path = os.path.join(self.root_dir, img_name)
image = nib.load(image_path)
if img_label == 'Normal':
label = 0
elif img_label == 'AD':
label = 1
elif img_label == 'MCI':
label = 2
if self.transform: #通过找中心和缩放等实现标准化(比如降维、归一化等)
image = self.transform(image)
sample = {'image': image, 'label': label}
return sample