用OT来解决人群计数问题
用了OT+count loss + TV loss
证明OT的泛化误差比density map和Bayesian Loss
OT
考虑两个分布
X
=
{
x
i
∣
x
i
∈
R
d
}
i
=
1
n
\mathcal{X}=\left\{\mathbf{x}_i \mid \mathbf{x}_i \in \mathbb{R}^d\right\}_{i=1}^n
X={xi∣xi∈Rd}i=1n,
Y
=
{
y
j
∣
y
j
∈
R
d
}
j
=
1
n
\mathcal{Y}=\left\{\mathbf{y}_j \mid \mathbf{y}_j \in \mathbb{R}^d\right\}_{j=1}^n
Y={yj∣yj∈Rd}j=1n
考虑两个测度
μ
,
ν
\boldsymbol{\mu},\boldsymbol{\nu}
μ,ν,
1
n
T
μ
=
1
n
T
ν
=
1
\mathbf{1}_n^T \boldsymbol{\mu}=\mathbf{1}_n^T \boldsymbol{\nu}=1
1nTμ=1nTν=1
设代价
c
:
X
×
Y
↦
R
+
c: \mathcal{X} \times \mathcal{Y} \mapsto \mathbb{R}_{+}
c:X×Y↦R+
代价矩阵
C
i
j
=
c
(
x
i
,
y
j
)
\mathbf{C}_{ij}=c\left(\mathbf{x}_i,\mathbf{y}_j\right)
Cij=c(xi,yj)
传输矩阵:
Γ
=
{
γ
∈
R
+
n
×
n
:
γ
1
=
μ
,
γ
T
1
=
ν
}
\Gamma=\left\{\boldsymbol{\gamma} \in \mathbb{R}_{+}^{n \times n}: \boldsymbol{\gamma} \mathbf{1}=\boldsymbol{\mu},\boldsymbol{\gamma}^T \mathbf{1}=\boldsymbol{\nu}\right\}
Γ={γ∈R+n×n:γ1=μ,γT1=ν}
OT:
W
(
μ
,
ν
)
=
min
γ
∈
Γ
⟨
C
,
γ
⟩
\mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu})=\min _{\gamma \in \Gamma}\langle\mathbf{C}, \gamma\rangle
W(μ,ν)=γ∈Γmin⟨C,γ⟩
W ( μ , ν ) = max α , β ∈ R n ⟨ α , μ ⟩ + ⟨ β , ν ⟩ s.t. α i + β j ≤ c ( x i , y j ) , ∀ i , j \begin{aligned} \mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu}) & =\max _{\boldsymbol{\alpha}, \boldsymbol{\beta} \in \mathbb{R}^n}\langle\boldsymbol{\alpha}, \boldsymbol{\mu}\rangle+\langle\boldsymbol{\beta}, \boldsymbol{\nu}\rangle\\ &\quad \text { s.t. } \alpha_i+\beta_j \leq c\left(\mathbf{x}_i, \mathbf{y}_j\right), \forall i, j \end{aligned} W(μ,ν)=α,β∈Rnmax⟨α,μ⟩+⟨β,ν⟩ s.t. αi+βj≤c(xi,yj),∀i,j
DM-count
设预测的density map为
z
^
∈
R
+
n
\hat{\mathbf{z}}\in\mathbb{R}_+^n
z^∈R+n
gt的density map为
z
∈
R
+
n
\mathbf{z}\in\mathbb{R}_+^n
z∈R+n
count loss
这里count loss的作用:因为OT算的归一化的density map,他没有数量信息
ℓ
C
(
z
,
z
^
)
=
∣
∥
z
∥
1
−
∥
z
^
∥
1
∣
\ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\| \mathbf{z}\|_1-\| \hat{\mathbf{z}} \|_1 \right|
ℓC(z,z^)=∣∥z∥1−∥z^∥1∣
由于
z
,
z
^
≥
0
,
\mathbf{z},\hat{\mathbf{z}}\ge 0,
z,z^≥0,,可以用求和代替1范数
即
ℓ
C
(
z
,
z
^
)
=
∣
∑
i
=
1
n
z
i
−
∑
i
=
1
n
z
^
i
∣
\ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\sum_{i=1}^n \mathbf{z}_i-\sum_{i=1}^{n}\hat{\mathbf{z}}_i\right|
ℓC(z,z^)=∣∑i=1nzi−∑i=1nz^i∣
OT loss
ℓ
O
T
(
z
,
z
^
)
=
W
(
z
∥
z
∥
1
,
z
^
∥
z
^
∥
1
)
=
⟨
α
∗
,
z
∥
z
∥
1
⟩
+
⟨
β
∗
,
z
^
∥
z
^
∥
1
⟩
\ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})=\mathcal{W}\left(\frac{\mathbf{z}}{\|\mathbf{z}\|_1}, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right)=\left\langle\boldsymbol{\alpha}^*, \frac{\mathbf{z}}{\|\mathbf{z}\|_1}\right\rangle+\left\langle\boldsymbol{\beta}^*, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\rangle
ℓOT(z,z^)=W(∥z∥1z,∥z^∥1z^)=⟨α∗,∥z∥1z⟩+⟨β∗,∥z^∥1z^⟩
其中
α
∗
,
β
∗
\boldsymbol{\alpha}^*,\boldsymbol{\beta}^*
α∗,β∗为OT的对偶问题的最优解
代价矩阵用的是
c
(
z
(
i
)
,
z
^
(
j
)
)
=
∥
z
(
i
)
−
z
^
(
j
)
∥
2
2
c(\mathbf{z}(i), \hat{\mathbf{z}}(j))=\|\mathbf{z}(i)-\hat{\mathbf{z}}(j)\|_2^2
c(z(i),z^(j))=∥z(i)−z^(j)∥22
∂ ℓ O T ( z , z ^ ) ∂ z ^ = β ∗ ∥ z ^ ∥ 1 − ⟨ β ∗ , z ^ ⟩ ∥ z ^ ∥ 1 2 \frac{\partial \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}=\frac{\boldsymbol{\beta}^*}{\|\hat{\mathbf{z}}\|_1}-\frac{\left\langle\boldsymbol{\beta}^*, \hat{\mathbf{z}}\right\rangle}{\|\hat{\mathbf{z}}\|_1^2} ∂z^∂ℓOT(z,z^)=∥z^∥1β∗−∥z^∥12⟨β∗,z^⟩
要注意一个问题,代码里,它的OT loss是
ℓ O T ( z , z ^ ) = ⟨ ∂ ℓ O T ( z , z ^ ) ∂ z ^ , z ^ ⟩ \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})= \left\langle \frac{\partial \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}, \hat{\mathbf{z}}\right\rangle ℓOT(z,z^)=⟨∂z^∂ℓOT(z,z^),z^⟩
https://github.com/cvlab-stonybrook/DM-Count/issues/29
求解OT,用的最原始的sinkhorn(没有log-domain
TV loss
这里主要是为了稳定结果
ℓ T V ( z , z ^ ) = ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ T V = 1 2 ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ 1 \ell_{T V}(\mathbf{z}, \hat{\mathbf{z}})=\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_{T V}=\frac{1}{2}\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_1 ℓTV(z,z^)= ∥z∥1z−∥z^∥1z^ TV=21 ∥z∥1z−∥z^∥1z^ 1
结果
在UCF-QNRF上
作者模型: mae 85.76006602669905, mse 150.3385868782564
我跑的:best_model_7.pth: mae 89.24010239104311, mse 155.59441664755747