基于Mask的音频降噪
参考代码:Noise reduction using spectral gating in python
算法步骤:
- 对音频信号进行FFT得到语谱图
- 用Mask算法对语谱图像素进行降噪处理
- 进行IFT得到恢复的音频信号
语谱图
选取一个音频信号进行分析和处理,可以看到音频信号的语谱图上有很多白色或接近白色的像素。白色像素表示信号的平均功率或其他统计特性的值为0或者接近0.一般在加性高斯白噪声的情况下,白色像素代表的区域是噪声。
滤波器
如频谱图所示,信号在8kHz处有一个需要被去除的噪声,可以使用一个低通滤波器进行滤波,不影响其他频谱,衰减8kHz处的频谱,得到以下频谱图。
同理,我们如果要通过语谱图达到降噪效果,需要做的就是不影响彩色的有效信号像素,衰减白色或接近白色的噪声像素。就像滤波器把频谱图低频信号保留,抠出高频噪声扔掉,语谱图则是保留彩色像素,抠出浅色像素扔掉。像是一个二维滤波器。
Mask
Mask是图像处理里比较常用的一种抠图算法。简要原理如下图所示(图片有参考)。
- 0:Mask
- 1:Copy
可见当mask bitmap区域是0的时候,原图相应区域被操作成为0;mask bitmap区域是1的时候,原图相应区域原封不动被复制。
达到的效果也就是抠出了原图相应mask bitmap取值1的那些区域。
*在音频处理中,mask由信号FFT和阈值比较得到。阈值由信号的统计特性和算法所期望的灵敏度决定。如下图所示,是该音频信号的阈值分析。
平滑处理
如图所示,用于平滑处理mask的滤波器。
音频降噪的mask的值不只是0和1,通过平滑处理之后取值范围在[0,1]区间。这样恢复的音频更流畅。
信号恢复
通过原来的语谱图和mask bitmap做运算得到masked map如图所示。可见相比于原来的语谱图,大量浅色像素被掩蔽。
经过恢复得到恢复后信号语谱图。
做逆运算即可得到信号的时域情况,对比如下。
1.原始音频
2.加噪声后音频
3.降噪恢复音频
经过播放.WAV文件,加噪信号有明显的噪声干扰,降噪信号几乎没有噪声,恢复效果良好。对于加性噪声,mask降噪具有很不错的效果。
完整代码如下(Google colab):
import IPython
import matplotlib
from prompt_toolkit import output
from scipy.io import wavfile
import scipy.signal
import numpy as np
import matplotlib.pyplot as plt
import librosa
from IPython.display import Audio, display
from scipy.fftpack import fft
from google.colab import files
samplerate, data = wavfile.read('audio1.wav') #audio1 is 3sec.
rate = samplerate
original_data = data
data_array_size = len(data)
duration = len(data)/samplerate
time = np.arange(0,duration,1/samplerate) #time vector
plt.plot(time,data)
plt.xlabel('Time(s)')
plt.ylabel('Amplitude')
plt.title('audio1.wav')
plt.show()
frequency = np.linspace(0.0,samplerate/2,int(data_array_size/2))
freq_data = fft(data)
y = 2/data_array_size * np.abs(freq_data[0:int(data_array_size/2)])
plt.plot(frequency, y)
plt.title('Frequency domain Signal')
plt.xlabel('Frequency in Hz')
plt.ylabel('Amplitude')
plt.show()
data = data / 32768
def fftnoise(f):
f = np.array(f, dtype="complex")
Np = (len(f) - 1) // 2
phases = np.random.rand(Np) * 2 * np.pi
phases = np.cos(phases) + 1j * np.sin(phases)
f[1: Np + 1] *= phases
f[-1: -1 - Np: -1] = np.conj(f[1: Np + 1])
return np.fft.ifft(f).real
def band_limited_noise(min_freq, max_freq, samples=1024, samplerate=1):
freqs = np.abs(np.fft.fftfreq(samples, 1 / samplerate))
f = np.zeros(samples)
f[np.logical_and(freqs >= min_freq, freqs <= max_freq)] = 1
return fftnoise(f)
IPython.display.Audio(data=data, rate=rate)
# 绘制时域波形图
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(data)
plt.show()
#加噪声
noise_len = 2 #秒
noise = band_limited_noise(min_freq=4000, max_freq = 12000, samples=len(data), samplerate=rate)*10
noise_clip = noise[:rate*noise_len]
audio_clip_band_limited = data+noise
#绘制时域波形图
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(audio_clip_band_limited)
IPython.display.Audio(data=audio_clip_band_limited, rate=rate)
#降噪
import time
from datetime import timedelta as td
def _stft(y, n_fft, hop_length, win_length):
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(y, hop_length, win_length):
return librosa.istft(y, hop_length, win_length)
def _amp_to_db(x):
return librosa.core.amplitude_to_db(x, ref=1.0, amin=1e-20, top_db=80.0)
def _db_to_amp(x, ):
return librosa.core.db_to_amplitude(x, ref=1.0)
def plot_spectrogram(signal, title):
fig, ax = plt.subplots(figsize=(20, 4))
cax = ax.matshow(
signal,
origin="lower",
aspect="auto",
cmap=plt.cm.seismic,
vmin=-1 * np.max(np.abs(signal)),
vmax=np.max(np.abs(signal)),
)
fig.colorbar(cax)
ax.set_title(title)
plt.tight_layout()
plt.show()
def plot_statistics_and_filter(
mean_freq_noise, std_freq_noise, noise_thresh, smoothing_filter
):
fig, ax = plt.subplots(ncols=2, figsize=(20, 4))
plt_mean, = ax[0].plot(mean_freq_noise, label="Mean power of noise")
plt_std, = ax[0].plot(std_freq_noise, label="Std. power of noise")
plt_std, = ax[0].plot(noise_thresh, label="Noise threshold (by frequency)")
ax[0].set_title("Threshold for mask")
ax[0].legend()
cax = ax[1].matshow(smoothing_filter, origin="lower")
fig.colorbar(cax)
ax[1].set_title("Filter for smoothing Mask")
plt.show()
def removeNoise(
audio_clip,
noise_clip,
n_grad_freq=2,
n_grad_time=4,
n_fft=2048,
win_length=2048,
hop_length=512,
n_std_thresh=1.5,
prop_decrease=1.0,
verbose=False,
visual=False,
):
if verbose:
start = time.time()
# 噪声的STFT
noise_stft = _stft(noise_clip, n_fft, hop_length, win_length)
noise_stft_db = _amp_to_db(np.abs(noise_stft)) # convert to dB
# Calculate statistics over noise
mean_freq_noise = np.mean(noise_stft_db, axis=1)
std_freq_noise = np.std(noise_stft_db, axis=1)
noise_thresh = mean_freq_noise + std_freq_noise * n_std_thresh
if verbose:
print("STFT on noise:", td(seconds=time.time() - start))
start = time.time()
# 信号的STFT
if verbose:
start = time.time()
sig_stft = _stft(audio_clip, n_fft, hop_length, win_length)
sig_stft_db = _amp_to_db(np.abs(sig_stft))
if verbose:
print("STFT on signal:", td(seconds=time.time() - start))
start = time.time()
# 计算掩码
mask_gain_dB = np.min(_amp_to_db(np.abs(sig_stft)))
print(noise_thresh, mask_gain_dB)
# Create a smoothing filter for the mask in time and frequency
smoothing_filter = np.outer(
np.concatenate(
[
np.linspace(0, 1, n_grad_freq + 1, endpoint=False),
np.linspace(1, 0, n_grad_freq + 2),
]
)[1:-1],
np.concatenate(
[
np.linspace(0, 1, n_grad_time + 1, endpoint=False),
np.linspace(1, 0, n_grad_time + 2),
]
)[1:-1],
)
smoothing_filter = smoothing_filter / np.sum(smoothing_filter)
# 计算阈值
db_thresh = np.repeat(
np.reshape(noise_thresh, [1, len(mean_freq_noise)]),
np.shape(sig_stft_db)[1],
axis=0,
).T
# 阈值比较确定掩码
sig_mask = sig_stft_db < db_thresh
if verbose:
print("Masking:", td(seconds=time.time() - start))
start = time.time()
# 平滑处理
sig_mask = scipy.signal.fftconvolve(sig_mask, smoothing_filter, mode="same")
sig_mask = sig_mask * prop_decrease
if verbose:
print("Mask convolution:", td(seconds=time.time() - start))
start = time.time()
# 在信号上应用掩码
sig_stft_db_masked = (
sig_stft_db * (1 - sig_mask)
+ np.ones(np.shape(mask_gain_dB)) * mask_gain_dB * sig_mask
)
sig_imag_masked = np.imag(sig_stft) * (1 - sig_mask)
sig_stft_amp = (_db_to_amp(sig_stft_db_masked) * np.sign(sig_stft)) + (
1j * sig_imag_masked
)
if verbose:
print("Mask application:", td(seconds=time.time() - start))
start = time.time()
# 恢复信号
recovered_signal = _istft(sig_stft_amp, hop_length, win_length)
recovered_spec = _amp_to_db(
np.abs(_stft(recovered_signal, n_fft, hop_length, win_length))
)
if verbose:
print("Signal recovery:", td(seconds=time.time() - start))
if visual:
plot_spectrogram(noise_stft_db, title="Noise")
if visual:
plot_statistics_and_filter(
mean_freq_noise, std_freq_noise, noise_thresh, smoothing_filter
)
if visual:
plot_spectrogram(sig_stft_db, title="Signal")
if visual:
plot_spectrogram(sig_mask, title="Mask applied")
if visual:
plot_spectrogram(sig_stft_db_masked, title="Masked signal")
if visual:
plot_spectrogram(recovered_spec, title="Recovered spectrogram")
return recovered_signal
output = removeNoise(audio_clip=audio_clip_band_limited, noise_clip=noise_clip,verbose=True, visual=True)
#绘制时域波形图
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 4))
plt.plot(output, color='black')
ax.set_xlim((0, len(output)))
plt.show()
original_data = audio_clip_band_limited * 32768
new_original_data = original_data.astype(np.int16)
filtered_data = output * 32768
new_filtered_data = filtered_data.astype(np.int16)
# 播放音频
wavfile.write('original_audio1_with_noise.wav', samplerate, new_original_data)
display(Audio('original_audio1_with_noise.wav', autoplay=True))
files.download('original_audio1_with_noise.wav')
wavfile.write('filtered_audio1.wav', samplerate, new_filtered_data)
display(Audio('filtered_audio1.wav', autoplay=True))
files.download('filtered_audio1.wav')