import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torchvision.models.resnet import Bottleneck
from typing import List
from IPython import embed
def ResNetBottleNeck(c): return Bottleneck(c, c // 4)
def generate_grid(height: int, width: int):
xs = torch.linspace(0, 1, width)
ys = torch.linspace(0, 1, height)
indices = torch.stack(torch.meshgrid(
(xs, ys), indexing='xy'), 0) # 2 h w
indices = F.pad(indices, (0, 0, 0, 0, 0, 1),
value=1) # 3 h w
# 1 3 h w
indices = indices[None]
return indices
def get_view_matrix(h=200, w=200, h_meters=100.0, w_meters=100.0, offset=0.0):
"""
copied from ..data.common but want to keep models standalone
"""
sh = h / h_meters
sw = w / w_meters
return [
[0., -sw, w/2.],
[-sh, 0., h*offset+h/2.],
[0., 0., 1.]
]
class Normalize(nn.Module):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
super().__init__()
self.register_buffer('mean', torch.tensor(
mean)[None, :, None, None], persistent=False)
self.register_buffer('std', torch.tensor(
std)[None, :, None, None], persistent=False)
def forward(self, x):
return (x - self.mean) / self.std
class RandomCos(nn.Module):
def __init__(self, *args, stride=1, padding=0, **kwargs):
super().__init__()
linear = nn.Conv2d(*args, **kwargs)
self.register_buffer('weight', linear.weight)
self.register_buffer('bias', linear.bias)
self.kwargs = {
'stride': stride,
'padding': padding,
}
def forward(self, x):
return torch.cos(F.conv2d(x, self.weight, self.bias, **self.kwargs))
class BEVEmbedding(nn.Module):
def __init__(
self,
dim: int,
sigma: int,
bev_height: int,
bev_width: int,
h_meters: int,
w_meters: int,
offset: int,
decoder_blocks: list,
):
"""
Only real arguments are:
dim: embedding size
sigma: scale for initializing embedding
The rest of the arguments are used for constructing the view matrix.
In hindsight we should have just specified the view matrix in config
and passed in the view matrix...
"""
super().__init__()
# each decoder block upsamples the bev embedding by a factor of 2
h = bev_height // (2 ** len(decoder_blocks))
w = bev_width // (2 ** len(decoder_blocks))
# bev coordinates
grid = generate_grid(h, w).squeeze(0)
grid[0] = bev_width * grid[0]
grid[1] = bev_height * grid[1]
# map from bev coordinates to ego frame
V = get_view_matrix(bev_height, bev_width,
h_meters, w_meters, offset) # 3 3
V_inv = torch.FloatTensor(V).inverse() # 3 3
# 3 (h w)
grid = V_inv @ rearrange(grid, 'd h w -> d (h w)')
grid = rearrange(grid, 'd (h w) -> d h w', h=h,
w=w) # 3 h w
# egocentric frame
self.register_buffer(
'grid', grid, persistent=False) # 3 h w
self.learned_features = nn.Parameter(
sigma * torch.randn(dim, h, w)) # d h w
def get_prior(self):
return self.learned_features
class KernelAttention(nn.Module):
def __init__(self, dim, heads, dim_head, qkv_bias, norm=nn.LayerNorm):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Sequential(norm(dim), nn.Linear(
dim, heads * dim_head, bias=qkv_bias))
self.to_k = nn.Sequential(norm(dim), nn.Linear(
dim, heads * dim_head, bias=qkv_bias))
self.to_v = nn.Sequential(norm(dim), nn.Linear(
dim, heads * dim_head, bias=qkv_bias))
self.proj = nn.Linear(heads * dim_head, dim)
self.prenorm = norm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, 2 * dim), nn.GELU(), nn.Linear(2 * dim, dim))
self.postnorm = norm(dim)
def forward(self, q, k, v, skip=None, mask=None):
"""
q: (b n d H W)
k: (b n k g d)
v: (b n k g d)
mask: (b n k 1)
"""
_, _, _, H, W = q.shape
num_points = k.shape[-2]
# Move feature dim to last for multi-head proj
# (b, n, k, d)
q = rearrange(q, 'b n d H W -> b n (H W) d')
# Project with multiple heads
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
# Group the head dim with batch dim
q = rearrange(q, 'b n q (m d) -> (b m) n q 1 d',
m=self.heads, d=self.dim_head)
k = rearrange(k, 'b n q g (m d) -> (b m) n q g d',
m=self.heads, d=self.dim_head)
v = rearrange(v, 'b n q g (m d) -> (b m) q (n g) d',
m=self.heads, d=self.dim_head)
# Dot product attention along cameras
dot = self.scale * \
torch.einsum('b n Q c d, b n Q K d -> b n Q c K', q, k)
dot = rearrange(dot, 'b n Q c K -> b Q (n c K)')
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.heads, 1, 1, num_points)
mask = rearrange(mask, 'b h n Q g -> (b h) Q (n g)')
dot[~mask] = -10**9
att = dot.to(q).softmax(dim=-1)
a = torch.einsum('b Q K, b Q K d -> b Q d', att, v)
a = rearrange(a, '(b m) Q d -> b Q (m d)',
m=self.heads, d=self.dim_head)
# Combine multiple heads
z = self.proj(a)
# Optional skip connection
if skip is not None:
z = z + rearrange(skip, 'b d H W -> b (H W) d')
z = self.prenorm(z)
z = z + self.mlp(z)
z = self.postnorm(z)
z = rearrange(z, 'b (H W) d -> b d H W', H=H, W=W)
return z
@torch.no_grad()
def bev2image_sampling(points, I, E, height, width):
"""
bev points to images: each bev point -> image points
Args:
points: (k, 3), (x,y,z)
I: (b, n, 3, 3)
E: (b, n, 4, 4)
Return:
sampled points: (k, 6, 2)
"""
# (k, 3) -> (k, 4)
k = points.shape[0]
b, n = I.shape[:2]
points = torch.cat([points, torch.ones_like(points[..., :1])], -1) #把点拓展为齐次坐标
intrin_mat = F.pad(I, (0, 1, 0, 1), value=0)
intrin_mat[..., -1, -1] = 1.0 #填充4x4
# (k, 3) -> (b, n, k, 4, 1)
points = points.view(1, 1, k, 4).repeat(b, n, 1, 1).unsqueeze(-1)
# (b, n, 4, 4) * (k, 4)^T
point2image = (intrin_mat @ E).view(b, n, 1, 4, 4).repeat(1, 1, k, 1, 1)
sample_points = (point2image @ points).squeeze(-1) # (b, n, k, 4)
#将扩展后的点坐标 points 和相机外参矩阵 E 进行矩阵乘法,得到从 BEV 到图像坐标系的投影变换矩阵。然后,通过矩阵乘法将点坐标 points 投影到图像坐标系中,得到在图像上采样的点的坐标 sample_points。
# filter points
eps = 1e-5
# mask: (b, n, k, 4)
mask = (sample_points[..., 2:3] > eps)
sample_points = sample_points[..., 0:2] / \
sample_points[..., 2:3].clamp(min=eps)
#这部分代码根据投影后的点的深度信息,通过阈值 eps 进行筛选,将深度小于阈值的点标记为无效点。然后,对采样点的坐标进行归一化处理,除以深度值,以得到归一化的图像坐标。
sample_points[..., 0] /= width
sample_points[..., 1] /= height
#归一话
# sample points in the image
mask = (mask & (sample_points[..., 0:1] > 0.0)
& (sample_points[..., 0:1] < 1.0)
& (sample_points[..., 1:2] > 0.0)
& (sample_points[..., 1:2] < 1.0))
mask = torch.nan_to_num(mask)
#v这部分代码根据归一化的图像坐标,对采样点进行进一步的筛选,将超出图像边界的点标记为无效点。然后,通过 torch.nan_to_num() 函数将无效点的掩码转换为数值类型,将 NaN 替换为 0。
return sample_points, mask
class IndexBEVProjector(nn.Module):
"""GridBEVProjector, based on Grid Sampling (nearest)
"""
def __init__(self, image_size, grid_size=(3, 3), height=-1.0):
super().__init__()
self.image_size = image_size
self.grid_size = grid_size
grid_h, grid_w = grid_size
y = torch.arange(grid_h) - grid_h // 2
x = torch.arange(grid_w) - grid_w // 2
offsets = torch.stack(torch.meshgrid(
x, y, indexing="xy")).permute(1, 2, 0)
self.register_buffer("grid_offsets", offsets, persistent=False)
self.bev_height = height
def forward(self, bev_grids, images, I, E):
"""
bev_grids: (3, H, W)
images: (b, n, c, h, w), features
I: (b, n, 3, 3)
E: (b, n, 4, 4)
im_size: (height, width)
"""
b, n = I.shape[:2]
# unfold feature maps
bn, c, h, w = images.shape
# bev_grids -> image_coords
# (3, H, W) -> (H*W, 3), k=H*W
bev_points = bev_grids.reshape(3, -1).transpose(0, 1)
bev_points[:, -1] = self.bev_height
# (b, n, k, 2), (b, n, k, 1)
sample_points, sample_mask = bev2image_sampling(
bev_points, I, E, self.image_size[0], self.image_size[1])
#将 BEV 网格坐标转换为图像坐标。
#返回的 sample_points 是图像中的采样点坐标,sample_mask 是对应的掩码,用于指示哪些点是有效的。
num_grid_points = self.grid_size[0] * self.grid_size[1]
sample_points[..., 0] *= w
sample_points[..., 1] *= h
sample_points = sample_points.round().long()
#计算每个采样点在图像中的坐标。
#将 sample_points 中的 x 坐标乘以图像的宽度 w,y 坐标乘以图像的高度 h,
#然后将坐标值四舍五入并转换为整型。
grid_offsets = self.grid_offsets.view(1, 1, 1, num_grid_points, 2) #1x1x1x7x2
"""
将网格偏移量 grid_offsets 添加到 sample_points 上,以获得每个采样点周围的九个坐标。
然后,对坐标进行约束,确保它们在图像范围内。接着,通过计算索引值 sample_points_inds,
从图像特征 images 中提取对应的特征。最后,将结果重塑为适当的形状,并返回采样的特征 sample_feats 和对应的掩码 sample_mask。
"""
# [b, n, k, 7, 2]
sample_points = sample_points.unsqueeze(-2) + grid_offsets #1x6x625x7x2 +1x1x1x7x2
# restrict sample_points between 0~H-1
sample_points[..., 0].clamp_(min=0, max=w-1)
sample_points[..., 1].clamp_(min=0, max=h-1) #确保坐标的取值范围在图像的宽度和高度之内。
# [b, n, k, 9]
k = sample_points.shape[2]
sample_points_inds = sample_points[..., 0] + sample_points[..., 1] * w #(x+y)*w 计算一维索引
# [b*n, k*9]
sample_points_inds = sample_points_inds.view(
b * n, k * num_grid_points) #sample_points_inds.shape 26250
# [b*n*h*w, c]
images = rearrange(images, "b c h w -> (b h w) c")
ind_offsets = (torch.arange(b * n, device=images.device)
* (h * w)).view(b * n, 1)
#对图像特征 images 进行形状变换,将其转换为一个大小为 [b*n*h*w, c] 的二维张量。
#同时,创建一个偏移量张量 ind_offsets,用于计算采样点在二维图像特征中的索引值。偏移量张量的形状为 [b*n, 1]。
# b*n*k*9, 1
sample_points_inds = (sample_points_inds + ind_offsets).view(-1)
#过加grid_offsets上偏移量,将采样点在二维图像特征中的索引值计算出来,并将结果展平为一维张量,大小为 [b*n*k*9, 1]。
# [b*n*k*9, c]
sample_feats = images[sample_points_inds].reshape(
b, n, k, num_grid_points, c)
#通过索引操作,从图像特征 images 中提取对应的特征。
#使用 sample_points_inds 作为索引,在 images 中选择对应的行,
#并将结果重塑为大小为 [b, n, k, 9, c] 的张量。
#这样,sample_feats 存储了每个采样点周围的九个位置的特征。
# embed()
return sample_feats, sample_mask.detach()
class UnfoldBEVProjector(nn.Module):
def __init__(self, image_size, grid_size=(3, 3), height=-1.0):
super().__init__()
self.image_size = image_size
self.grid_size = grid_size
self.pad_size = (grid_size[0] // 2, grid_size[1] // 2)
self.unfold = nn.Unfold(
kernel_size=self.grid_size,
padding=self.pad_size
)
self.bev_height = height
def forward(self, bev_grids, images, I, E):
"""
bev_grids: (3, H, W)
images: (b*n, c, h, w), features
I: (b, n, 3, 3)
E: (b, n, 4, 4)
im_size: (height, width)
"""
# bev_grids -> image_coords
# (3, H, W) -> (H*W, 3), k=H*W
bev_points = bev_grids.reshape(
3, -1).transpose(0, 1).requires_grad_(False)
#将 bev_grids 重塑为大小为 (3, H*W) 的张量,
#然后进行转置操作,使得每一行代表一个三维 BEV 坐标点。
# z: bev height
bev_points[:, -1] = self.bev_height
# (b, n, k, 2), (b, n, k, 1)
sample_points, sample_mask = bev2image_sampling(
bev_points, I, E, self.image_size[0], self.image_size[1])
sample_points = sample_points * 2.0 - 1.0
#将 BEV 坐标点转换为图像坐标点,并返回采样点的位置和掩码。
#然后,将 sample_points 的值进行缩放和平移操作,将其范围映射到 [-1, 1]。
# embed()
b, n = I.shape[:2]
# unfold feature maps
bn, c, h, w = images.shape
# (b*n, c*p, h, w)
unfold_images = self.unfold(images).view(bn, -1, h, w) #展开后的特征张量被重新形状为 (b*n, c*p, h, w),其中 p 是每个格子中像素的数量。
# (b, n, k, 2) -> (b * n, k, 1, 2)
k = sample_points.shape[2]
sample_points = sample_points.reshape(b * n, k, 1, 2) #获取 sample_points 的第三个维度的大小 k,并将其形状重塑为 (b*n, k, 1, 2) 的张量。
# grid-sample -> (b*n, c, k, 1)
# reshape -> (b, n, c', num, k)
num_grid_points = self.grid_size[0] * self.grid_size[1]
sample_feats = F.grid_sample(
unfold_images, sample_points, mode='nearest').reshape(b, n, c, num_grid_points, k)
# permute -> (b, n, k, grid_points, C)
sample_feats = sample_feats.permute(0, 1, 4, 3, 2)
"""
这部分代码使用 F.grid_sample 函数对展开的图像特征 unfold_images 进行采样,
根据 sample_points 中的采样点在图像上提取对应的特征。采样模式设为 'nearest',即最近邻采样。
然后,将结果重新形状为 (b, n, c, num_grid_points, k) 的张量,并通过 permute 操作对维度进行重新排列,
得到 (b, n, k, num_grid_points, c) 的张量。
"""
return sample_feats, sample_mask.detach()
class GeometryKernelAttention(nn.Module):
def __init__(
self,
feat_height: int,
feat_width: int,
feat_dim: int,
dim: int,
bev_z: int,
kernel_h: int,
kernel_w: int,
image_height: int,
image_width: int,
qkv_bias: bool,
heads: int = 4,
dim_head: int = 32,
no_image_features: bool = False,
skip: bool = True,
sampling_type: str = "index",
use_kernel_conv: bool = True,
kernel_conv_h: int = 1,
kernel_conv_w: int = 7
):
super().__init__()
# 1 1 3 h w -> 1 1 3 56 120
image_plane = generate_grid(feat_height, feat_width)[None]
image_plane[:, :, 0] *= image_width
image_plane[:, :, 1] *= image_height
self.register_buffer('image_plane', image_plane, persistent=False)
if sampling_type == "unfold":
self.sampling = UnfoldBEVProjector(
(image_height, image_width), grid_size=(kernel_h, kernel_w), height=bev_z)
elif sampling_type == "index":
self.sampling = IndexBEVProjector(
(image_height, image_width), grid_size=(kernel_h, kernel_w), height=bev_z)
else:
raise NotImplementedError()
self.feature_linear = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, dim, bias=False)
)
if no_image_features:
self.feature_proj = None
else:
self.feature_proj = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, dim, bias=False)
)
if use_kernel_conv:
self.conv = nn.Conv2d(
feat_dim, feat_dim, (kernel_conv_h, kernel_conv_w),
padding=(kernel_conv_h // 2, kernel_conv_w // 2))
else:
self.conv = lambda x: x
self.bev_embed = nn.Conv2d(2, dim, 1)
self.img_embed = nn.Linear(4, dim, bias=False)
self.cam_embed = nn.Conv2d(4, dim, 1, bias=False)
self.cross_attn = KernelAttention(dim, heads, dim_head, qkv_bias)
self.skip = skip
def forward(
self,
x: torch.FloatTensor,
bev: BEVEmbedding,
feature: torch.FloatTensor,
I_inv: torch.FloatTensor,
E_inv: torch.FloatTensor,
I_: torch.FloatTensor,
E_: torch.FloatTensor
):
"""
x: (b, c, H, W)
feature: (b, n, dim_in, h, w)
I_inv: (b, n, 3, 3)
E_inv: (b, n, 4, 4)
Returns: (b, d, H, W)
"""
# E_inv[0][:][:,0:3,-1] = 0.0
b, n, _, _, _ = feature.shape
# b n 3 h w
pixel = self.image_plane
_, _, _, h, w = pixel.shape
"""
通过从 E_inv 张量中提取最后一个维度的xyz平移元素c。 c_embed (6x128x1x1)
"""
# b n 4 1
c = E_inv[..., -1:]
# (b n) 4 1 1
c_flat = rearrange(c, 'b n ... -> (b n) ...')[..., None]
# (b n) d 1 1
# embed()
c_embed = self.cam_embed(c_flat)
# 1 1 3 (h w) =(1,1,3,6720)
pixel_flat = rearrange(pixel, '... h w -> ... (h w)')
# b n 3 (h w)
cam = I_inv @ pixel_flat
cam = F.pad(cam, (0, 0, 0, 1, 0, 0, 0, 0), value=1) #第三维度末尾添加一个填充为 1 的值
# b n 4 (h w) = (1,6,4,6720)
d = E_inv @ cam
# (b n) 4 h w = (6,4,56,120)
d_flat = rearrange(d, 'b n d (h w) -> (b n) d h w', h=h, w=w)
"""
从 bev.grid 张量中提取鸟瞰图坐标 world 的前两个维度,大小为 (2, H, W)。
鸟瞰图嵌入 w_embed大小为 (1, d, H, W)。
接下来,将相机嵌入 c_embed 从 w_embed 中减去,得到鸟瞰图嵌入 bev_embed,大小为 (b*n, d, H, W)。
"""
# 2 H W
world = bev.grid[:2]
# 1 d H W
w_embed = self.bev_embed(world[None])
# (b n) d H W
bev_embed = w_embed - c_embed #可学习的bev view坐标编码 - 相机位置编码 将两者的差异作为Q来学
bev_embed = bev_embed / (bev_embed.norm(dim=1, keepdim=True) + 1e-7)
# b n d H W
query_pos = rearrange(bev_embed, '(b n) ... -> b n ...', b=b, n=n)
# (b n) d h w
feature_flat = rearrange(feature, 'b n ... -> (b n) ...')
feature_flat = self.conv(feature_flat)
# project local patches using sampling
# concat feature and embeddings for sampling
d_feature = feature_flat.shape[1]
feature_embed = torch.cat([feature_flat, d_flat], dim=1)
feature_embed, mask = self.sampling(
bev.grid.detach().clone(), feature_embed, I_, E_)
"""
从 feature_embed 中分离出特征部分和 d(相机坐标)部分,分别赋值给 feature_flat 和 d_flat。
"""
# b, n, q, num_points, c
feature_flat = feature_embed[..., :d_feature]
d_flat = feature_embed[..., d_feature:]
"""
通过卷积层 self.img_embed 将相机坐标 d_flat 转换为图像嵌入 d_embed。
然后,将相机嵌入 c_embed 视图重塑为与 d_embed 相同的形状,并计算图像嵌入 img_embed,并进行归一化。
"""
# (b n) q, num_points, 4
d_embed = self.img_embed(d_flat)
# d_embed: b, n, q, num_points, d
# c_embed: (b, n), d, 1, 1
img_embed = d_embed - c_embed.view(b, n, 1, 1, d_embed.shape[-1]) # 图像平面转到世界系下的坐标减去平移向量是一个可学习的方向向量
# embed()
img_embed = img_embed / (img_embed.norm(dim=-1, keepdim=True) + 1e-7)
# g: num_grid_points
# b, n, q, g, c
if self.feature_proj is not None:
key_flat = img_embed + self.feature_proj(feature_flat) #编码过的从相机原点到深度1的图像平面的方向向量(camera-aware positional embeddings)+ 图像特征
else:
# (b, n) d, h, w
key_flat = img_embed
# (b, n) d, h, w
val_flat = self.feature_linear(feature_flat) # 图像特征
# Expand + refine the BEV embedding
# b, n, d, H, W
query = query_pos + x[:, None]
return self.cross_attn(query, key_flat, val_flat, mask=mask, skip=x if self.skip else None)
class GeometryKernelEncoder(nn.Module):
def __init__(
self,
backbone,
cross_view: dict,
bev_embedding: dict,
dim: int = 128,
middle: List[int] = [2, 2],
scale: float = 1.0,
):
super().__init__()
self.norm = Normalize()
self.backbone = backbone
if scale < 1.0:
self.down = lambda x: F.interpolate(
x, scale_factor=scale, recompute_scale_factor=False)
else:
self.down = lambda x: x
assert len(self.backbone.output_shapes) == len(middle)
cross_views = list()
layers = list()
for feat_shape, num_layers in zip(self.backbone.output_shapes, middle):
_, feat_dim, feat_height, feat_width = self.down(
torch.zeros(feat_shape)).shape
cva = GeometryKernelAttention(
feat_height, feat_width, feat_dim, dim, **cross_view)
cross_views.append(cva)
layer = nn.Sequential(*[ResNetBottleNeck(dim)
for _ in range(num_layers)])
layers.append(layer)
self.bev_embedding = BEVEmbedding(dim, **bev_embedding)
self.cross_views = nn.ModuleList(cross_views)
self.layers = nn.ModuleList(layers)
def forward(self, batch):
b, n, _, _, _ = batch['image'].shape
# b n c h w
image = batch['image'].flatten(0, 1)
# b n 3 3
I_inv = batch['intrinsics'].inverse()
# b n 4 4
E_inv = batch['extrinsics'].inverse()
features = [self.down(y) for y in self.backbone(self.norm(image))]
# d H W
x = self.bev_embedding.get_prior()
# b d H W
x = repeat(x, '... -> b ...', b=b)
for cross_view, feature, layer in zip(self.cross_views, features, self.layers):
feature = rearrange(feature, '(b n) ... -> b n ...', b=b, n=n)
x = cross_view(x, self.bev_embedding, feature, I_inv,
E_inv, batch['intrinsics'], batch['extrinsics'])
x = layer(x)
return x
GKT--Efficient and Robust 2D-to-BEV Representation Learning via Geometry-guided Kernel Transformer代码
最新推荐文章于 2024-06-22 18:58:20 发布