spynet(四):光流估计代码数据加载和处理

10. 定义dataset

传入 数据集的目录,目录中 每两张图片文件对应一个光流文件
传入 transform, 如果transform 不为空,对 文件进行数据增强和处理操作

class FlyingChairDataset(torch.utils.data.Dataset):

    def __init__(self, 
                 root: Union[Path, str],
                 transform = None) -> None:

        self.root = Path(root)
        self.ids = set([o.stem.split('_')[0] for o in self.root.iterdir()])
        self.ids = list(self.ids)
        self.transform = transform

    def __getitem__(self, idx: int):
        id_ = self.ids[idx]
        frame1_path = self.root / (id_ + '_img1.ppm')
        frame2_path = self.root / (id_ + '_img2.ppm')
        optical_flow_path = self.root / (id_ + '_flow.flo')

        frame1 = Image.open(str(frame1_path))
        frame2 = Image.open(str(frame2_path))
        optical_flow = read_flow(str(optical_flow_path))

        if self.transform is not None:
            (frame1, frame2), optical_flow = \
                self.transform((frame1, frame2), optical_flow)
            
        return (frame1, frame2), optical_flow

    def __len__(self) -> int:
        return len(self.ids)

11. PyTorch 的 torch.randperm()使用方法

返回一个0~n-1的数组,随机打散的

t = torch.randperm(8)
return:tensor([5, 4, 2, 6, 7, 3, 1, 0])

12. 载入数据,生成dataset

这里的输入 是 一个路径和 一个level, 表示的是图像的大小,因为网络框架是金字塔,所以用到不同的size

初始size是 (24, 32)

GConf(1).image_size : (24,32)2
GConf(2).image_size : (24,32)
(2^2)
GConf(3).image_size : (24,32)*(2^3)

但是下面的代码过于麻烦

from typing import NamedTuple, Tuple
MAX_G = 5
class BaseGConf(NamedTuple):
    image_size: Tuple[int, int] = (24, 32)

class GConf(object):

    def __init__(self, level: int) -> None:
        assert level >= 0 and level <= MAX_G
        self.base_conf = BaseGConf()
        self.scale = 2 ** level

    @property
    def image_size(self):
        return (self.base_conf.image_size[0] * self.scale, 
                self.base_conf.image_size[1] * self.scale)

然后这里输入一个文件路径和level, 形成不同size(根据level)的 train_dataset 和 vali_dataset

def load_data(root: str, k: int) -> Tuple[Subset, Subset]:
    train_tfms = OFT.Compose([
        OFT.Resize(*spynet.config.GConf(k).image_size), # *spynet.config.GConf(k).image_size写的很繁琐,建议忽略,其实就是 获得一个size,这个size和k有关系,等于init_size * k
        OFT.RandomRotate(17),
        OFT.ToTensor(),
        OFT.Normalize(mean=[.485, .406, .456], 
                      std= [.229, .225, .224])
    ]) 

    valid_tfms = OFT.Compose([
        OFT.Resize(*spynet.config.GConf(k).image_size),
        OFT.ToTensor(),
        OFT.Normalize(mean=[.485, .406, .456], 
                      std= [.229, .225, .224])
    ])
    
    train_ds = spynet.dataset.FlyingChairDataset(root,  transform=train_tfms)
    valid_ds = spynet.dataset.FlyingChairDataset(root, transform=valid_tfms)
    train_len = int(len(train_ds) * 0.9)              # 9:1划分训练和测试集
    rand_idx = torch.randperm(len(train_ds)).tolist() # 返回一个0~n-1的数组,随机打散的

    train_ds = Subset(train_ds, rand_idx[:train_len]) # 获得训练集和测试集,这段写的不好,直接在Dataset类中应该直接写好
    valid_ds = Subset(valid_ds, rand_idx[train_len:])

    return train_ds, valid_ds

13. 根据dataset生成dataloader

传入 batch_size和 num_workers

def collate_fn(batch):  # 有点多此一举,换来换去
    frames, flow = zip(*batch)
    frame1, frame2 = zip(*frames)
    return (torch.stack(frame1), torch.stack(frame2)), torch.stack(flow)

def build_dl(train_ds: Subset, 
             valid_ds: Subset,
             batch_size: int,
             num_workers: int) -> Tuple[DataLoader, DataLoader]:

    train_dl = DataLoader(train_ds,
                          batch_size=batch_size,
                          num_workers=num_workers,
                          shuffle=True,
                          collate_fn=collate_fn)

    valid_dl = DataLoader(valid_ds,
                          batch_size=batch_size,
                          num_workers=num_workers,
                          shuffle=False,
                          collate_fn=collate_fn)

    return train_dl, valid_dl

总之经过上面的处理之后,通过图像文件,光流文件得到了dataset和dataloader,方便后续加入到网络训练
请添加图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值