from __future__ import print_function, absolute_import
import os
from PIL import Image
import numpy as np
import os.path as osp
import torch
from torch.utils.data import Dataset
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not osp.exists(img_path):
raise IOError("{} does not exist".format(img_path))
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, dataset, transform=None):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_path, pid, camid = self.dataset[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
return img, pid, camid
if __name__ == '__main__':
import data_manager
dataset = data_manager.init_img_dataset(root='/home/ubuntu/reid/AlignedReID/data', name='market1501')
train_loader = ImageDataset(dataset.train)
from IPython import embed
embed()
运行结果
=> Market1501 loaded
Dataset statistics:
------------------------------
subset | # ids | # images
------------------------------
train | 751 | 12936
query | 750 | 3368
gallery | 751 | 15913
------------------------------
total | 1501 | 32217
------------------------------
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.19.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: train_loader
Out[1]: <__main__.ImageDataset at 0x7f30d44e8fa0>
In [2]: for batch_id, (imgs, pid, camid) in enumerate(train_loader):
...: break
...:
In [3]: batch_id
Out[3]: 0
In [4]: imgs
Out[4]: <PIL.Image.Image image mode=RGB size=64x128 at 0x7F304A3DA820>
In [5]: pid
Out[5]: 25
In [6]: camid
Out[6]: 4
In [7]: imgs.save('aaaa.jpg')
In [8]:
保存的aaaa.jpg如下: