本文章以ReID的数据加载为例。
from torch.utils.data import dataset, dataloader
from torchvision import transforms
一、建立自定义数据处理方法类:如随机擦除,随机裁剪等
代码:
class RandomErasing(object):
def __init__(self,probability=0.5)
def __call__(self, img)
...
return img
二、建立数据预处理组合类实例:如图像翻转,归一化,向量化,擦除等
代码:
train_transform = transforms.Compose([
transforms.Resize((384, 128), interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
RandomErasing(probability=0.5, mean=[0.0, 0.0, 0.0])
])
三、建立数据读取类:从本地路径进行数据加载,形成列表等
代码:
from torchvision.datasets.folder import default_loader //解释见最后
class Market(dataset.Dataset):
def __init__(self, transform, dtype, data_path):
self.loader = default_loader
def __getitem__(self, index):
...
//根据路径生成图像与标签列表
//加载图像
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img,target
def __len__(self):
return len(self.imgs)
四、生成torch数据流类实例:
self.train_loader = dataloader.DataLoader(self.trainset,
sampler=RandomSampler(self.trainset, batch_id=opt.batchid,
batch_image=opt.batchimage),
batch_size=opt.batchid * opt.batchimage, num_workers=8,
pin_memory=True)
self.test_loader = dataloader.DataLoader(self.testset, batch_size=opt.batchtest, num_workers=8, pin_memory=True)
然后就可以在训练阶段使用迭代方法进行数据获取了。
torchvision.datasets.folder中的default_loader函数:
该函数主要分两种情况调用两个函数,一般采用pil_loader函数。
def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else: #get_image_backend() == 'PIL'
return pil_loader(path)