pytorch自定义dataset

pytorch自定义dataset


记录一下进程 经过一晚上的尝试,代码如下:
import os

import numpy as np
from PIL import Image
from torch.utils.data import  DataLoader
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms


class MaskDataset(Dataset):
    def __init__(self, img_path, mask_img_path, transform=None):
        self.img_path = img_path
        self.mask_img_path = mask_img_path
        self.transform = transform
        self.mask_img = os.listdir(mask_img_path)

    def __len__(self):

        return len(self.mask_img)

    def __getitem__(self, idx):
        label_name=self.mask_img [idx]
        image_name=label_name
        label_path=os.path.join(mask_img_path,label_name)
        image_path=os.path.join(img_path,image_name)
        # image=cv2.imread(image_path)
        # label=cv2.imread(label_path)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        # label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        # image = image.reshape(1, image.shape[0], image.shape[1])
        # label = label.reshape(1, label.shape[0], label.shape[1])
        image=Image.open(image_path)
        label=Image.open(label_path)
        # image=np.array(image)
        # label=np.array(label)
        # image=image[np.newaxis,:,:]
        # label = label[np.newaxis, :, :]
        # image=Image.fromarray(image)
        # label=Image.fromarray(label)


        if self.transform is not None:
            image=self.transform(image)
            label=self.transform(label)
        return image,label





if __name__=="__main__":

    img_path='E:\\data1\\train\\image1'
    mask_img_path='E:\data1\\train\label1'
    my_transform= transforms.Compose([transforms.Resize((400,400)), transforms.ToTensor()])
    test = MaskDataset(img_path, mask_img_path,my_transform)
    print(len(test))
    dataloader= torch.utils.data.DataLoader(dataset=test,
                                               batch_size=2,
                                               shuffle=False)
    for image, label in dataloader:
        print(image.shape)

在这里插入图片描述
记录一下错误:
之前用cv2读取图像,但图像是空值,不知道怎么回事
在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值