pytorch自定义dataset
文章目录
记录一下进程 经过一晚上的尝试,代码如下:
import os
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class MaskDataset(Dataset):
def __init__(self, img_path, mask_img_path, transform=None):
self.img_path = img_path
self.mask_img_path = mask_img_path
self.transform = transform
self.mask_img = os.listdir(mask_img_path)
def __len__(self):
return len(self.mask_img)
def __getitem__(self, idx):
label_name=self.mask_img [idx]
image_name=label_name
label_path=os.path.join(mask_img_path,label_name)
image_path=os.path.join(img_path,image_name)
# image=cv2.imread(image_path)
# label=cv2.imread(label_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
# image = image.reshape(1, image.shape[0], image.shape[1])
# label = label.reshape(1, label.shape[0], label.shape[1])
image=Image.open(image_path)
label=Image.open(label_path)
# image=np.array(image)
# label=np.array(label)
# image=image[np.newaxis,:,:]
# label = label[np.newaxis, :, :]
# image=Image.fromarray(image)
# label=Image.fromarray(label)
if self.transform is not None:
image=self.transform(image)
label=self.transform(label)
return image,label
if __name__=="__main__":
img_path='E:\\data1\\train\\image1'
mask_img_path='E:\data1\\train\label1'
my_transform= transforms.Compose([transforms.Resize((400,400)), transforms.ToTensor()])
test = MaskDataset(img_path, mask_img_path,my_transform)
print(len(test))
dataloader= torch.utils.data.DataLoader(dataset=test,
batch_size=2,
shuffle=False)
for image, label in dataloader:
print(image.shape)
记录一下错误:
之前用cv2读取图像,但图像是空值,不知道怎么回事