spynet(六):光流整体结构

本文详细介绍了光流估计的SPyNet网络结构,包括网络单元的定义、训练过程、光流的可视化方法以及torch.nn.functional.interpolate函数的使用。SPyNet采用多层金字塔网络逐层优化光流估计,通过残差学习逐步细化预测。此外,文章还讨论了光流估计的评价指标,并提供了SPyNet网络模型的代码实现,便于理解和复现。
摘要由CSDN通过智能技术生成

16. 网络结构

3层金字塔图

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LLnJzWig-1666147293880)(spynet_note_im/20221014114943.png)]

  1. G是各level的网络结构
    这里主要说下网络的输入,包括
    参考帧
    根据光流warp后的辅助帧
    光流
    对于每层而言:
    在这里插入图片描述

输出的是 光流的残差(上一个level 上采样后得到的 flow与 groundtruth flow之间的差值)
后面会有体现,interpolate插值函数 和 warp函数后面也会讲解

class SpyNetUnit(nn.Module):

    def __init__(self, input_channels: int = 8):
        super(SpyNetUnit, self).__init__()

        self.module = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=7, padding=3, stride=1),
            nn.ReLU(inplace=False),

            nn.Conv2d(32, 64, kernel_size=7, padding=3, stride=1),
            nn.ReLU(inplace=False),

            nn.Conv2d(64, 32, kernel_size=7, padding=3, stride=1),
            nn.ReLU(inplace=False),

            nn.Conv2d(32, 16, kernel_size=7, padding=3, stride=1),
            nn.ReLU(inplace=False),

            nn.Conv2d(16, 2, kernel_size=7, padding=3, stride=1))

    def forward(self, 
                frames: Tuple[torch.Tensor, torch.Tensor], 
                optical_flow: torch.Tensor = None,
                upsample_optical_flow: bool = True) -> torch.Tensor:
        f_frame, s_frame = frames

        # G的输入是两个图片和对应的光流
        # 在第0层也就是金字塔最上层,输入的光流是[]

        if optical_flow is None:
            # If optical flow is None (k = 0) then create empty one having the
            # same size as the input frames, therefore there is no need to 
            # upsample it later
            upsample_optical_flow = False
            b, c, h, w = f_frame.size()
            optical_flow = torch.zeros(b, 2, h, w, device=s_frame.device)

        # 其他层输入的光流是 上一层光流 的 2倍上采样(size和value都要扩大2倍)
        if upsample_optical_flow:
            optical_flow = F.interpolate(
                optical_flow, scale_factor=2, align_corners=True, 
                mode='bilinear') * 2

        s_frame = spynet.nn.warp(s_frame, optical_flow, s_frame.device)
        s_frame = torch.cat([s_frame, optical_flow], dim=1)
        
        inp = torch.cat([f_frame, s_frame], dim=1)
        # inp 是  f_frame,s_frame_warp,optical_flow
        return self.module(inp)
  1. 逐层训练

这套代码是从金字塔最高层到最底层逐个进行训练的:
总的来说条理不是很清晰,而且逐层训练的目的是什么?

def train(**kwargs):
    torch.manual_seed(0)
    previous = []
    for k in range(kwargs.pop('levels')):
        previous.append(train_one_level(k, previous, **kwargs))
    # previous 开始为空,最后是一个包含k层的网络

    # 训练完成后保存下来 
    final = spynet.SpyNet(previous)
    torch.save(final.state_dict(), 
               str(Path(kwargs['checkpoint_dir']) / f'final.pt'))


def train_one_level(k: int, 
                    previous: Sequence[spynet.SpyNetUnit],
                    **kwargs) -> spynet.SpyNetUnit:

    print(f'Training level {k}...')

    train_ds, valid_ds = load_data(kwargs['root'], k)
    train_dl, valid_dl = build_dl(train_ds, valid_ds, 
                                  kwargs['batch_size'],
                                  kwargs['dl_num_workers'])

    # 返回当前的网络 和 之前的网络, 比如3层的网络和2层的网络
    current_level, trained_pyramid = build_spynets(
        k, kwargs['finetune_name'], previous)
    
    optimizer = torch.optim.AdamW(current_level.parameters(),
                                  lr=1e-5,
                                  weight_decay=4e-5)
    loss_fn = spynet.nn.EPELoss()

    for epoch in range(kwargs['epochs']):
        train_one_epoch(train_dl, 
                        optimizer,
                        loss_fn,
                        current_level,
                        trained_pyramid,
                        print_freq=999999,
                        header=f'Epoch [{epoch}] [Level {k}]')

    torch.save(current_level.state_dict(), 
               str(Path(kwargs['checkpoint_dir']) / f'{k}.pt'))
    
    return current_level
def train_one_epoch(dl: DataLoader,
                    optimizer: torch.optim.AdamW,
                    criterion_fn: torch.nn.Module,
                    Gk: torch.nn.Module, 
                    prev_pyramid: torch.nn.Module = None, 
                    print_freq: int = 100,
                    header: str = ''):
    Gk.train()
    running_loss = 0.

    if prev_pyramid is not None:
        prev_pyramid.eval()

    for i, (x, y) in enumerate(dl):
        x = x[0].to(device), x[1].to(device)
        y = y.to(device)

        if prev_pyramid is not None:
            with torch.no_grad():
                Vk_1 = prev_pyramid(x)
                Vk_1 = F.interpolate(
                    Vk_1, scale_factor=2, mode='bilinear', align_corners=True)
        else:
            Vk_1 = None

        predictions = Gk(x, Vk_1, upsample_optical_flow=False)

        if Vk_1 is not None:
            y = y - Vk_1

        loss = criterion_fn(y, predictions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % print_freq == 0:
            loss_mean = running_loss / i
            print(f'{header} [{i}/{len(dl)}] loss {loss_mean:.4f}')

    loss_mean = running_loss / len(dl)
    print(f'{header} loss {loss_mean:.4f}')

# 返回当前的网络 和 之前的网络
def build_spynets(k: int, name: str, 
                  previous: Sequence[torch.nn.Module]) \
                      -> Tuple[spynet.SpyNetUnit, spynet.SpyNet]:

    if name != 'none':
        pretrained = spynet.SpyNet.from_pretrained(name, map_location=device)
        current_train = pretrained.units[k]
    else:
        current_train = spynet.SpyNetUnit()
        
    current_train.to(device)
    current_train.train()
    
    if k == 0:
        Gk = None
    else:
        Gk = spynet.SpyNet(previous)
        Gk.to(device)
        Gk.eval()

    return current_train, Gk
  1. warp 和 epeloss的实现

F.grid_sample 函数 根据 flow (相当于一个查找表对应像素的位置)查表和插值得到一个新的图像

import torch
import torch.nn.functional as F


def warp(image: torch.Tensor, 
         optical_flow: torch.Tensor,
         device: torch.device = torch.device('cpu')) -> torch.Tensor:

    b, c, im_h, im_w = image.size() 
    
    hor = torch.linspace(-1.0, 1.0, im_w).view(1, 1, 1, im_w)
    hor = hor.expand(b, -1, im_h, -1)

    vert = torch.linspace(-1.0, 1.0, im_h).view(1, 1, im_h, 1)
    vert = vert.expand(b, -1, -1, im_w)

    grid = torch.cat([hor, vert], 1).to(device)

    # optical_flow是对应图像size的,因此首先将其缩放到[-1,1]
    # 再与grid相加
    optical_flow = torch.cat([
        optical_flow[:, 0:1, :, :] / ((im_w - 1.0) / 2.0), 
        optical_flow[:, 1:2, :, :] / ((im_h - 1.0) / 2.0)], dim=1)

    # Channels last (which corresponds to optical flow vectors coordinates)
    grid = (grid + optical_flow).permute(0, 2, 3, 1)
    return F.grid_sample(image, grid=grid, padding_mode='border', 
                         align_corners=True)

# 欧式距离
class EPELoss(torch.nn.Module): #end-point-error (EPE)

    def __init__(self):
        super(EPELoss, self).__init__()
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        dist = (target - pred).pow(2).sum().sqrt()
        return dist.mean()

17.光流的可视化方法

参见博客: light flow 光流的常见可视化方法,光流图像生成

18. torch.nn.functional.interpolate函数

常用于 tensord的 上采样,下采样操作

x = Variable(torch.randn([1, 3, 64, 64]))
y0 = F.interpolate(x, scale_factor=0.5)
y1 = F.interpolate(x, size=[32, 32])

y2 = F.interpolate(x, size=[128, 128], mode="bilinear")

print(y0.shape)
print(y1.shape)
print(y2.shape)

return:
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 128, 128])

19. 光流估计的评价指标

在这里插入图片描述

20. 一个比较规整,易懂的spynet 网络模型

"""
This code is based on Open-MMLab's one.
https://github.com/open-mmlab/mmediting
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import flow_warp

class SPyNet(nn.Module):
    """SPyNet network structure.
    The difference to the SPyNet in [tof.py] is that
        1. more SPyNetBasicModule is used in this version, and
        2. no batch normalization is used in this version.
    Paper:
        Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
    Args:
        pretrained (str): path for pre-trained SPyNet. Default: None.
    """

    def __init__(self):
        super().__init__()

        self.basic_module = nn.ModuleList(
            [SPyNetBasicModule() for _ in range(6)]
        )

        #self.load_state_dict(torch.load('spynet_20210409-c6c1bd09.pth'))

        self.register_buffer(
            'mean',
            torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer(
            'std',
            torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def compute_flow(self, ref, supp):
        """Compute flow from ref to supp.
        Note that in this function, the images are already resized to a
        multiple of 32.
        Args:
            ref (Tensor): Reference image with shape of (n, 3, h, w).
            supp (Tensor): Supporting image with shape of (n, 3, h, w).
        Returns:
            Tensor: Estimated optical flow: (n, 2, h, w).
        """
        n, _, h, w = ref.size()

        # normalize the input images
        ref = [(ref - self.mean) / self.std]
        supp = [(supp - self.mean) / self.std]

        # generate downsampled frames
        for level in range(5):
            ref.append(
                F.avg_pool2d(
                    input=ref[-1],
                    kernel_size=2,
                    stride=2,
                    count_include_pad=False
                )
            )
            supp.append(
                F.avg_pool2d(
                    input=supp[-1],
                    kernel_size=2,
                    stride=2,
                    count_include_pad=False
                )
            )
        ref = ref[::-1]
        supp = supp[::-1]

        # flow computation
        flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
        for level in range(len(ref)):
            if level == 0:
                flow_up = flow
            else:
                flow_up = F.interpolate(
                    input=flow,
                    scale_factor=2,
                    mode='bilinear',
                    align_corners=True) * 2.0

            # add the residue to the upsampled flow
            flow = flow_up + self.basic_module[level](
                torch.cat([
                    ref[level],
                    flow_warp(
                        supp[level],
                        flow_up.permute(0, 2, 3, 1),
                        padding_mode='border'), flow_up
                ], 1))

        return flow

    def forward(self, ref, supp):
        """Forward function of SPyNet.
        This function computes the optical flow from ref to supp.
        Args:
            ref (Tensor): Reference image with shape of (n, 3, h, w).
            supp (Tensor): Supporting image with shape of (n, 3, h, w).
        Returns:
            Tensor: Estimated optical flow: (n, 2, h, w).
        """

        # upsize to a multiple of 32
        h, w = ref.shape[2:4]
        w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
        h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
        ref = F.interpolate(
            input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
        supp = F.interpolate(
            input=supp,
            size=(h_up, w_up),
            mode='bilinear',
            align_corners=False)

        # compute flow, and resize back to the original resolution
        flow = F.interpolate(
            input=self.compute_flow(ref, supp),
            size=(h, w),
            mode='bilinear',
            align_corners=False)

        # adjust the flow values
        flow[:, 0, :, :] *= float(w) / float(w_up)
        flow[:, 1, :, :] *= float(h) / float(h_up)

        return flow


class SPyNetBasicModule(nn.Module):
    """Basic Module for SPyNet.
    Paper:
        Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
    """

    def __init__(self):
        super().__init__()

        self.basic_module = nn.Sequential(
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
        )

    def forward(self, tensor_input):
        """
        Args:
            tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
                8 channels contain:
                [reference image (3), neighbor image (3), initial flow (2)].
        Returns:
            Tensor: Refined flow with shape (b, 2, h, w)
        """
        return self.basic_module(tensor_input)
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值