from torch.utils.data import Dataset
import cv2
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(root_dir, label_dir)
self.image_dir_list = os.listdir(self.path)
def __getitem__(self, item):
image_dir = os.path.join(self.path, item)
image = cv2.imread(image_dir)
label = self.label_dir
return image, label
def __len__(self):
return len(self.image_dir_list)
if __name__ == '__main__':
mydata = MyData(root_dir="hymenoptera_data/train", label_dir="ants")
image_dir_list = mydata.image_dir_list
for index in image_dir_list:
cv2.imshow(mydata.__getitem__(index)[1],mydata.__getitem__(index)[0])
cv2.waitKey(1000)
cv2.destroyAllWindows()
pytorch创建数据集
最新推荐文章于 2024-09-05 11:42:07 发布