pytorch Dataset
通过Dataset类自定义一个数据集类
该自定义数据集类可以通过路径创建一个训练集
通过图片名称创建Image对象
from torch.utils.data import Dataset
from PIL import Image
import os
class Mydata(Dataset):
'''
自定义数据集类
'''
def __init__(self, root_dir, label_dir):
'''
输入:
root_dir: 'datasets/hymenoptera_data/train'
label_dir: 'ants'
:param root_dir:
:param label_dir:
'''
self.root_dir = root_dir
self.label_dir = label_dir
# 1) 获取label_path
self.label_path = os.path.join('../',self.root_dir, self.label_dir)
# 2) 获取image名列表
self.image_names = os.listdir(self.label_path)
def __getitem__(self, index):
'''
输入:
index: img名字索引
输出:
Image对象
:param index:
:return:
'''
img_name = self.image_names[index]
img_path = os.path.join(self.label_path, img_name)
img = Image.open(img_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.image_names)
def test():
root_dir = input("输入根目录:")
data_label = input("输入数据集名称:")
datasets = Mydata(root_dir, data_label)
index = input("输入图片索引:")
index = int(index)
image, label = datasets.__getitem__(index)
print('该数据集包括{}张图片'.format(datasets.__len__()))
image.show()
if __name__ == '__main__':
test()