自定义flycharis数据集

flycharis数据集包含一对图像和对应的光流序列(00001_img1.ppm, 00001_img2.ppm, 00001_flow.flo)

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__,其中:(1). 在__init__中是初始化了该类的一些基本参数;
(2). __getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;
(3). __len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

from typing import Union
from pathlib import Path
import torch
import numpy as np
from PIL import Image


class FlyingChairDataset(torch.utils.data.Dataset):

    def __init__(self, root: Union[Path, str], transform=None):
        super(FlyingChairDataset, self).__init__()
        self.root = Path(root)
        # print(self.root)
        self.ids = set([o.stem.split('_')[0] for o in self.root.iterdir()])
        self.ids = list(self.ids)
        # print(len(self.ids))
        self.transform = transform

    def __getitem__(self, idx: int):
        id_ = self.ids[idx]
        frame1_path = self.root / (id_ + "_img1.ppm")
        frame2_path = self.root / (id_ + "_img2.ppm")
        optical_flow_path = self.root / (id_ + "_flow.flo")

        frame1 = Image.open(frame1_path)
        frame2 = Image.open(frame2_path)
        optical_flow = read_flow(str(optical_flow_path))

        if self.transform is not None:
            (frame1, frame2), optical_flow = self.transform((frame1, frame2), optical_flow)

        return (frame1, frame2), optical_flow

    def __len__(self) -> int:
        return len(self.ids)
def read_flow(filename: str) -> np.ndarray:
    f = open(filename, 'rb')
    header = f.read(4)
    # print(header)
    # print(header.decode("utf-8"))

    if header.decode('utf-8') != "PIEH":
        raise Exception("Flow file header does not contain PIEH")

    width = np.fromfile(f, np.int32, 1).squeeze()
    height = np.fromfile(f, np.int32, 1).squeeze()
    # print(width)
    # print(height)

    flow = np.fromfile(f, np.float32, width * height * 2)
    flow = flow.reshape(height, width, 2)
    return flow.astype(np.float32)

有时需要对数据集进行Transform:

import random
from typing import Union, Tuple
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import numpy as np
from skimage import transform

ImageOrTensor = Union['Image', torch.Tensor]
Transformed = Tuple[Tuple[ImageOrTensor, ImageOrTensor],
                    Union[np.ndarray, torch.Tensor]]


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, *args):
        o = args  # (img1, img2), flow
        for t in self.transforms:
            o = t(*o)
        return o


class Resize(object):

    def __init__(self, height: int, width: int):
        self.height = height
        self.width = width

    def __call__(self,
                 frame: Tuple['Image', 'Image'],
                 optical_flow: np.ndarray) -> Transformed:
        # Resize the input PIL Image to the given size.
        frame1 = TF.resize(frame[0], size=(self.height, self.width))
        frame2 = TF.resize(frame[1], size=(self.height, self.width))
        h = optical_flow.shape[0]  # 这里!!!!!!!!!!!!!!!!!
        optical_flow = transform.resize(optical_flow, output_shape=(self.height, self.width)) * (self.height / h)

        return (frame1, frame2), optical_flow


class RandomRotate(object):
    def __init__(self, minmax: Union[Tuple[int, int], int]):
        self.minmax = minmax
        if isinstance(minmax, int):
            self.minmax = (-minmax, minmax)

    def __call__(self,
                 frames: Tuple['Image', 'Image'],
                 optical_flow: np.ndarray) -> Transformed:
        angle = random.randint(*self.minmax)
        frame1 = TF.rotate(frames[0], angle)
        frame2 = TF.rotate(frames[1], angle)
        optical_flow = transform.rotate(optical_flow, angle=angle)

        return (frame1, frame2), optical_flow

#Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of #shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, #I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8


class ToTensor(object):
    def __call__(self,
                 frames: Tuple['Image', 'Image'],
                 optical_flow: np.ndarray) -> Transformed:
        return (TF.to_tensor(frames[0]), TF.to_tensor(frames[1])), torch.from_numpy(optical_flow).permute(2, 0, 1).float()


class Normalize(object):
    def __init__(self,
                 mean: Tuple[float, float, float],
                 std: Tuple[float, float, float]):
        self.mean = mean
        self.std = std

    def __call__(self,
                 frames: Tuple[torch.Tensor, torch.Tensor],
                 optical_flow: torch.Tensor) -> Transformed:
        frame1 = TF.normalize(frames[0], mean=self.mean, std=self.std)
        frame2 = TF.normalize(frames[1], mean=self.mean, std=self.std)

        return (frame1, frame2), optical_flow

接下来制作训练集和验证集(9:1)

from typing import Tuple, Union, Sequence
import torch
from torch.utils.data import DataLoader, Subset
import spynet
import spynet.data.transforms as TF


def load_data(root: str, k: int) -> Tuple[Subset, Subset]:
    train_trans = TF.Compose([
        TF.Resize(*spynet.config.GConf(k).image_size),
        TF.RandomRotate(17),
        TF.ToTensor(),
        TF.Normalize(mean=(0.485, 0.406, 0.456),
                     std=(0.229, 0.225, 0.224))
    ])

    valid_trans = TF.Compose([
        TF.Resize(*spynet.config.GConf(k).image_size),
        TF.ToTensor(),
        TF.Normalize(mean=(0.485, 0.406, 0.456),
                     std=(0.229, 0.225, 0.224))
    ])

    train_ds = spynet.dataset.FlyingChairDataset(root, transform=train_trans)
    valid_ds = spynet.dataset.FlyingChairDataset(root, transform=valid_trans)

    rand_idx = torch.randperm(len(train_ds)).tolist()
    train_len = int(len(train_ds) * 0.9)

    train_ds = Subset(train_ds, rand_idx[:train_len])
    valid_ds = Subset(valid_ds, rand_idx[train_len:])

    return train_ds, valid_ds

制作DataLoader:

# merges a list of samples to form a mini-batch of Tensor(s)
def collate_fn(batch):
    # batch = [((frame0_1, frame0_2), flow0), ((frame1_1, frame1_2), flow1), ...]
    # frames = ((frame0_1, frame0_2), (frame1_1, frame1_2), ...)
    # flow = (flow0, flow1, ...)
    frames, flow = zip(*batch)

    # frame1 = (frame0_1, frame1_1, ....)
    # frame2 = (frame0_2, frame1_2, ....)
    frame1, frame2 = zip(*frames)

    return (torch.stack(frame1), torch.stack(frame2)), torch.stack(flow)


# 创建DataLoader
def build_dl(train_ds: Subset,
             valid_ds: Subset,
             batch_size: int,
             num_workers: int) -> Tuple[DataLoader, DataLoader]:
    train_dl = torch.utils.data.DataLoader(train_ds,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=num_workers,
                                           collate_fn=collate_fn)

    valid_dl = torch.utils.data.DataLoader(valid_ds,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=False,
                                           collate_fn=collate_fn)
    return train_dl, valid_dl

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值