1. 代码总流程forward()函数
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
self.args.dropout = 0
self.args.alternate_corr = False
# feature network, context network, and update block
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def initialize_flow(self, img):
...
def upsample_flow(self, flow, mask):
...
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
# step 1:预处理
image1 = 2 * (image1 / 255.0) - 1.0 # 图像归一化至(-1,1)
image2 = 2 * (image2 / 255.0) - 1.0 # 图像归一化
image1 = image1.contiguous() #使内存连续
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# step 2:Feature Encoder 提取两图特征(权值共享)
with autocast(enabled=self.args.mixed_precision):#混合精度加速
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
# step 3:初始化 Correlation Volumes 相关性查找表
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# step 4:Context Encoder 提取第一帧图特征
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1) # cnet按通道维一分为二,net 为 GRU 的隐状态,inp 后续与其他特征结合作为 GRU 的一般输入
net = torch.tanh(net)
inp = torch.relu(inp)
# step 5:更新光流
# 初始化光流的坐标信息,coords0 为初始时刻的坐标,coords1 为当前迭代的坐标,此处两坐标数值相等
coords0, coords1 = self.initialize_flow(image1)
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach() # 只改变内容,不计算梯度
corr = corr_fn(coords1) # 从相关性查找表中获取当前坐标的对应特征
flow = coords1 - coords0 # 计算当前迭代的光流
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) # GRU 获取更新的隐状态,用于上采样的 mask,以及光流残差
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow # 更新当前坐标
# step 6:上采样光流(此处为了训练网络,对每次迭代的光流都进行了上采样,实际 inference 时,只需要保留最后一次迭代后的上采样)
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up) # 添加上采样光流至预测光流末尾
if test_mode:
return coords1 - coords0, flow_up # 如推理返回光流flow和上采样的光流 flow_up
return flow_predictions # 返回预测光流
**光流初始化**
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape # N*384*512
coords0 = coords_grid(N, H//8, W//8).to(img.device) # N*2*32*64
coords1 = coords_grid(N, H//8, W//8).to(img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
**构建坐标网格**
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) # 构成ht*wd的两个坐标网格
coords = torch.stack(coords[::-1], dim=0).float() # 上述两张量倒序叠加为三维坐标网格 2*ht*wd
return coords[None].repeat(batch, 1, 1, 1) # 增加batch维度,在此维度重复复制 N*2*32*64
2. 相关性查找表
在 step 3 初始化相关性查找表时,调用 __init__() 函数;在 step 5 查找对应特征时,调用 __call__() 函数。
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2) # 对两图特征使用矩阵乘法得到相关性查找表
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2) # (b,h,w,1,h,w) -> (bhw,1,h,w)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1): # 循环池化最后两维三次,构建1/2, 1/4, 1/8 的相关特征金字塔
corr = F.avg_pool2d(corr, 2, stride=2) # 使用平均 pooling 的方式获得多尺度查找表
self.corr_pyramid.append(corr) # 此时相关特征金字塔为4个尺度
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1) # (b,2,h,w) -> (b,h,w,2) 当前坐标,包含x和y两个方向,由 meshgrid() 函数得到,细节见 Sec. 3.3.
batch, h1, w1, _ = coords.shape # _为临时变量
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i] # (bhw,1,h,w) 某一尺度的相关特征图
dx = torch.linspace(-r, r, 2*r+1) # (2r+1) x方向的相对位置查找范围 -r,-r+1,...,r
dy = torch.linspace(-r, r, 2*r+1) # (2r+1) y方向的相对位置查找范围
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) # 查找窗 (2r+1,2r+1,2)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i # (b,h,w,2) -> (bhw,1,1,2) 某一尺度下的坐标
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) # (2r+1,2r+1,2) -> (1,2r+1,2r+1,2) 查找窗
coords_lvl = centroid_lvl + delta_lvl # (bhw,1,1,2) + (1,2r+1,2r+1,2) -> (bhw,2r+1,2r+1,2) 可以形象理解为:对于 bhw 这么多待查找的点,每一个点需要搜索 (2r+1)*(2r+1) 邻域范围内的其他点,每个点包含 x 和 y 两个坐标值
corr = bilinear_sampler(corr, coords_lvl) # (bhw,1,2r+1,2r+1) 在查找表上搜索每个点的邻域特征,获得相关性图
corr = corr.view(batch, h1, w1, -1) # (bhw,1,2r+1,2r+1) -> (b,h,w,(2r+1)*(2r+1))
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1) #(b,h,w,4*(2r+1)*(2r+1)) 尺度为4
return out.permute(0, 3, 1, 2).contiguous().float() #(b,4*(2r+1)*(2r+1),h,w) #4*9*9
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd) # 第一帧图特征 (b,c,h,w) -> (b,c,hw)
fmap2 = fmap2.view(batch, dim, ht*wd) # 第二帧图特征 (b,c,h,w) -> (b,c,hw)
corr = torch.matmul(fmap1.transpose(1,2), fmap2) # (b,hw,c) * (b,c,hw) -> (b,hw,hw) 后两维使用矩阵乘法,第一维由广播得到
corr = corr.view(batch, ht, wd, 1, ht, wd) # (b,hw,hw) -> (b,h,w,1,h,w)
return corr / torch.sqrt(torch.tensor(dim).float()) # 这里除的意义可能是为了避免corr数值太大梯度消失
**双线性插值**
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1) # (bhw,2r+1,2r+1,1) + (bhw,2r+1,2r+1,1)
xgrid = 2*xgrid/(W-1) - 1 # x方向归一化
ygrid = 2*ygrid/(H-1) - 1 # y方向归一化
grid = torch.cat([xgrid, ygrid], dim=-1) # (bhw,2r+1,2r+1,2)
img = F.grid_sample(img, grid, align_corners=True) # img: (bhw,1,h,w) -> (bhw,1,2r+1,2r+1) 根据搜索范围 grid 在查找表 img 中采样对应特征
return img
3. GRU更新光流
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) #(128,256)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256) #(128,256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr) # flow=(b,2,h,w)corr=(b,4*9*9,h,w) 结合光流和相关性图提取特征 m_f =(b,128,h,w)
inp = torch.cat([inp, motion_features], dim=1) # 连接 Context Encoder 提取的特征和上面提取的特征 (b,256,h,w)
net = self.gru(net, inp) # GRU 迭代,更新隐状态 net (net=128,inp=256)-->128
delta_flow = self.flow_head(net) # 由隐状态得到光流残差 (b,2,h,w)
# scale mask to balence gradients
mask = .25 * self.mask(net) # 由隐状态得到上采样 mask
return net, mask, delta_flow
**BasicMotionEncoder**
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 #4*9*9
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1) #192+64
out = F.relu(self.conv(cor_flo)) #126
return torch.cat([out, flow], dim=1) #126+2维
**SepConvGRU**
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1) #128+256
z = torch.sigmoid(self.convz1(hx)) #128+256--->128
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) #128+256--->128
h = (1-z) * h + z * q #128
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
**FlowHead**
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x))) #(b,2,h,w)
4. 上采样光流
class RAFT(nn.Module):
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W) # (b,9*8*8,h,w) -> (b,1,9,8,8,h,w)
mask = torch.softmax(mask, dim=2) # 权重归一化
up_flow = F.unfold(8 * flow, [3,3], padding=1) # (b,2,h,w) -> (b,2*3*3,h*w)
# 提取每个像素点以及周围的 8 邻域像素点特征(总共 9 个像素点)重新排列到 channel 维度上
# 这里 8*flow 的原因是上采样后图像的尺度变大了,为了匹配尺度增大的像素坐标,光流也要按同样的倍率(8 倍)上采样
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) # (b,2*3*3,h*w) -> (b,2,9,1,1,h,w)
up_flow = torch.sum(mask * up_flow, dim=2) # (b,1,9,8,8,h,w) * (b,2,9,1,1,h,w) -> (b,2,9,8,8,h,w) ->(sum) (b,2,8,8,h,w)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # (b,2,8,8,h,w) -> (b,2,h,8,w,8)
return up_flow.reshape(N, 2, 8*H, 8*W) # (b,2,h,8,w,8) -> (b,2,8h,8w)