本数据集是由上海交通大学(倪冰冰团队)提供,共有十个医学图像分类数据集(分辨率为28*28),由于自己对眼底图片相对来说熟悉一点,所以就先看了一下眼底图片的一些情况。
首先是可视化这28*28的Diabetic Retinopathy(DR)图片
- 数据来源是ISBI2020 challenge(The 2nd diabetic retinopathy – grading and image quality estimation challenge)
- 名称:DeepDR Diabetic Retinopathy Image Dataset (DeepDRiD)
- 数据集下载链接为https://isbi.deepdr.org/data.html
由于原始图片大小为1736*1824,所以MedMNIST中28*28可想而知分辨率有多低
code:
from medmnist import environ
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
class MedMNIST(Dataset):
flag = ...
flag = "retinamnist"
def __init__(self, split='train', transform=None, target_transform=None):
npz_file = np.load(os.path.join(environ.dataroot,"{}.npz".format(self.flag)))
self.split = split
self.transform = transform
self.target_transform = target_transform
if self.split == 'train':
self.img = npz_file['train_images']
self.label = npz_file['train_labels']
elif self.split == 'val':
self.img = npz_file['val_images']
self.label = npz_file['val_labels']
elif self.split == 'test':
self.img = npz_file['test_images']
self.label = npz_file['test_labels']
def __getitem__(self, index):
img, target = self.img[index], int(self.label[index])
img = Image.fromarray(np.uint8(img))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.img.shape[0]
class PathMNIST(MedMNIST):
flag = "pathmnist"
class OCTMNIST(MedMNIST):
flag = "octmnist"
class PneumoniaMNIST(MedMNIST):
flag = "pneumoniamnist"
class ChestMNIST(MedMNIST):
flag = "chestmnist"
class DermaMNIST(MedMNIST):
flag = "dermamnist"
class RetinaMNIST(MedMNIST):
flag = "retinamnist"
class BreastMNIST(MedMNIST):
flag = "breastmnist"
class OrganMNIST_Axial(MedMNIST):
flag = "organmnist_axial"
class OrganMNIST_Coronal(MedMNIST):
flag = "organmnist_coronal"
class OrganMNIST_Sagittal(MedMNIST):
flag = "organmnist_sagittal"
if __name__ == '__main__':
import sys
import torch.utils.data as data
from torchvision import transforms
import matplotlib.pyplot as plt
train_transform = transforms.Compose([
transforms.ToTensor()
])
dataclass = {
"pathmnist": PathMNIST,
"chestmnist": ChestMNIST,
"dermamnist": DermaMNIST,
"octmnist": OCTMNIST,
"pneumoniamnist": PneumoniaMNIST,
"retinamnist": RetinaMNIST,
"breastmnist": BreastMNIST,
"organmnist_axial": OrganMNIST_Axial,
"organmnist_coronal": OrganMNIST_Coronal,
"organmnist_sagittal": OrganMNIST_Sagittal,
}
print(dataclass["retinamnist"])
train_dataset = dataclass["retinamnist"](split='train', transform=train_transform)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=1, shuffle=False)
for img,target in train_loader:
# print(img.size())
img = img.squeeze().permute(1, 2, 0) # 先squeeze讲B压缩,然后permute进行维度交换,将(C,H,W)变为(H,W,C)
# print(img.shape)
plt.imshow(img) # plt.imshow(x) x表示shape=(224,224,)
plt.show()
print(target)