DPRNN使用的loss函数是 SI-SNR
SI-SNR 是scale-invariant source-to-noise ratio的缩写,中文翻译为尺度不变的信噪比,意思是不受信号变化影响的信噪比
公式
公式如下:
{
s
t
a
r
g
e
t
=
⟨
s
^
,
s
⟩
s
∣
∣
s
∣
∣
2
e
n
o
i
s
e
=
s
^
−
s
t
a
r
g
e
t
S
I
S
N
R
=
10
l
o
g
10
∣
∣
s
t
a
r
g
e
t
∣
∣
2
∣
∣
e
n
o
i
s
e
∣
∣
2
\begin{cases} s_{target} = \cfrac {\lang{\hat s,s}\rangle s} {||s||^2} \\ e_{noise} = \hat s - s_{target} \\ SISNR = 10 log_{10} \cfrac {||s_{target}||^2} {||e_{noise}||^2} \end{cases}
⎩⎪⎪⎪⎪⎨⎪⎪⎪⎪⎧starget=∣∣s∣∣2⟨s^,s⟩senoise=s^−stargetSISNR=10log10∣∣enoise∣∣2∣∣starget∣∣2
其中
s
^
\hat s
s^是评估信号,
s
s
s是纯净信号;
⟨
s
^
,
s
⟩
\lang{\hat s,s}\rangle
⟨s^,s⟩是元素乘积再求和运算
∣
∣
s
∣
∣
2
||s||^2
∣∣s∣∣2是L2norm(2范数),它相当于
⟨
s
,
s
⟩
\lang{s,s}\rangle
⟨s,s⟩
2范数公式如下 ∣ ∣ x ∣ ∣ 2 = ∑ 0 n x i 2 ||x||_2= \sqrt{ \displaystyle\sum_0^nx_i^2} ∣∣x∣∣2=0∑nxi2,简单地理解它为二维空间所有点到圆心的距离
SNR是纯净信号与噪音的声强的比,而SISNR是通过正则化消减信号变化导致的影响。
在DPRNN源码中,信号s同样做了“特殊处理”,先是将s减去平均值,然后再套用公式计算。
源码如下
def sisnr(x, s, eps=1e-8):
"""
calculate training loss
input:
x: separated signal, N x S tensor
s: reference signal, N x S tensor
Return:
sisnr: N tensor
"""
def l2norm(mat, keepdim=False):
return torch.norm(mat, dim=-1, keepdim=keepdim)
if x.shape != s.shape:
raise RuntimeError(
"Dimention mismatch when calculate si-snr, {} vs {}".format(
x.shape, s.shape))
x_zm = x - torch.mean(x, dim=-1, keepdim=True)
s_zm = s - torch.mean(s, dim=-1, keepdim=True)
t = torch.sum(
x_zm * s_zm, dim=-1,
keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))