import os,sys,cv2,torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image class EyeSegDataset(Dataset): def init(self, img_dir, mask_dir, transform=None): self.img_dir = img_dir self.mask_dir = mask_dir self.transform = transform self.images = os.listdir(img_dir) def __len__(self): return len(self.images) def __getitem__(self, idx): # 读取原始图片和标记图片 img_path = os.path.join(self.img_dir, self.images[idx]) mask_path = os.path.join(self.mask_dir, self.images[idx].split('.')[0] + '_mask.png') image = Image.open(img_path).convert('RGB') mask = Image.open(mask_path).convert('L') if self.transform: # 对原始图片和标记图片进行相同的变换 image = self.transform(image) mask = self.transform(mask) # 将标记图片转化为0-9的整数值 mask = torch.squeeze(mask) mask[mask == 255] = 0 return ima
pytorch医学图片分割10层
于 2023-03-07 05:19:03 首次发布