在计算机视觉领域,特别是在人脸识别任务中,"probe"指的是探针集(Probe Set),它是一组待识别的人脸图像集合。这个集合中的人脸图像需要通过与已知身份的人脸图像集合(即gallery set或gallery)进行比对,来推断其身份。Probe set通常包含需要进行身份验证或识别的未知人脸图像。在测试阶段,模型的性能评估是根据probe set中元素查询的效果来反映的。简而言之,probe set是用来测试和评估人脸识别系统性能的一组图像,它们需要与gallery set中的图像进行匹配,以确定待识别图像的身份。
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
class SurfaceNormalHead(nn.Module):
def __init__(
self,
feat_dim,
head_type="multiscale",
uncertainty_aware=False,
hidden_dim=512,
kernel_size=1,
):
super().__init__()
self.uncertainty_aware = uncertainty_aware
output_dim = 4 if uncertainty_aware else 3
self.kernel_size = kernel_size
assert head_type in ["linear", "multiscale", "dpt"]
name = f"snorm_{head_type}_k{kernel_size}"
self.name = f"{name}_UA" if uncertainty_aware else name
if head_type == "linear":
self.head = Linear(feat_dim, output_dim, kernel_size)
elif head_type == "multiscale":
self.head = MultiscaleHead(feat_dim, output_dim, hidden_dim, kernel_size)
elif head_type == "dpt":
self.head = DPT(feat_dim, output_dim, hidden_dim, kernel_size)
else:
raise ValueError(f"Unknown head type: {self.head_type}")
def forward(self, feats):
return self.head(feats)
class DepthHead(nn.Module):
def __init__(
self,
feat_dim,
head_type="multiscale",
min_depth=0.001,
max_depth=10,
prediction_type="bindepth",
hidden_dim=512,
kernel_size=1,
):
super().__init__()
self.kernel_size = kernel_size
self.name = f"{prediction_type}_{head_type}_k{kernel_size}"
if prediction_type == "bindepth":
output_dim = 256
self.predict = DepthBinPrediction(min_depth, max_depth, n_bins=output_dim)
elif prediction_type == "sigdepth":
output_dim = 1
self.predict = DepthSigmoidPrediction(min_depth, max_depth)
else:
raise ValueError()
if head_type == "linear":
self.head = Linear(feat_dim, output_dim, kernel_size)
elif head_type == "multiscale":
self.head = MultiscaleHead(feat_dim, output_dim, hidden_dim, kernel_size)
elif head_type == "dpt":
self.head = DPT(feat_dim, output_dim, hidden_dim, kernel_size)
else:
raise ValueError(f"Unknown head type: {self.head_type}")
def forward(self, feats):
"""Prediction each pixel."""
feats = self.head(feats)
depth = self.predict(feats)
return depth
class DepthBinPrediction(nn.Module):
def __init__(
self,
min_depth=0.001,
max_depth=10,
n_bins=256,
bins_strategy="UD",
norm_strategy="linear",
):
super().__init__()
self.n_bins = n_bins
self.min_depth = min_depth
self.max_depth = max_depth
self.norm_strategy = norm_strategy
self.bins_strategy = bins_strategy
def forward(self, prob):
if self.bins_strategy == "UD":
bins = torch.linspace(
self.min_depth, self.max_depth, self.n_bins, device=prob.device
)
elif self.bins_strategy == "SID":
bins = torch.logspace(
self.min_depth, self.max_depth, self.n_bins, device=prob.device
)
# following Adabins, default linear
if self.norm_strategy == "linear":
prob = torch.relu(prob)
eps = 0.1
prob = prob + eps
prob = prob / prob.sum(dim=1, keepdim=True)
elif self.norm_strategy == "softmax":
prob = torch.softmax(prob, dim=1)
elif self.norm_strategy == "sigmoid":
prob = torch.sigmoid(prob)
prob = prob / prob.sum(dim=1, keepdim=True)
depth = torch.einsum("ikhw,k->ihw", [prob, bins])
depth = depth.unsqueeze(dim=1)
return depth
class DepthSigmoidPrediction(nn.Module):
def __init__(self, min_depth=0.001, max_depth=10):
super().__init__()
self.min_depth = min_depth
self.max_depth = max_depth
def forward(self, pred):
depth = pred.sigmoid()
depth = self.min_depth + depth * (self.max_depth - self.min_depth)
return depth
class FeatureFusionBlock(nn.Module):
def __init__(self, features, kernel_size, with_skip=True):
super().__init__()
self.with_skip = with_skip
if self.with_skip:
self.resConfUnit1 = ResidualConvUnit(features, kernel_size)
self.resConfUnit2 = ResidualConvUnit(features, kernel_size)
def forward(self, x, skip_x=None):
if skip_x is not None:
assert self.with_skip and skip_x.shape == x.shape
x = self.resConfUnit1(x) + skip_x
x = self.resConfUnit2(x)
return x
class ResidualConvUnit(nn.Module):
def __init__(self, features, kernel_size):
super().__init__()
assert kernel_size % 1 == 0, "Kernel size needs to be odd"
padding = kernel_size // 2
self.conv = nn.Sequential(
nn.Conv2d(features, features, kernel_size, padding=padding),
nn.ReLU(True),
nn.Conv2d(features, features, kernel_size, padding=padding),
nn.ReLU(True),
)
def forward(self, x):
return self.conv(x) + x
class DPT(nn.Module):
def __init__(self, input_dims, output_dim, hidden_dim=512, kernel_size=3):
super().__init__()
assert len(input_dims) == 4
self.conv_0 = nn.Conv2d(input_dims[0], hidden_dim, 1, padding=0)
self.conv_1 = nn.Conv2d(input_dims[1], hidden_dim, 1, padding=0)
self.conv_2 = nn.Conv2d(input_dims[2], hidden_dim, 1, padding=0)
self.conv_3 = nn.Conv2d(input_dims[3], hidden_dim, 1, padding=0)
self.ref_0 = FeatureFusionBlock(hidden_dim, kernel_size)
self.ref_1 = FeatureFusionBlock(hidden_dim, kernel_size)
self.ref_2 = FeatureFusionBlock(hidden_dim, kernel_size)
self.ref_3 = FeatureFusionBlock(hidden_dim, kernel_size, with_skip=False)
self.out_conv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.ReLU(True),
nn.Conv2d(hidden_dim, output_dim, 3, padding=1),
)
def forward(self, feats):
"""Prediction each pixel."""
assert len(feats) == 4
feats[0] = self.conv_0(feats[0])
feats[1] = self.conv_1(feats[1])
feats[2] = self.conv_2(feats[2])
feats[3] = self.conv_3(feats[3])
feats = [interpolate(x, scale_factor=2) for x in feats]
out = self.ref_3(feats[3], None)
out = self.ref_2(feats[2], out)
out = self.ref_1(feats[1], out)
out = self.ref_0(feats[0], out)
out = interpolate(out, scale_factor=4)
out = self.out_conv(out)
out = interpolate(out, scale_factor=2)
return out
def make_conv(input_dim, hidden_dim, output_dim, num_layers, kernel_size=1):
if num_layers == 1:
conv = nn.Conv2d(input_dim, output_dim, kernel_size)
else:
assert num_layers > 1
modules = [nn.Conv2d(input_dim, hidden_dim, kernel_size), nn.ReLU(inplace=True)]
for i in range(num_layers - 2):
modules.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size))
modules.append(nn.ReLU(inplace=True))
modules.append(nn.Conv2d(hidden_dim, output_dim, kernel_size))
conv = nn.Sequential(*modules)
return conv
class Linear(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=1):
super().__init__()
if type(input_dim) is not int:
input_dim = sum(input_dim)
assert type(input_dim) is int
padding = kernel_size // 2
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding)
def forward(self, feats):
if type(feats) is list:
feats = torch.cat(feats, dim=1)
feats = interpolate(feats, scale_factor=4, mode="bilinear")
return self.conv(feats)
class MultiscaleHead(nn.Module):
def __init__(self, input_dims, output_dim, hidden_dim=512, kernel_size=1):
super().__init__()
self.convs = nn.ModuleList(
[make_conv(in_d, None, hidden_dim, 1, kernel_size) for in_d in input_dims]
)
interm_dim = len(input_dims) * hidden_dim
self.conv_mid = make_conv(interm_dim, hidden_dim, hidden_dim, 3, kernel_size)
self.conv_out = make_conv(hidden_dim, hidden_dim, output_dim, 2, kernel_size)
def forward(self, feats):
num_feats = len(feats)
feats = [self.convs[i](feats[i]) for i in range(num_feats)]
h, w = feats[-1].shape[-2:]
feats = [interpolate(feat, (h, w), mode="bilinear") for feat in feats]
feats = torch.cat(feats, dim=1).relu()
# upsample
feats = interpolate(feats, scale_factor=2, mode="bilinear")
feats = self.conv_mid(feats).relu()
feats = interpolate(feats, scale_factor=4, mode="bilinear")
return self.conv_out(feats)