flycharis数据集包含一对图像和对应的光流序列(00001_img1.ppm, 00001_img2.ppm, 00001_flow.flo)
自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__,其中:(1). 在__init__中是初始化了该类的一些基本参数;
(2). __getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;
(3). __len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。
from typing import Union
from pathlib import Path
import torch
import numpy as np
from PIL import Image
class FlyingChairDataset(torch.utils.data.Dataset):
def __init__(self, root: Union[Path, str], transform=None):
super(FlyingChairDataset, self).__init__()
self.root = Path(root)
# print(self.root)
self.ids = set([o.stem.split('_')[0] for o in self.root.iterdir()])
self.ids = list(self.ids)
# print(len(self.ids))
self.transform = transform
def __getitem__(self, idx: int):
id_ = self.ids[idx]
frame1_path = self.root / (id_ + "_img1.ppm")
frame2_path = self.root / (id_ + "_img2.ppm")
optical_flow_path = self.root / (id_ + "_flow.flo")
frame1 = Image.open(frame1_path)
frame2 = Image.open(frame2_path)
optical_flow = read_flow(str(optical_flow_path))
if self.transform is not None:
(frame1, frame2), optical_flow = self.transform((frame1, frame2), optical_flow)
return (frame1, frame2), optical_flow
def __len__(self) -> int:
return len(self.ids)
def read_flow(filename: str) -> np.ndarray:
f = open(filename, 'rb')
header = f.read(4)
# print(header)
# print(header.decode("utf-8"))
if header.decode('utf-8') != "PIEH":
raise Exception("Flow file header does not contain PIEH")
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
# print(width)
# print(height)
flow = np.fromfile(f, np.float32, width * height * 2)
flow = flow.reshape(height, width, 2)
return flow.astype(np.float32)
有时需要对数据集进行Transform:
import random
from typing import Union, Tuple
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import numpy as np
from skimage import transform
ImageOrTensor = Union['Image', torch.Tensor]
Transformed = Tuple[Tuple[ImageOrTensor, ImageOrTensor],
Union[np.ndarray, torch.Tensor]]
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *args):
o = args # (img1, img2), flow
for t in self.transforms:
o = t(*o)
return o
class Resize(object):
def __init__(self, height: int, width: int):
self.height = height
self.width = width
def __call__(self,
frame: Tuple['Image', 'Image'],
optical_flow: np.ndarray) -> Transformed:
# Resize the input PIL Image to the given size.
frame1 = TF.resize(frame[0], size=(self.height, self.width))
frame2 = TF.resize(frame[1], size=(self.height, self.width))
h = optical_flow.shape[0] # 这里!!!!!!!!!!!!!!!!!
optical_flow = transform.resize(optical_flow, output_shape=(self.height, self.width)) * (self.height / h)
return (frame1, frame2), optical_flow
class RandomRotate(object):
def __init__(self, minmax: Union[Tuple[int, int], int]):
self.minmax = minmax
if isinstance(minmax, int):
self.minmax = (-minmax, minmax)
def __call__(self,
frames: Tuple['Image', 'Image'],
optical_flow: np.ndarray) -> Transformed:
angle = random.randint(*self.minmax)
frame1 = TF.rotate(frames[0], angle)
frame2 = TF.rotate(frames[1], angle)
optical_flow = transform.rotate(optical_flow, angle=angle)
return (frame1, frame2), optical_flow
#Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of #shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, #I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8
class ToTensor(object):
def __call__(self,
frames: Tuple['Image', 'Image'],
optical_flow: np.ndarray) -> Transformed:
return (TF.to_tensor(frames[0]), TF.to_tensor(frames[1])), torch.from_numpy(optical_flow).permute(2, 0, 1).float()
class Normalize(object):
def __init__(self,
mean: Tuple[float, float, float],
std: Tuple[float, float, float]):
self.mean = mean
self.std = std
def __call__(self,
frames: Tuple[torch.Tensor, torch.Tensor],
optical_flow: torch.Tensor) -> Transformed:
frame1 = TF.normalize(frames[0], mean=self.mean, std=self.std)
frame2 = TF.normalize(frames[1], mean=self.mean, std=self.std)
return (frame1, frame2), optical_flow
接下来制作训练集和验证集(9:1)
from typing import Tuple, Union, Sequence
import torch
from torch.utils.data import DataLoader, Subset
import spynet
import spynet.data.transforms as TF
def load_data(root: str, k: int) -> Tuple[Subset, Subset]:
train_trans = TF.Compose([
TF.Resize(*spynet.config.GConf(k).image_size),
TF.RandomRotate(17),
TF.ToTensor(),
TF.Normalize(mean=(0.485, 0.406, 0.456),
std=(0.229, 0.225, 0.224))
])
valid_trans = TF.Compose([
TF.Resize(*spynet.config.GConf(k).image_size),
TF.ToTensor(),
TF.Normalize(mean=(0.485, 0.406, 0.456),
std=(0.229, 0.225, 0.224))
])
train_ds = spynet.dataset.FlyingChairDataset(root, transform=train_trans)
valid_ds = spynet.dataset.FlyingChairDataset(root, transform=valid_trans)
rand_idx = torch.randperm(len(train_ds)).tolist()
train_len = int(len(train_ds) * 0.9)
train_ds = Subset(train_ds, rand_idx[:train_len])
valid_ds = Subset(valid_ds, rand_idx[train_len:])
return train_ds, valid_ds
制作DataLoader:
# merges a list of samples to form a mini-batch of Tensor(s)
def collate_fn(batch):
# batch = [((frame0_1, frame0_2), flow0), ((frame1_1, frame1_2), flow1), ...]
# frames = ((frame0_1, frame0_2), (frame1_1, frame1_2), ...)
# flow = (flow0, flow1, ...)
frames, flow = zip(*batch)
# frame1 = (frame0_1, frame1_1, ....)
# frame2 = (frame0_2, frame1_2, ....)
frame1, frame2 = zip(*frames)
return (torch.stack(frame1), torch.stack(frame2)), torch.stack(flow)
# 创建DataLoader
def build_dl(train_ds: Subset,
valid_ds: Subset,
batch_size: int,
num_workers: int) -> Tuple[DataLoader, DataLoader]:
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn)
valid_dl = torch.utils.data.DataLoader(valid_ds,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
collate_fn=collate_fn)
return train_dl, valid_dl