文章目录
ps: 前半部分代码均为pytorch
版本。最后附上uPIT-SiSNR的tensorflow
版本。
损失函数,也是模型训练中非常重要的一块。
常见损失函数:
- 语音分离: uPIT-SiSNR
- 语音增强:l1, mse
损失函数示例:
语音分离
【SepFormer】:uPIT-SiSNR(https://github.com/speechbrain/speechbrain)
【DuralPath RNN】:SiSNR
【TransMask】:SiSNR
【Conv-Tasnet】:SiSNR(https://github.com/kaituoxu/Conv-TasNet)
音乐分离
【Demucs】:l1(https://github.com/facebookresearch/demucs)
语音降噪
【Denoiser】:l1 ,stft_loss(https://github.com/facebookresearch/denoiser)
【Phasen】:SiSNR或mag_spec (https://github.com/huyanxin/phasen/blob/master/model/phasen.py)
【Transformer】l1
【Conformer】l1
【DCCRN】SiSNR(https://github.com/huyanxin/DeepComplexCRN,https://huyanxin.github.io/DeepComplexCRN/)
【DCUNet】:wSNR
【DF-Conformer】:SNR
Referance
关于uPIT Si-SNR https://blog.csdn.net/zjuPeco/article/details/106300674
speechbrain 的 losses 代码:【speechbrain/speechbrain/nnet/losses.py
】:https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/losses.py
SNR (Signal-to-Noise Ratio)
ref: https://blog.csdn.net/zjuPeco/article/details/106300674
Si-SNR (Scale invariant Signal-to-Noise Ratio)
也可参见论文中的表述:【Optimal scale-invariant signal-to-noise ratio and curriculum learning for monaural multi-speaker speech separation in noisy environment
】http://www.apsipa.org/proceedings/2020/pdfs/0000711.pdf
可看出,SISNR的定义其实不止一种。
这里以SpeechBrain中代码为例。可参见SpeechBrain的github主页:https://github.com/speechbrain/speechbrain
【speechbrain/speechbrain/nnet/losses.py
】:https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/losses.py
def cal_si_snr(source, estimate_source):
"""Calculate SI-SNR.
Arguments:
---------
source: [T, B, C],
Where B is batch size, T is the length of the sources, C is the number of sources
the ordering is made so that this loss is compatible with the class PitWrapper.
estimate_source: [T, B, C]
The estimated source.
Example:
---------
>>> import numpy as np
>>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
>>> xhat = x[:, (1, 0)]
>>> x = x.unsqueeze(-1).repeat(1, 1, 2)
>>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
>>> si_snr = -cal_si_snr(x, xhat)
>>> print(si_snr)
tensor([[[ 25.2142, 144.1789],
[130.9283, 25.2142]]])
"""
EPS = 1e-8
assert source.size() == estimate_source.size()
device = estimate_source.device.type
source_lengths = torch.tensor(
[estimate_source.shape[0]] * estimate_source.shape[1], device=device
)
mask = get_mask(source, source_lengths)
estimate_source *= mask
num_samples = (
source_lengths.contiguous().reshape(1, -1, 1).float()
) # [1, B, 1]
mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples
mean_estimate = (
torch.sum(estimate_source, dim=0, keepdim=True) / num_samples
)
zero_mean_target = source - mean_target
zero_mean_estimate = estimate_source - mean_estimate
# mask padding position along T
zero_mean_target *= mask
zero_mean_estimate *= mask
# Step 2. SI-SNR with PIT
# reshape to use broadcast
s_target = zero_mean_target # [T, B, C]
s_estimate = zero_mean_estimate # [T, B, C]
# s_target = <s', s>s / ||s||^2
dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True) # [1, B, C]
s_target_energy = (
torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS
) # [1, B, C]
proj = dot * s_target / s_target_energy # [T, B, C]
# e_noise = s' - s_target
e_noise = s_estimate - proj # [T, B, C]
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (
torch.sum(e_noise ** 2, dim=0) + EPS
)
si_snr = 10 * torch.log10(si_snr_beforelog + EPS) # [B, C]
return -si_snr.unsqueeze(0)
def get_mask(source, source_lengths):
"""
Arguments
---------
source : [T, B, C]
source_lengths : [B]
Returns
-------
mask : [T, B, 1]
Example:
---------
>>> source = torch.randn(4, 3, 2)
>>> source_lengths = torch.Tensor([2, 1, 4]).int()
>>> mask = get_mask(source, source_lengths)
>>> print(mask)
tensor([[[1.],
[1.],
[1.]],
<BLANKLINE>
[[1.],
[0.],
[1.]],
<BLANKLINE>
[[0.],
[0.],
[1.]],
<BLANKLINE>
[[0.],
[0.],
[1.]]])
"""
T, B, _ = source.size()
mask = source.new_ones((T, B, 1))
for i in range(B):
mask[source_lengths[i] :, i, :] = 0
return mask
值得注意的是,这里的s_target 和s_estimate均减去了平均值。同时,为了防止出现除法分母为0的错误,加上了EPS
。
PIT (Permutation Invariant Training)
PIT是一种训练的方法,全称为Permutation Invariant Training。这种训练方式就可以end-to-end去训练,总体思想很直觉,就是我先随便假设一个speakers对应于输出的的顺序,稍微train几下,得到一个model。然后,下一次train的时候,我会算两次SI-SDR之类的评价指标,分别是红1,蓝2和蓝1,红2,然后把Loss小的那个作为排序,然后按这个顺序train下去
uPIT (utterance-level PIT)
uPIT 相当于在上述所有permutation组合的情况中,找一种最优的输出。
实现代码:
class PitWrapper(nn.Module):
"""
Permutation Invariant Wrapper to allow Permutation Invariant Training
(PIT) with existing losses.
Permutation invariance is calculated over the sources/classes axis which is
assumed to be the rightmost dimension: predictions and targets tensors are
assumed to have shape [batch, ..., channels, sources].
Arguments
---------
base_loss : function
Base loss function, e.g. torch.nn.MSELoss. It is assumed that it takes
two arguments:
predictions and targets and no reduction is performed.
(if a pytorch loss is used, the user must specify reduction="none").
Returns
---------
pit_loss : torch.nn.Module
Torch module supporting forward method for PIT.
Example
-------
>>> pit_mse = PitWrapper(nn.MSELoss(reduction="none"))
>>> targets = torch.rand((2, 32, 4))
>>> p = (3, 0, 2, 1)
>>> predictions = targets[..., p]
>>> loss, opt_p = pit_mse(predictions, targets)
>>> loss
tensor([0., 0.])
"""
def __init__(self, base_loss):
super(PitWrapper, self).__init__()
self.base_loss = base_loss
def _fast_pit(self, loss_mat):
"""
Arguments
----------
loss_mat : torch.Tensor
Tensor of shape [sources, source] containing loss values for each
possible permutation of predictions.
Returns
-------
loss : torch.Tensor
Permutation invariant loss for the current batch, tensor of shape [1]
assigned_perm : tuple
Indexes for optimal permutation of the input over sources which
minimizes the loss.
"""
loss = None
assigned_perm = None
for p in permutations(range(loss_mat.shape[0])):
c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
# return loss_mat[range(loss_mat.shape[0]), p][0], p
#########################################################
### IMPORTANT ###########################################
if loss is None or loss > c_loss:
loss = c_loss
assigned_perm = p
#########################################################
return loss, assigned_perm
def _opt_perm_loss(self, pred, target):
"""
Arguments
---------
pred : torch.Tensor
Network prediction for the current example, tensor of
shape [..., sources].
target : torch.Tensor
Target for the current example, tensor of shape [..., sources].
Returns
-------
loss : torch.Tensor
Permutation invariant loss for the current example, tensor of shape [1]
assigned_perm : tuple
Indexes for optimal permutation of the input over sources which
minimizes the loss.
"""
n_sources = pred.size(-1)
pred = pred.unsqueeze(-2).repeat(
*[1 for x in range(len(pred.shape) - 1)], n_sources, 1
)
target = target.unsqueeze(-1).repeat(
1, *[1 for x in range(len(target.shape) - 1)], n_sources
)
loss_mat = self.base_loss(pred, target)
assert (
len(loss_mat.shape) >= 2
), "Base loss should not perform any reduction operation"
mean_over = [x for x in range(len(loss_mat.shape))]
loss_mat = loss_mat.mean(dim=mean_over[:-2])
return self._fast_pit(loss_mat)
def reorder_tensor(self, tensor, p):
"""
Arguments
---------
tensor : torch.Tensor
Tensor to reorder given the optimal permutation, of shape
[batch, ..., sources].
p : list of tuples
List of optimal permutations, e.g. for batch=2 and n_sources=3
[(0, 1, 2), (0, 2, 1].
Returns
-------
reordered : torch.Tensor
Reordered tensor given permutation p.
"""
reordered = torch.zeros_like(tensor, device=tensor.device)
for b in range(tensor.shape[0]):
reordered[b] = tensor[b][..., p[b]].clone()
return reordered
def forward(self, preds, targets):
"""
Arguments
---------
preds : torch.Tensor
Network predictions tensor, of shape
[batch, channels, ..., sources].
targets : torch.Tensor
Target tensor, of shape [batch, channels, ..., sources].
Returns
-------
loss : torch.Tensor
Permutation invariant loss for current examples, tensor of
shape [batch]
perms : list
List of indexes for optimal permutation of the inputs over
sources.
e.g., [(0, 1, 2), (2, 1, 0)] for three sources and 2 examples
per batch.
"""
losses = []
perms = []
for pred, label in zip(preds, targets):
loss, p = self._opt_perm_loss(pred, label)
perms.append(p)
losses.append(loss)
loss = torch.stack(losses)
return loss, perms
其中,permulations类是python中一个枚举所有permulations的类:
如我们在做语音分离时,分离2路语音时,输出的两路可能对应的是分别是(第1个讲话人,第2个讲话人)或(第2个讲话人,第1个讲话人)。
同理,分离3路语音时,可能就存在6种permutations的组合。
将这个问题泛化即为,输入m路语音,输出n路语音 (n<=m), 那么可用class permutations来做: perms = permutations(range(m), n)
在上述uPIT-SiSNR中的具体用法为:
for p in permutations(range(loss_mat.shape[0])):
c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
# return loss_mat[range(loss_mat.shape[0]), p][0], p
#########################################################
### IMPORTANT ###########################################
if loss is None or loss > c_loss:
loss = c_loss
assigned_perm = p
#########################################################
return loss, assigned_perm
class permutations(object):
"""
permutations(iterable[, r]) --> permutations object
Return successive r-length permutations of elements in the iterable.
permutations(range(3), 2) --> (0,1), (0,2), (1,0), (1,2), (2,0), (2,1)
"""
def __getattribute__(self, *args, **kwargs): # real signature unknown
""" Return getattr(self, name). """
pass
def __init__(self, iterable, r=None): # real signature unknown; restored from __doc__
pass
def __iter__(self, *args, **kwargs): # real signature unknown
""" Implement iter(self). """
pass
@staticmethod # known case of __new__
def __new__(*args, **kwargs): # real signature unknown
""" Create and return a new object. See help(type) for accurate signature. """
pass
def __next__(self, *args, **kwargs): # real signature unknown
""" Implement next(self). """
pass
def __reduce__(self, *args, **kwargs): # real signature unknown
""" Return state information for pickling. """
pass
def __setstate__(self, *args, **kwargs): # real signature unknown
""" Set state information for unpickling. """
pass
def __sizeof__(self, *args, **kwargs): # real signature unknown
""" Returns size in memory, in bytes. """
pass
这里的base_function
可为指定的loss function,比如我们这里的SiSNR。具体实现方式如下。
uPIT-SiSNR
Pytorch实现代码
def get_si_snr_with_pitwrapper(source, estimate_source):
"""This function wraps si_snr calculation with the speechbrain pit-wrapper.
Arguments:
---------
source: [B, T, C],
Where B is the batch size, T is the length of the sources, C is
the number of sources the ordering is made so that this loss is
compatible with the class PitWrapper.
estimate_source: [B, T, C]
The estimated source.
Example:
---------
>>> x = torch.arange(600).reshape(3, 100, 2)
>>> xhat = x[:, :, (1, 0)]
>>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
>>> print(si_snr)
tensor([135.2284, 135.2284, 135.2284])
"""
pit_si_snr = PitWrapper(cal_si_snr)
loss, perms = pit_si_snr(source, estimate_source)
return loss
Tensorflow实现代码
在此实现上述的uPIT-SiSNR的tensorflow版本(v2.40)
注:我们基于keras的Loss基类函数进行继承。其中部分注释为原pytorch版本的代码。
pytorch与tensorflow的部分对比如下:
# pytorch
c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
# tensorflow
c_loss = tf.reduce_mean([loss_mat[i][p[i]] for i in range(loss_mat.shape[0])])
# pytorch
pred = pred.unsqueeze(-2).repeat(
*[1 for x in range(len(pred.shape) - 1)], n_sources, 1
)
# tensorflow
pred = tf.tile(tf.expand_dims(pred, axis=-2), [len([1 for x in range(len(pred.shape) - 1)]), n_sources, 1])
# pytorch
loss_mat = loss_mat.mean(dim=mean_over[:-2])
# tensorflow
loss_mat = tf.reduce_mean(loss_mat, axis=mean_over[:-2])
# pytorch
def forward(self, preds, targets):
# tensorflow
def call(self, preds, targets):
# pytorch
-si_snr.unsqueeze(0)
# tensorflow
-tf.expand_dims(si_snr, 0)
# pytorch
x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
xhat = x[:, (1, 0)]
# tensorflow
x = tf.constant([[1, 0], [123, 45], [34, 5], [2312, 421]], dtype=float)
xhat = tf.slice(x, [0, 1], [x.shape[0], 1])
xhat = tf.concat([xhat, tf.slice(x, [0, 0], [x.shape[0], 1])], axis=1)
# pytorch
xhat = x[:, :, (1, 0)]
# tensorflow
xhat = tf.slice(x, [0, 0, 1], [x.shape[0], x.shape[1], 1])
xhat = tf.concat([xhat, tf.slice(x, [0, 0, 0], [x.shape[0], x.shape[1], 1])], axis=2)
代码如下:
import tensorflow as tf
from tensorflow.python.keras.losses import Loss, mse
from itertools import permutations
from tensorflow.python.keras.utils import losses_utils
class PitWrapper(Loss):
"""
Permutation Invariant Wrapper to allow Permutation Invariant Training
(PIT) with existing losses.
Permutation invariance is calculated over the sources/classes axis which is
assumed to be the rightmost dimension: predictions and targets tensors are
assumed to have shape [batch, ..., channels, sources].
Arguments
---------
base_loss : function
Base loss function, e.g. torch.nn.MSELoss. It is assumed that it takes
two arguments:
predictions and targets and no reduction is performed.
(if a pytorch loss is used, the user must specify reduction="none").
Returns
---------
pit_loss : torch.nn.Module
Torch module supporting forward method for PIT.
Example
-------
>>> pit_mse = PitWrapper(nn.MSELoss(reduction="none"))
>>> targets = torch.rand((2, 32, 4))
>>> p = (3, 0, 2, 1)
>>> predictions = targets[..., p]
>>> loss, opt_p = pit_mse(predictions, targets)
>>> loss
tensor([0., 0.])
"""
def __init__(self, base_loss):
super().__init__()
self.reduction = losses_utils.ReductionV2.NONE ## IMPORTANT ##
self.base_loss = base_loss
def _fast_pit(self, loss_mat):
"""
Arguments
----------
loss_mat : torch.Tensor
Tensor of shape [sources, source] containing loss values for each
possible permutation of predictions.
Returns
-------
loss : torch.Tensor
Permutation invariant loss for the current batch, tensor of shape [1]
assigned_perm : tuple
Indexes for optimal permutation of the input over sources which
minimizes the loss.
"""
loss = None
assigned_perm = None
for p in permutations(range(loss_mat.shape[0])):
c_loss = tf.reduce_mean([loss_mat[i][p[i]] for i in range(loss_mat.shape[0])]) # loss_mat[range(loss_mat.shape[0]), p].mean()
if loss is None or loss > c_loss:
loss = c_loss
assigned_perm = p
return loss, assigned_perm
def _opt_perm_loss(self, pred, target):
"""
Arguments
---------
pred : torch.Tensor
Network prediction for the current example, tensor of
shape [..., sources].
target : torch.Tensor
Target for the current example, tensor of shape [..., sources].
Returns
-------
loss : torch.Tensor
Permutation invariant loss for the current example, tensor of shape [1]
assigned_perm : tuple
Indexes for optimal permutation of the input over sources which
minimizes the loss.
"""
n_sources = pred.shape[-1] #pred.size(-1)
# pred = pred.unsqueeze(-2).repeat(
# *[1 for x in range(len(pred.shape) - 1)], n_sources, 1
# )
pred = tf.tile(tf.expand_dims(pred, axis=-2), [len([1 for x in range(len(pred.shape) - 1)]), n_sources, 1])
# target = target.unsqueeze(-1).repeat(
# 1, *[1 for x in range(len(target.shape) - 1)], n_sources
# )
target = tf.tile(tf.expand_dims(target, axis=-1), [1, len([1 for x in range(len(target.shape) - 1)]), n_sources])
loss_mat = self.base_loss(pred, target)
assert (
len(loss_mat.shape) >= 2
), "Base loss should not perform any reduction operation"
mean_over = [x for x in range(len(loss_mat.shape))]
# loss_mat = loss_mat.mean(dim=mean_over[:-2])
loss_mat = tf.reduce_mean(loss_mat, axis=mean_over[:-2])
return self._fast_pit(loss_mat)
def reorder_tensor(self, tensor, p):
"""
Arguments
---------
tensor : torch.Tensor
Tensor to reorder given the optimal permutation, of shape
[batch, ..., sources].
p : list of tuples
List of optimal permutations, e.g. for batch=2 and n_sources=3
[(0, 1, 2), (0, 2, 1].
Returns
-------
reordered : torch.Tensor
Reordered tensor given permutation p.
"""
reordered = tf.zeros_like(tensor, device=tensor.device)
for b in range(tensor.shape[0]):
reordered[b] = tensor[b][..., p[b]].clone()
return reordered
def call(self, preds, targets): #forward(self, preds, targets):
"""
Arguments
---------
preds : torch.Tensor
Network predictions tensor, of shape
[batch, channels, ..., sources].
targets : torch.Tensor
Target tensor, of shape [batch, channels, ..., sources].
Returns
-------
loss : torch.Tensor
Permutation invariant loss for current examples, tensor of
shape [batch]
perms : list
List of indexes for optimal permutation of the inputs over
sources.
e.g., [(0, 1, 2), (2, 1, 0)] for three sources and 2 examples
per batch.
"""
losses = []
perms = []
for pred, label in zip(preds, targets):
loss, p = self._opt_perm_loss(pred, label)
perms.append(p)
losses.append(loss)
loss = tf.stack(losses)
return loss #, perms # todo?
def get_si_snr_with_pitwrapper(source, estimate_source):
"""This function wraps si_snr calculation with the speechbrain pit-wrapper.
Arguments:
---------
source: [B, T, C],
Where B is the batch size, T is the length of the sources, C is
the number of sources the ordering is made so that this loss is
compatible with the class PitWrapper.
estimate_source: [B, T, C]
The estimated source.
Example:
---------
>>> x = torch.arange(600).reshape(3, 100, 2)
>>> xhat = x[:, :, (1, 0)]
>>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
>>> print(si_snr)
tensor([135.2284, 135.2284, 135.2284])
"""
pit_si_snr = PitWrapper(cal_si_snr)
loss = pit_si_snr(estimate_source, source) # , perms
return loss
def get_mse_with_pitwrapper(source, estimate_source):
pit_si_snr = PitWrapper(mse)
loss = pit_si_snr(estimate_source, source) # , perms
return loss
def cal_si_snr(estimate_source, source):
"""Calculate SI-SNR.
Arguments:
---------
source: [T, B, C],
Where B is batch size, T is the length of the sources, C is the number of sources
the ordering is made so that this loss is compatible with the class PitWrapper.
estimate_source: [T, B, C]
The estimated source.
Example:
---------
>>> import numpy as np
>>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
>>> xhat = x[:, (1, 0)]
>>> x = x.unsqueeze(-1).repeat(1, 1, 2)
>>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
>>> si_snr = -cal_si_snr(x, xhat)
>>> print(si_snr)
tensor([[[ 25.2142, 144.1789],
[130.9283, 25.2142]]])
"""
EPS = 1e-8
source_lengths = tf.constant([estimate_source.shape[0]] * estimate_source.shape[1])
mask = get_mask(source, source_lengths)
estimate_source = tf.multiply(estimate_source, mask)
# num_samples = tf.reshape(source_lengths, [1, -1, 1])
mean_target = tf.math.reduce_mean(source, axis=0, keepdims=True) #/ num_samples
mean_estimate = (
tf.math.reduce_mean(estimate_source, axis=0, keepdims=True) #/ num_samples
)
zero_mean_target = source - mean_target
zero_mean_estimate = estimate_source - mean_estimate
# mask padding position along T
zero_mean_target *= mask
zero_mean_estimate *= mask
# Step 2. SI-SNR with PIT
# reshape to use broadcast
s_target = zero_mean_target # [T, B, C]
s_estimate = zero_mean_estimate # [T, B, C]
# s_target = <s', s>s / ||s||^2
dot = tf.math.reduce_sum(s_estimate * s_target, axis=0, keepdims=True) # [1, B, C]
s_target_energy = (
tf.math.reduce_sum(s_target ** 2, axis=0, keepdims=True) + EPS
) # [1, B, C]
proj = dot * s_target / s_target_energy # [T, B, C]
# e_noise = s' - s_target
e_noise = s_estimate - proj # [T, B, C]
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
si_snr_beforelog = tf.math.reduce_sum(proj ** 2, axis=0) / (
tf.reduce_sum(e_noise ** 2, axis=0) + EPS
)
si_snr = 10 * tf.math.log(si_snr_beforelog + EPS) / tf.math.log(10.0) # [B, C]
return -tf.expand_dims(si_snr, 0) #-si_snr.unsqueeze(0)
def get_mask(source, source_lengths):
"""
Arguments
---------
source : [T, B, C]
source_lengths : [B]
Returns
-------
mask : [T, B, 1]
Example:
---------
>>> source = torch.randn(4, 3, 2)
>>> source_lengths = torch.Tensor([2, 1, 4]).int()
>>> mask = get_mask(source, source_lengths)
>>> print(mask)
tensor([[[1.],
[1.],
[1.]],
<BLANKLINE>
[[1.],
[0.],
[1.]],
<BLANKLINE>
[[0.],
[0.],
[1.]],
<BLANKLINE>
[[0.],
[0.],
[1.]]])
"""
T, B, _ = source.shape
mask = None
for i in range(B):
if mask is None:
mask = tf.concat([tf.ones(shape=[source_lengths[i], 1, 1], dtype=source.dtype),
tf.zeros(shape=[T - source_lengths[i], 1, 1], dtype=source.dtype)], axis=0)
else:
mask_i = tf.concat([tf.ones(shape=[source_lengths[i], 1, 1], dtype=source.dtype),
tf.zeros(shape=[T - source_lengths[i], 1, 1], dtype=source.dtype)], axis=0)
mask = tf.concat([mask, mask_i], axis=1)
return mask
def test_get_mask():
'''
Example:
---------
>>> source = torch.randn(4, 3, 2)
>>> source_lengths = torch.Tensor([2, 1, 4]).int()
>>> mask = get_mask(source, source_lengths)
>>> print(mask)
tensor([[[1.],
[1.],
[1.]],
<BLANKLINE>
[[1.],
[0.],
[1.]],
<BLANKLINE>
[[0.],
[0.],
[1.]],
<BLANKLINE>
[[0.],
[0.],
[1.]]])
'''
source = tf.random.uniform([4, 3, 2])
source_length = tf.constant([2, 1, 4])
mask = get_mask(source, source_length)
print(mask)
def test_cal_si_snr():
'''
Example:
---------
>>> import numpy as np
>>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
>>> xhat = x[:, (1, 0)]
>>> x = x.unsqueeze(-1).repeat(1, 1, 2)
>>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
>>> si_snr = -cal_si_snr(x, xhat)
>>> print(si_snr)
tensor([[[ 25.2142, 144.1789],
[130.9283, 25.2142]]])
'''
x = tf.constant([[1, 0], [123, 45], [34, 5], [2312, 421]], dtype=float)
xhat = tf.slice(x, [0, 1], [x.shape[0], 1])
xhat = tf.concat([xhat, tf.slice(x, [0, 0], [x.shape[0], 1])], axis=1)
x = tf.expand_dims(x, axis=-1)
x = tf.concat([x, x], axis=-1)
xhat = tf.expand_dims(xhat, axis=1)
xhat = tf.concat([xhat, xhat], axis=1)
si_snr = -cal_si_snr(xhat, x)
print(si_snr)
def test_upit_sisnr():
'''
Example:
---------
>>> x = torch.arange(600).reshape(3, 100, 2)
>>> xhat = x[:, :, (1, 0)]
>>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
>>> print(si_snr)
tensor([135.2284, 135.2284, 135.2284])
'''
x = tf.random.uniform([800], 0, 800) #tf.range(800, dtype=float)
x = tf.reshape(x, [4, 100, 2])
xhat = tf.slice(x, [0, 0, 1], [x.shape[0], x.shape[1], 1])
xhat = tf.concat([xhat, tf.slice(x, [0, 0, 0], [x.shape[0], x.shape[1], 1])], axis=2)
si_snr = -get_si_snr_with_pitwrapper(xhat, x)
print(si_snr)
si_snr = -get_si_snr_with_pitwrapper(x, xhat)
print(si_snr)
# tf.Tensor(135.22835, shape=(), dtype=float32), which is different from the pytorch version by a mean manipulation
# tf.Tensor([135.22835 135.22835 135.22835], shape=(3,), dtype=float32), where reduction is none
if __name__ == '__main__':
# test_get_mask() # ok
# test_cal_si_snr() # ok
# test_upit_sisnr() # ok, somehow: according to reduction
print('done')