# 定义一个数据加载器
import cv2
import numpy as np
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import os
class MyDataset(Dataset):
def __init__(self, path):
super(MyDataset, self).__init__()
self.path = path
self.img = self.get_img_label()[0]
self.label = self.get_img_label()[1]
def __getitem__(self, item):
return self.img[item], self.label[item]
def __len__(self):
return len(self.img)
def get_img_label(self):
os.listdir(self.path)
img = []
label = []
for i in os.listdir(self.path):
for j in os.listdir(self.path + '/' + i):
img.append(self.path + '/' + i + '/' + j)
label.append(i)
img = [np.resize(cv2.imread(i).transpose(2, 1, 0), (3, 512, 512)) for i in img]
dic = {'cat':0, 'dog':1}
label = [dic[i] for i in label]
return img, label
# path = './samples/data'
# dataset = MyDataset(path)
# for i in dataset.img:
# print(i.shape)
# cv2.imshow('2', i.transpose(2, 1, 0))
# # cv2.waitKey(0)
#
# data_loader = DataLoader(dataset, 1, shuffle=True)
# print(len(data_loader))
基于Pytorch的自定义图像数据集类
最新推荐文章于 2024-05-17 10:23:04 发布