光流(Optical Flow) 是一种图像处理技术,用于表示视频或图像序列中物体或像素的运动信息。具体来说,光流定义了在两个连续帧之间,每个像素从一个位置移动到另一个位置的位移矢量。光流矢量场提供了从当前帧到下一帧的每个像素的运动轨迹。
定义
光流描述了图像帧中每个像素的运动。假设我们有两个连续的图像帧 和
,光流估计出这些帧之间每个像素的位移:
- 前向光流
:描述了帧
中每个像素在帧
中的移动位置。
- 后向光流
:描述了帧
中每个像素在帧
中的移动位置。
光流图中的每个像素是一个二维矢量:
- 水平方向的位移(x 方向),即每个像素水平移动了多少像素。
- 垂直方向的位移(y 方向),即每个像素垂直移动了多少像素。
使用光流进行 Warp 的流程
Warp 是通过光流对图像进行变换的操作,旨在将前一帧的像素根据光流场映射到下一帧的对应位置。使用光流进行 Warp 操作的详细流程:
Step 1: 计算光流
我们首先需要通过光流网络(如 RAFT)来估算两帧之间的光流。我们可以获得前向光流(从帧 到帧
)和后向光流(从帧
到帧
)。
# 通过 RAFT 模型计算光流
flow_fwd, flow_bwd = model(image1, image2), model(image2, image1)
- flow_fwd:表示从
image1
(当前帧)到image2
(下一帧)的光流,记录了每个像素从image1
移动到image2
的位移。 - flow_bwd:表示从
image2
到image1
的光流。
Step 2: 生成网格坐标
每个图像的像素位置可以表示为一个标准的二维网格坐标。我们需要生成一个与图像大小相匹配的网格坐标,用于记录图像中每个像素的 (x, y)
位置。
# 获取图像的高度和宽度
_, _, H, W = image1.shape
# 生成网格坐标 (H, W)
x = torch.linspace(0, 1, W)
y = torch.linspace(0, 1, H)
grid_x, grid_y = torch.meshgrid(x, y)
# 生成包含每个像素位置的网格 (2, H, W)
grid = torch.stack([grid_x, grid_y], dim=0).to(DEVICE)
grid = grid.permute(0, 2, 1) # 交换维度以匹配图像格式
grid[0] *= W # 将网格坐标转换为图像坐标范围
grid[1] *= H
grid:表示每个像素在图像中的原始坐标。大小为 (2, H, W)
,grid[0]
包含所有像素的 x 坐标,grid[1]
包含所有像素的 y 坐标。
Step 3: 通过光流偏移网格坐标
我们将光流场与标准网格坐标相加,从而得到每个像素在目标图像中的新位置。此操作的含义是,根据光流的位移信息更新每个像素的坐标。
# 通过光流偏移网格坐标
grid_1to2 = grid + flow_fwd.squeeze()
grid_1to2
表示在光流作用下,image1
中的每个像素移动到image2
中的位置。
例如,某个像素在 image1
中的位置为 (x, y)
,其光流位移为 (dx, dy)
。则该像素在 image2
中的新位置为 (x + dx, y + dy)
。这种映射关系通过 grid_ + flow_fwd
实现。
Step 4: 归一化网格坐标
为了使用 PyTorch 的 grid_sample
函数进行图像采样,我们需要将这些新的像素坐标归一化到 [-1, 1] 的范围(这是 grid_sample
函数要求的输入格式)。
# 归一化网格坐标到 [-1, 1]
grid_norm_1to2 = grid_1to2.clone()
grid_norm_1to2[0, ...] = 2 * grid_norm_1to2[0, ...] / (W - 1) - 1
grid_norm_1to2[1, ...] = 2 * grid_norm_1to2[1, ...] / (H - 1) - 1
grid_norm_1to2
是经过归一化后的网格坐标,适合 grid_sample
使用。
Step 5: 使用 grid_sample
进行 Warp 操作
我们使用 PyTorch 的 grid_sample
函数,基于新的坐标对图像进行 Warp 操作。grid_sample
会根据新的坐标在原图像中插值采样,生成 Warp 后的图像。
# 使用新的网格坐标对图像进行 Warp 操作
warped_image = F.grid_sample(image1, grid_norm_1to2.unsqueeze(0).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros')
warped_image
:是image1
经过光流场flow_fwd
Warp 后,映射到image2
中的新图像。grid_sample
:根据新坐标grid_norm_1to2
,对image1
进行采样,生成新图像。
整合流程:
import torch
import torch.nn.functional as F
def warp_image_with_flow(raft_model, image1, image2):
# 1. 使用 RAFT 模型计算前向光流和后向光流
flow_fwd = raft_model(image1, image2) # 前向光流:从 image1 到 image2
flow_bwd = raft_model(image2, image1) # 后向光流:从 image2 到 image1
# 2. 生成标准网格坐标
_, _, H, W = image1.shape # 获取图像尺寸
x = torch.linspace(0, 1, W)
y = torch.linspace(0, 1, H)
grid_x, grid_y = torch.meshgrid(x, y)
grid = torch.stack([grid_x, grid_y], dim=0).to(image1.device) # (2, H, W)
grid = grid.permute(0, 2, 1) # 调整为 (2, H, W)
grid[0] *= W # 将网格坐标从 [0, 1] 转换到图像坐标范围
grid[1] *= H
# 3. 通过光流偏移网格
grid_1to2 = grid + flow_fwd.squeeze() # 偏移网格:将 image1 中的每个像素移动到 image2 中的位置
# 4. 归一化网格坐标到 [-1, 1](为了 grid_sample 函数)
grid_norm_1to2 = grid_1to2.clone()
grid_norm_1to2[0, ...] = 2 * grid_norm_1to2[0, ...] / (W - 1) - 1
grid_norm_1to2[1, ...] = 2 * grid_norm_1to2[1, ...] / (H - 1) - 1
# 5. 使用 grid_sample 进行 Warp 操作
warped_image = F.grid_sample(image1, grid_norm_1to2.unsqueeze(0).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros')
return warped_image