【代码小记】赏析《RAFT:运动属性的光流感知》

一.方法:预测当前帧各像素’运动’(‘位移量’ ΔF)

简化架构:

在这里插入图片描述

补充:

在这里插入图片描述

方法:

3.1特征提取

key:相邻帧;映射;低分辨率(单分辨率);残差块

在这里插入图片描述

消融实验1-组件性能

单分辨率提取特征

key:简化网络;大偏移匹配(‘感受野一样’)
在这里插入图片描述

上下文网络

key:更新融合空间信息;
在这里插入图片描述

3.2计算视觉相关性(key)

key:序列模型输入之一;通过‘局部位移信息’,推理出该点位移信息
在这里插入图片描述

key:特征向量对的点积
[c,hw]*[hw,c]=[hw,hw]

key:后两维均池化(保留高分辨率像素信息;获得’不同对象‘大小位移信息)

key:映射1(点->点**【光流估计序列(f1,f2)】);映射2(点->区域【领域】**);四种不同的局部大小位移信息

在这里插入图片描述

消融实验2-组件性能

全像素对点积:满足对指定范围的需要;便于矩阵运算
在这里插入图片描述

3.3 迭代更新

实际架构:

在这里插入图片描述

目标:预测当前帧各像素‘运动’(‘位移量’ΔF)

key-key-key:光流估计序列(获得‘局部位移信息’);更新预测‘位移’量;更新光流估计序列
在这里插入图片描述

实现:

输入:上下文特征;相关性特征((‘局部位移信息’)【领域】;光流特征(点->预测光流点)【初始化/迭代中的预测‘位移’量Δf】(最后,通过输出ΔF与label建立监督

输出:ΔF=(Δf1,Δf2,…,Δfn) n:像素点数
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

key:各像素点的更新位移量Δf

key:上采样;凸组合;‘反卷积’
在这里插入图片描述

上采样:获得原输入分辨度
在这里插入图片描述

3.4 监督问题

key:l1 distance;loss;BP;验证训练有效性
在这里插入图片描述

二.代码解毒

整体框图
在这里插入图片描述

raft.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from core.extractor import BasicEncoder#, SmallEncoder#feature network, context network
from core.corr import CorrBlock, AlternateCorrBlock#相关性查找表
from core.update import BasicUpdateBlock#, SmallUpdateBlock#ConVGRU
from core.utils.utils import bilinear_sampler, coords_grid, upflow8#计算flow、上采样
#
import matplotlib.pyplot as plt
from core.utils import flow_viz
import cv2
#https://www.cnblogs.com/jimchen1218/p/14315008.html
try:
    #利用with语句,在autocast实例的上下文范围内,进行模型的前向推理和loss计算
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass
        def __enter__(self):
            pass
        def __exit__(self, *args):
            pass


class RAFT(nn.Module):
    def __init__(self, args):
        super(RAFT, self).__init__()
        self.args = args

#①初始化参数
        """
        if args.small:
            self.hidden_dim = hdim = 96
            self.context_dim = cdim = 64
            args.corr_levels = 4
            args.corr_radius = 3
        """
        #这里选用4.8M的RAFT框架
        self.hidden_dim = hdim = 128
        self.context_dim = cdim = 128
        #相关体数
        args.corr_levels = 4
        args.corr_radius = 4
        #dropout值
        if 'dropout' not in self.args:
            self.args.dropout = 0
        if 'alternate_corr' not in self.args:
            self.args.alternate_corr = False

#②子网络架构
        # feature network, context network, and update block
        """
        if args.small:
            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)        
            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
        """
        self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)        
        self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
        self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)

    # def freeze_bn(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.BatchNorm2d):#如果m的类型与nn.BatchNorm2d的类型相同则返回 True,否则返回 False
    #             m.eval()

#③初始化光流f0
    #feature encoder network outputs:features at 1/8 resolution(R^H*W*3->G^H/8*W/8*256)
    # Flow is represented as difference between two coordinate grids flow = coords1 - coords0
    def initialize_flow(self, img):
        N, C, H, W = img.shape#这里cam[0]:torch.Size([1, 3, 160, 320])
        coords0 = coords_grid(N, H//8, W//8).to(img.device)# (1,1,20,40)
        coords1 = coords_grid(N, H//8, W//8).to(img.device)# (1,1,20,40)
        return coords0, coords1

#④上采样:mask * up_flow
    #Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination
    #提取每个像素点以及周围的 8 邻域像素点特征(总共 9 个像素点)重新排列到 channel 维度上
    def upsample_flow(self, flow, mask):
        #1)mask
        N, _, H, W = flow.shape#这里torch.Size([1, 2, 20, 40])
        mask = mask.view(N, 1, 9, 8, 8, H, W) # (N,9*8*8,20,40) -> (N,1,9,8,8,20,40)
        mask = torch.softmax(mask, dim=2)  # 权重归一化
        #2)up_flow
        #8*flow:上采样后图像的尺度变大了,为了匹配尺度增大的像素坐标,光流(flow=coords1 - coords0)也要按同样的倍率(8 倍)上采样
        up_flow = F.unfold(8 * flow, [3,3], padding=1) #每一列的元素为滑动窗口(只卷不积)依次所覆盖的内容 (b,2,h,w) -> (b,2*3*3,h*w)
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) # (b,2*3*3,h*w) -> (b,2,9,1,1,h,w)
        up_flow = torch.sum(mask * up_flow, dim=2) # (b,1,9,8,8,h,w) * (b,2,9,1,1,h,w) -> (b,2,9,8,8,h,w) ->(sum) (b,2,8,8,h,w)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)  # (b,2,8,8,h,w) -> (b,2,h,8,w,8)
        return up_flow.reshape(N, 2, 8*H, 8*W)  # (b,2,h,8,w,8) -> (b,2,8h,8w)




    ##Estimate optical flow between pair of frames
    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):

        # step 1:网络输入-图像预处理
        image1 = 2 * (image1 / 255.0) - 1.0#图像归一化
        image2 = 2 * (image2 / 255.0) - 1.0#图像归一化
        image1 = image1.contiguous()#类似深拷贝;调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一模一样,但是两个tensor完全没有联系
        image2 = image2.contiguous()
        hdim = self.hidden_dim#这里128
        cdim = self.context_dim#这里128

        #step 2:Feature Encoder 提取两图特征(权值共享)
        with autocast():
            fmap1, fmap2 = self.fnet([image1, image2])
        fmap1 = fmap1.float()
        fmap2 = fmap2.float()

        # step 3:初始化相关性查找表时,调用 __init__() 函数;
        if self.args.alternate_corr:
            corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)#这里self.args.corr_radius=4
        else:
            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

        # step 4:Context Encoder提取第一帧图特征
        #GRU输入特征之一:net为GRU的隐状态,inp后续与其他特征结合作为 GRU 的一般输入
        with autocast():
            cnet = self.cnet(image1)
            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
            net = torch.tanh(net)
            inp = torch.relu(inp)

        # step 5:更新光流
        # 初始化光流的坐标信息
        # coords0 为初始时刻的坐标,coords1 为当前迭代的坐标
        coords0, coords1 = self.initialize_flow(image1)#此处两坐标数值相等
        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = [] #key:对每次迭代的光流都进行了上采样
        for itr in range(iters):
            #1)初始的coords1
            coords1 = coords1.detach()#key:使之切断反向传播,不具有更新属性
            #2)更新的coords1(coords0做基准一直不变)
            corr = corr_fn(coords1) #从相关性查找表中获取当前坐标的对应特征(查找对应特征时,调用 __call__() 函数)
            flow = coords1 - coords0

            #a)self.update_block的输入:context网络的输出(net, inp)-帧1信息;从相关性查找表中获取各坐标的对应特征-帧1、2信息;flow-包含抽象更新趋势delta_flow
            with autocast():
                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)# 计算当前迭代的光流
            #b)F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow#更新光流

            #3)upsample predictions(目的:训练网络)
            # step 6:上采样光流(此处为了训练网络,对每次迭代的光流都进行了上采样,实际 inference 时,只需要保留最后一次迭代后的上采样)
            if up_mask is None:
                flow_up = upflow8(coords1 - coords0)
            else:
                flow_up = self.upsample_flow(coords1 - coords0, up_mask)
            flow_predictions.append(flow_up)

        if test_mode:
            return coords1 - coords0, flow_up
        return flow_predictions

extractor.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
        super(ResidualBlock, self).__init__()
        #1)定义残差内块:卷积运算、标准化函数和激活函数
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)#inplace = True 时,会修改输入对象的值,所以打印出对象存储地址相同,类似于C语言的址传递
        # [N,C,H*W]
        # nn.BatchNorm2d():统计C_i上所有N张图片(H*W)的均值和方差
        # nn.LayerNorm2d():统计N_i上所有C通道图片(H*W)的均值和方差
        # nn.InstanceNorm2d():统计(C_i,N_j)上图片(H*W)的均值和方差
        # nn.GroupNorm():统计(C_i1,C_i2,······;N_j)上图片(H*W)的均值和方差,介于<IN,LN>
        num_groups = planes // 8
        if norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)#GroupNorm 不会改变输入张量的shape,它只是按照group做normalization
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            if not stride == 1:
                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
        elif norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(planes)
            self.norm2 = nn.BatchNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.BatchNorm2d(planes)
        elif norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(planes)
            self.norm2 = nn.InstanceNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.InstanceNorm2d(planes)
        elif norm_fn == 'none':
            self.norm1 = nn.Sequential()
            self.norm2 = nn.Sequential()
            if not stride == 1:
                self.norm3 = nn.Sequential()
        #2)定义残差短路块:nn.Sequential() create a small sequential model
        if stride == 1:
            self.downsample = None
        else:    
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride)
                , self.norm3)

    def forward(self, x):
        #残差内块
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))
        #残差短路块
        if self.downsample is not None:
            x = self.downsample(x)
        #输出
        return self.relu(x+y)


class BasicEncoder(nn.Module):
    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
        super(BasicEncoder, self).__init__()
        #定义标准化函数
        self.norm_fn = norm_fn
        if self.norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
        elif self.norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(64)
        elif self.norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(64)
        elif self.norm_fn == 'none':
            self.norm1 = nn.Sequential()
        #定义dropout
        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)

        #1)key-key-key:网络架构
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.relu1 = nn.ReLU(inplace=True)
        self.in_planes = 64

        self.layer1 = self._make_layer(64,  stride=1)
        self.layer2 = self._make_layer(96, stride=2)
        self.layer3 = self._make_layer(128, stride=2)

        #output convolution
        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)

        #2)key-key-key:参数初始化的不同选择(对于卷积结构、标准化结构)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #using a normal distribution. The resulting tensor will have values sampled from N(0,std^2)
                #mode='fan_out/in':要么选择保留了向前传递中权重方差的大小。选择保留了向后传递的大小。
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1):
        #1)调用ResidualBlock,并完成(输出-下次输入)的维度转换
        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
        self.in_planes = dim
        #2)序列化Block
        layers = (layer1, layer2)
        return nn.Sequential(*layers)


    def forward(self, x):
        # if input is list, combine batch dimension
        is_list = isinstance(x, tuple) or isinstance(x, list)
        if is_list:
            batch_dim = x[0].shape[0]
            x = torch.cat(x, dim=0)

        #1)
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        #2)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        #3)
        x = self.conv2(x)
        if self.training and self.dropout is not None:
            x = self.dropout(x)

        # if input is list, devide batch dimension
        if is_list:
            x = torch.split(x, [batch_dim, batch_dim], dim=0)

        return x

update.py

import torch
import torch.nn as nn
import torch.nn.functional as F

#Flow与Corr经过卷积而结合
class BasicMotionEncoder(nn.Module):
    def __init__(self, args):
        super(BasicMotionEncoder, self).__init__()
        #
        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
        #
        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
        #
        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)

    def forward(self, flow, corr):
        #1)
        cor = F.relu(self.convc1(corr))
        cor = F.relu(self.convc2(cor))
        #2)
        flo = F.relu(self.convf1(flow))
        flo = F.relu(self.convf2(flo))
        #3)
        cor_flo = torch.cat([cor, flo], dim=1)
        out = F.relu(self.conv(cor_flo))
        #4)
        return torch.cat([out, flow], dim=1)



class SepConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(SepConvGRU, self).__init__()
        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))

        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))

    #h=net    注:context网络的输出(net, inp)-帧1信息;
    #x=[inp, motion_features]  注:motion_features从BasicMotionEncoder模块得到
    #两层GRU:ht-1->ht
    def forward(self, h, x):
        # 将 3x3 卷积替换成 1x5 和 5x1 的两次卷积,在不提高参数量的情况下增大感受野,下面使用的数学计算见 GRU 公式
        # horizontal
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz1(hx))
        r = torch.sigmoid(self.convr1(hx))
        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        
        h = (1-z) * h + z * q

        # vertical
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz2(hx))
        r = torch.sigmoid(self.convr2(hx))
        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
        h = (1-z) * h + z * q
        return h


class FlowHead(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256):
        super(FlowHead, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))

#key-key-key
class BasicUpdateBlock(nn.Module):
    def __init__(self, args, hidden_dim=128, input_dim=128):
        super(BasicUpdateBlock, self).__init__()
        self.args = args
        #
        self.encoder = BasicMotionEncoder(args)
        #
        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
        #
        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
        self.mask = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 64*9, 1, padding=0))

    #self.update_block的输入:
    # context网络的输出(net, inp)-帧1信息;
    # 从相关性查找表中获取各坐标的对应特征 - 帧1、2信息;
    # flow - 包含抽象更新趋势delta_flow
    def forward(self, net, inp, corr, flow, upsample=True):
        #1)BasicMotionEncoder模块
        motion_features = self.encoder(flow, corr)# 结合光流和相关性图提取特征
        #2)SepConvGRU模块
        inp = torch.cat([inp, motion_features], dim=1)# 连接 Context Encoder 提取的特征和上面提取的特征
        net = self.gru(net, inp) # GRU 迭代(iters次),更新隐状态 net
        #3)FlowHead模块
        delta_flow = self.flow_head(net) # 由隐状态得到光流残差
        #4)上采样
        # scale mask to balence gradients
        mask = .25 * self.mask(net) # 由隐状态得到上采样 mask
        return net, mask, delta_flow
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值