dataset类代码实战
(1)导包
from torch.utils.data import Dataset
from PIL import Image #Image可以对图像进行加载保存等处理
import os #可进行文件访问、进程管理、内存管理、网络通信等。目前用到的os.path可以拼接、分割、获取文件名
(2)定义一个类
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
#'sep'.join() 连接字符串数组。将字符串、元组、列表中的元素以指定的字符(分隔符)连接生成一个新的字符串
#os.path.join() 将多个路径组合后返回
self.path = os.path.join(self.root_dir,self.label_dir)
#os.listdir() 用于返回指定的文件夹包含的文件或文件夹的名字的列表
self.img_path = os.listdir(self.path)
def __getitem__(self, idx): #__getitem__(self,index) 一般用来迭代序列,或者求序列中的索引为index处的值
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
(3)给出数据集路径
例:
root_dir = 'Satellite Image/train'
desert_label_dir = 'desert'
desert_dataset = MyData(root_dir,desert_label_dir)