10. 定义dataset
传入 数据集的目录,目录中 每两张图片文件对应一个光流文件
传入 transform, 如果transform 不为空,对 文件进行数据增强和处理操作
class FlyingChairDataset(torch.utils.data.Dataset):
def __init__(self,
root: Union[Path, str],
transform = None) -> None:
self.root = Path(root)
self.ids = set([o.stem.split('_')[0] for o in self.root.iterdir()])
self.ids = list(self.ids)
self.transform = transform
def __getitem__(self, idx: int):
id_ = self.ids[idx]
frame1_path = self.root / (id_ + '_img1.ppm')
frame2_path = self.root / (id_ + '_img2.ppm')
optical_flow_path = self.root / (id_ + '_flow.flo')
frame1 = Image.open(str(frame1_path))
frame2 = Image.open(str(frame2_path))
optical_flow = read_flow(str(optical_flow_path))
if self.transform is not None:
(frame1, frame2), optical_flow = \
self.transform((frame1, frame2), optical_flow)
return (frame1, frame2), optical_flow
def __len__(self) -> int:
return len(self.ids)
11. PyTorch 的 torch.randperm()使用方法
返回一个0~n-1的数组,随机打散的
t = torch.randperm(8)
return:tensor([5, 4, 2, 6, 7, 3, 1, 0])
12. 载入数据,生成dataset
这里的输入 是 一个路径和 一个level, 表示的是图像的大小,因为网络框架是金字塔,所以用到不同的size
初始size是 (24, 32)
GConf(1).image_size : (24,32)2
GConf(2).image_size : (24,32)(2^2)
GConf(3).image_size : (24,32)*(2^3)
但是下面的代码过于麻烦
from typing import NamedTuple, Tuple
MAX_G = 5
class BaseGConf(NamedTuple):
image_size: Tuple[int, int] = (24, 32)
class GConf(object):
def __init__(self, level: int) -> None:
assert level >= 0 and level <= MAX_G
self.base_conf = BaseGConf()
self.scale = 2 ** level
@property
def image_size(self):
return (self.base_conf.image_size[0] * self.scale,
self.base_conf.image_size[1] * self.scale)
然后这里输入一个文件路径和level, 形成不同size(根据level)的 train_dataset 和 vali_dataset
def load_data(root: str, k: int) -> Tuple[Subset, Subset]:
train_tfms = OFT.Compose([
OFT.Resize(*spynet.config.GConf(k).image_size), # *spynet.config.GConf(k).image_size写的很繁琐,建议忽略,其实就是 获得一个size,这个size和k有关系,等于init_size * k
OFT.RandomRotate(17),
OFT.ToTensor(),
OFT.Normalize(mean=[.485, .406, .456],
std= [.229, .225, .224])
])
valid_tfms = OFT.Compose([
OFT.Resize(*spynet.config.GConf(k).image_size),
OFT.ToTensor(),
OFT.Normalize(mean=[.485, .406, .456],
std= [.229, .225, .224])
])
train_ds = spynet.dataset.FlyingChairDataset(root, transform=train_tfms)
valid_ds = spynet.dataset.FlyingChairDataset(root, transform=valid_tfms)
train_len = int(len(train_ds) * 0.9) # 9:1划分训练和测试集
rand_idx = torch.randperm(len(train_ds)).tolist() # 返回一个0~n-1的数组,随机打散的
train_ds = Subset(train_ds, rand_idx[:train_len]) # 获得训练集和测试集,这段写的不好,直接在Dataset类中应该直接写好
valid_ds = Subset(valid_ds, rand_idx[train_len:])
return train_ds, valid_ds
13. 根据dataset生成dataloader
传入 batch_size和 num_workers
def collate_fn(batch): # 有点多此一举,换来换去
frames, flow = zip(*batch)
frame1, frame2 = zip(*frames)
return (torch.stack(frame1), torch.stack(frame2)), torch.stack(flow)
def build_dl(train_ds: Subset,
valid_ds: Subset,
batch_size: int,
num_workers: int) -> Tuple[DataLoader, DataLoader]:
train_dl = DataLoader(train_ds,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
collate_fn=collate_fn)
valid_dl = DataLoader(valid_ds,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
collate_fn=collate_fn)
return train_dl, valid_dl
总之经过上面的处理之后,通过图像文件,光流文件得到了dataset和dataloader,方便后续加入到网络训练