RAFT代码注释

       1. 代码总流程forward()函数

class RAFT(nn.Module):
    def __init__(self, args):
        super(RAFT, self).__init__()
        self.args = args
        self.hidden_dim = hdim = 128
        self.context_dim = cdim = 128
        args.corr_levels = 4
        args.corr_radius = 4
        self.args.dropout = 0
        self.args.alternate_corr = False

        # feature network, context network, and update block
        self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
        self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
        self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)

    def initialize_flow(self, img):
        ...

    def upsample_flow(self, flow, mask):
        ...

    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        """ Estimate optical flow between pair of frames """

		# step 1:预处理
        image1 = 2 * (image1 / 255.0) - 1.0  # 图像归一化至(-1,1)
        image2 = 2 * (image2 / 255.0) - 1.0  # 图像归一化

        image1 = image1.contiguous() #使内存连续
        image2 = image2.contiguous()

        hdim = self.hidden_dim
        cdim = self.context_dim

        # step 2:Feature Encoder 提取两图特征(权值共享)
        with autocast(enabled=self.args.mixed_precision):#混合精度加速
            fmap1, fmap2 = self.fnet([image1, image2])

        fmap1 = fmap1.float()
        fmap2 = fmap2.float()

		# step 3:初始化 Correlation Volumes 相关性查找表
        corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

        # step 4:Context Encoder 提取第一帧图特征
        with autocast(enabled=self.args.mixed_precision):
            cnet = self.cnet(image1)
            net, inp = torch.split(cnet, [hdim, cdim], dim=1)  # cnet按通道维一分为二,net 为 GRU 的隐状态,inp 后续与其他特征结合作为 GRU 的一般输入
            net = torch.tanh(net)
            inp = torch.relu(inp)

		# step 5:更新光流
		# 初始化光流的坐标信息,coords0 为初始时刻的坐标,coords1 为当前迭代的坐标,此处两坐标数值相等
        coords0, coords1 = self.initialize_flow(image1)

        flow_predictions = []
        for itr in range(iters):
            coords1 = coords1.detach() # 只改变内容,不计算梯度
            corr = corr_fn(coords1)  # 从相关性查找表中获取当前坐标的对应特征

            flow = coords1 - coords0  # 计算当前迭代的光流
            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)  # GRU 获取更新的隐状态,用于上采样的 mask,以及光流残差

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow  # 更新当前坐标

            # step 6:上采样光流(此处为了训练网络,对每次迭代的光流都进行了上采样,实际 inference 时,只需要保留最后一次迭代后的上采样)
            flow_up = self.upsample_flow(coords1 - coords0, up_mask)

            flow_predictions.append(flow_up) # 添加上采样光流至预测光流末尾

        if test_mode:
            return coords1 - coords0, flow_up  # 如推理返回光流flow和上采样的光流 flow_up

        return flow_predictions # 返回预测光流

**光流初始化**

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, C, H, W = img.shape                                  # N*384*512
        coords0 = coords_grid(N, H//8, W//8).to(img.device)     # N*2*32*64
        coords1 = coords_grid(N, H//8, W//8).to(img.device)
 
        # optical flow computed as difference: flow = coords1 - coords0
        return coords0, coords1

**构建坐标网格**

def coords_grid(batch, ht, wd):
    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))  # 构成ht*wd的两个坐标网格
    coords = torch.stack(coords[::-1], dim=0).float() # 上述两张量倒序叠加为三维坐标网格 2*ht*wd
    return coords[None].repeat(batch, 1, 1, 1) # 增加batch维度,在此维度重复复制 N*2*32*64

2. 相关性查找表

在 step 3 初始化相关性查找表时,调用 __init__() 函数;在 step 5 查找对应特征时,调用 __call__() 函数。

class CorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = [] 

        # all pairs correlation
        corr = CorrBlock.corr(fmap1, fmap2)  # 对两图特征使用矩阵乘法得到相关性查找表

        batch, h1, w1, dim, h2, w2 = corr.shape
        corr = corr.reshape(batch*h1*w1, dim, h2, w2)  # (b,h,w,1,h,w) -> (bhw,1,h,w)

        self.corr_pyramid.append(corr)
        for i in range(self.num_levels-1): # 循环池化最后两维三次,构建1/2, 1/4, 1/8 的相关特征金字塔
            corr = F.avg_pool2d(corr, 2, stride=2)  # 使用平均 pooling 的方式获得多尺度查找表
            self.corr_pyramid.append(corr) # 此时相关特征金字塔为4个尺度

    def __call__(self, coords):
        r = self.radius
        coords = coords.permute(0, 2, 3, 1)  # (b,2,h,w) -> (b,h,w,2) 当前坐标,包含x和y两个方向,由 meshgrid() 函数得到,细节见 Sec. 3.3.
        batch, h1, w1, _ = coords.shape # _为临时变量

        out_pyramid = []
        for i in range(self.num_levels):
            corr = self.corr_pyramid[i]  # (bhw,1,h,w) 某一尺度的相关特征图
            dx = torch.linspace(-r, r, 2*r+1)  # (2r+1) x方向的相对位置查找范围 -r,-r+1,...,r
            dy = torch.linspace(-r, r, 2*r+1)  # (2r+1) y方向的相对位置查找范围
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)  # 查找窗 (2r+1,2r+1,2)

            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i  # (b,h,w,2) -> (bhw,1,1,2) 某一尺度下的坐标
            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)  # (2r+1,2r+1,2) -> (1,2r+1,2r+1,2) 查找窗
            coords_lvl = centroid_lvl + delta_lvl  # (bhw,1,1,2) + (1,2r+1,2r+1,2) -> (bhw,2r+1,2r+1,2) 可以形象理解为:对于 bhw 这么多待查找的点,每一个点需要搜索 (2r+1)*(2r+1) 邻域范围内的其他点,每个点包含 x 和 y 两个坐标值

            corr = bilinear_sampler(corr, coords_lvl)  # (bhw,1,2r+1,2r+1) 在查找表上搜索每个点的邻域特征,获得相关性图
            corr = corr.view(batch, h1, w1, -1) # (bhw,1,2r+1,2r+1) -> (b,h,w,(2r+1)*(2r+1))
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)  #(b,h,w,4*(2r+1)*(2r+1)) 尺度为4
        return out.permute(0, 3, 1, 2).contiguous().float()  #(b,4*(2r+1)*(2r+1),h,w) #4*9*9
    @staticmethod
    def corr(fmap1, fmap2):
        batch, dim, ht, wd = fmap1.shape
        fmap1 = fmap1.view(batch, dim, ht*wd)  # 第一帧图特征 (b,c,h,w) -> (b,c,hw)
        fmap2 = fmap2.view(batch, dim, ht*wd)  # 第二帧图特征 (b,c,h,w) -> (b,c,hw)

        corr = torch.matmul(fmap1.transpose(1,2), fmap2)  # (b,hw,c) * (b,c,hw) -> (b,hw,hw) 后两维使用矩阵乘法,第一维由广播得到
        corr = corr.view(batch, ht, wd, 1, ht, wd)  # (b,hw,hw) -> (b,h,w,1,h,w)
        return corr / torch.sqrt(torch.tensor(dim).float())  # 这里除的意义可能是为了避免corr数值太大梯度消失

**双线性插值**

def bilinear_sampler(img, coords, mode='bilinear', mask=False):
    """ Wrapper for grid_sample, uses pixel coordinates """
    H, W = img.shape[-2:]
    xgrid, ygrid = coords.split([1,1], dim=-1)  # (bhw,2r+1,2r+1,1) + (bhw,2r+1,2r+1,1)
    xgrid = 2*xgrid/(W-1) - 1  # x方向归一化
    ygrid = 2*ygrid/(H-1) - 1  # y方向归一化

    grid = torch.cat([xgrid, ygrid], dim=-1)  # (bhw,2r+1,2r+1,2)
    img = F.grid_sample(img, grid, align_corners=True)  # img: (bhw,1,h,w) -> (bhw,1,2r+1,2r+1) 根据搜索范围 grid 在查找表 img 中采样对应特征

    return img

3. GRU更新光流

class BasicUpdateBlock(nn.Module):
    def __init__(self, args, hidden_dim=128, input_dim=128):
        super(BasicUpdateBlock, self).__init__()
        self.args = args
        self.encoder = BasicMotionEncoder(args)
        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) #(128,256)
        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)  #(128,256)

        self.mask = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 64*9, 1, padding=0))

    def forward(self, net, inp, corr, flow, upsample=True):
        motion_features = self.encoder(flow, corr)  # flow=(b,2,h,w)corr=(b,4*9*9,h,w)  结合光流和相关性图提取特征 m_f =(b,128,h,w) 
        inp = torch.cat([inp, motion_features], dim=1)  # 连接 Context Encoder 提取的特征和上面提取的特征   (b,256,h,w)

        net = self.gru(net, inp)  # GRU 迭代,更新隐状态 net (net=128,inp=256)-->128
        delta_flow = self.flow_head(net)  # 由隐状态得到光流残差  (b,2,h,w)

        # scale mask to balence gradients
        mask = .25 * self.mask(net)  # 由隐状态得到上采样 mask
        return net, mask, delta_flow

**BasicMotionEncoder**

class BasicMotionEncoder(nn.Module):
    def __init__(self, args):
        super(BasicMotionEncoder, self).__init__()
        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2  #4*9*9
        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
 
    def forward(self, flow, corr):
        cor = F.relu(self.convc1(corr))
        cor = F.relu(self.convc2(cor))
        flo = F.relu(self.convf1(flow))
        flo = F.relu(self.convf2(flo))
 
        cor_flo = torch.cat([cor, flo], dim=1) #192+64
        out = F.relu(self.conv(cor_flo))       #126
        return torch.cat([out, flow], dim=1) #126+2维

**SepConvGRU**

class SepConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(SepConvGRU, self).__init__()
        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
 
        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
 
 
    def forward(self, h, x):
        # horizontal
        hx = torch.cat([h, x], dim=1) #128+256
        z = torch.sigmoid(self.convz1(hx)) #128+256--->128
        r = torch.sigmoid(self.convr1(hx))
        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) #128+256--->128       
        h = (1-z) * h + z * q   #128
 
        # vertical
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz2(hx))
        r = torch.sigmoid(self.convr2(hx))
        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
        h = (1-z) * h + z * q
 
        return h

**FlowHead**

class FlowHead(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256):
        super(FlowHead, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
 
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))  #(b,2,h,w)

4. 上采样光流

class RAFT(nn.Module):
    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, _, H, W = flow.shape
        mask = mask.view(N, 1, 9, 8, 8, H, W)  # (b,9*8*8,h,w) -> (b,1,9,8,8,h,w)
        mask = torch.softmax(mask, dim=2)  # 权重归一化

        up_flow = F.unfold(8 * flow, [3,3], padding=1)  # (b,2,h,w) -> (b,2*3*3,h*w)
        # 提取每个像素点以及周围的 8 邻域像素点特征(总共 9 个像素点)重新排列到 channel 维度上
        # 这里 8*flow 的原因是上采样后图像的尺度变大了,为了匹配尺度增大的像素坐标,光流也要按同样的倍率(8 倍)上采样
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)  # (b,2*3*3,h*w) -> (b,2,9,1,1,h,w)

        up_flow = torch.sum(mask * up_flow, dim=2)  # (b,1,9,8,8,h,w) * (b,2,9,1,1,h,w) -> (b,2,9,8,8,h,w) ->(sum) (b,2,8,8,h,w)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)  # (b,2,8,8,h,w) -> (b,2,h,8,w,8)
        return up_flow.reshape(N, 2, 8*H, 8*W)  # (b,2,h,8,w,8) -> (b,2,8h,8w)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值