一、Pytorch自定义Dataset的步骤:
- 继承torch.util.data.Dataset
- 实现__getitem__方法
- 实现__len__方法
二、Pytorch自定义Dataset的代码:
from torch.utils import data
from PIL import Image
from torchvision import transforms
import glob
import matplotlib.pyplot as plt
class MyDataset(data.Dataset):
def __init__(self, image_paths, labels, transform):
self.images = image_paths
self.labels = labels
self.transforms = transform
def __getitem__(self, index):
img = self.images[index]
label = self.labels[index]
pil_img = Image.open(img)
pil_img = pil_img.convert('RGB')
result = self.transforms(pil_img)
return result, label
def __len__(self):
return len(self.images)
三、进行结果的展示
from torch.utils import data
from PIL import Image
from torchvision import transforms
import glob
import matplotlib.pyplot as plt
if __name__ == '__main__':
images_path = glob.glob(r'dataset2/*.jpg')
classes = ['cloudy', 'rain', 'shine', 'sunrise']
classes_to_index = dict((cla, i) for i, cla in enumerate(classes))
index_to_classes = dict((v, k) for k, v in classes_to_index.items())
all_labels = []
for img in images_path:
for i, c in enumerate(classes):
if c in img:
all_labels.append(i)
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor()
])
weather_dataset = MyDataset(images_path, all_labels, transform)
weather_dataloader = data.DataLoader(dataset=weather_dataset, batch_size=16, shuffle=True)
imgs, labels = next(iter(weather_dataloader))
plt.figure(figsize=(12, 8))
for i, (imgg, labell) in enumerate(zip(imgs[:6], labels[:6])):
imgg = imgg.permute(1, 2, 0).numpy()
plt.subplot(2, 3, i + 1)
plt.axis('off')
plt.title(index_to_classes.get(labell.item()))
plt.imshow(imgg)
plt.savefig('result.png')
plt.show()