贴下原作者github:https://github.com/YuanGongND/psla
最近弄了下音频分类的工作,了解了下基本流程。一般都是MFCC/Fbank特征提取后再进行CNN/attention。数据集公共的一般是用FSD50K、audioset居多。我这里是用自己的数据集,需要转换到audioset格式即可。
下面说下我对PSLA代码的理解:
AudiosetDataset:
这个是数据增强和预处理,包括声音混合、fbank、数据标准化以及数据填充到固定长度,还有FrequencyMasking、TimeMasking。
FrequencyMasking、TimeMasking效果如下图,就是在fbank后对频域和时域进行随机遮盖,这个要根据自己的任务特点进行使用,我的任务刚开始使用这个精度一直有问题后来关闭了这个功能就好很多了,异物遮盖导致我的特征点被弄没了,训练出错了。
图片出处:利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型_AI科技大本营-CSDN博客
作者模型backbone主要是efficentnet和resnet、mbnet可选,训练启动脚本我用的.\egs\fsd50k\run.sh,因为sh文件无法debug所以我把所有参数都改到了run.py里面去方便调试。
单独预测的代码作者还没有写,可以自己改下预处理的代码关闭部分数据增强的部分即可。代码如下:
功能:读取一个wav文件预测结果并保存截断(wav文件每隔1.5s预测一次)的音频
import csv
import os
import time
import librosa
import numpy
import torchaudio
import torch
import ast
import models
import argparse
from collections import OrderedDict
import numpy as np
def make_index_dict(label_csv):
index_lookup = {}
with open(label_csv, 'r') as f:
csv_reader = csv.DictReader(f)
line_count = 0
for row in csv_reader:
index_lookup[row['mid']] = row['index']
line_count += 1
return index_lookup
class Relay:
def __init__(self, audio_conf, label_csv):
"""
Dataset that manages audio recordings
:param audio_conf: Dictionary containing the audio loading and preprocessing settings
:param dataset_json_file
"""
self.audio_conf = audio_conf
print('---------------the {:s} dataloader---------------'.format(self.audio_conf.get('mode')))
self.melbins = self.audio_conf.get('num_mel_bins')
self.freqm = self.audio_conf.get('freqm')
self.timem = self.audio_conf.get('timem')
print('now using following mask: {:d} freq, {:d} time'.format(self.audio_conf.get('freqm'),
self.audio_conf.get('timem')))
self.mixup = self.audio_conf.get('mixup')
print('now using mix-up with rate {:f}'.format(self.mixup))
self.dataset = self.audio_conf.get('dataset')
print('now process ' + self.dataset)
# dataset spectrogram mean and std, used to normalize the input
self.norm_mean = self.audio_conf.get('mean')
self.norm_std = self.audio_conf.get('std')
self.skip_norm = self.audio_conf.get('skip_norm') if self.audio_conf.get('skip_norm') else False
if self.skip_norm:
print('now skip normalization (use it ONLY when you are computing the normalization stats).')
else:
print(
'use dataset mean {:.3f} and std {:.3f} to normalize the input.'.format(self.norm_mean, self.norm_std))
self.noise = self.audio_conf.get('noise')
if self.noise == True:
print('now use noise augmentation')
self.index_dict = make_index_dict(label_csv)
self.label_num = len(self.index_dict)
print('number of classes is {:d}'.format(self.label_num))
def _wav2fbank(self, waveform, sr):
# mixup
# waveform, sr = torchaudio.load(filename)
waveform = waveform - waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
window_type='hanning', num_mel_bins=self.melbins, dither=0.0,
frame_shift=10)
target_length = self.audio_conf.get('target_length')
n_frames = fbank.shape[0]
p = target_length - n_frames
# cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0:target_length, :]
return fbank, 0
def preprocess(self, data, sr):
"""
returns: image, audio, nframes
where image is a FloatTensor of size (3, H, W)
audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform
nframes is an integer
"""
fbank, mix_lambda = self._wav2fbank(data, sr)
if not self.skip_norm:
fbank = (fbank - self.norm_mean) / (self.norm_std)
fbank = fbank[None, :]
return fbank#, label_indices
_eff_b = 2
_target_length = 300
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--weight_path", type=str,
default="/home/zhouhe/workspace/acv/acv/audio/psla-main/exp/2/models/best_audio_model.pth",
help="model file path")
parser.add_argument("--model", type=str, default="efficientnet", help="audio model architecture",
choices=["efficientnet", "resnet", "mbnet"])
parser.add_argument("--eff_b", type=int, default=_eff_b,
help="which efficientnet to use, the larger number, the more complex")
parser.add_argument("--n_class", type=int, default=2, help="number of classes")
parser.add_argument('--impretrain', help='if use imagenet pretrained CNNs', type=ast.literal_eval, default='True')
parser.add_argument("--att_head", type=int, default=4, help="number of attention heads")
parser.add_argument("--target_length", type=int, default=_target_length, help="the input length in frames")
parser.add_argument("--dataset_mean", type=float, default=-4.6476,
help="the dataset mean, used for input normalization")
parser.add_argument("--dataset_std", type=float, default=4.5699, help="the dataset std, used for input normalization")
parser.add_argument("--dataset", type=str, default="audioset", help="the dataset used",
choices=["audioset", "esc50", "speechcommands"])
parser.add_argument("--data-val", type=str, default='../wz_relay/wz_relay.json', help="validation data json")
parser.add_argument("--label-csv", type=str, default='../wz_relay/wz_relay.csv', help="csv with class labels")
args = parser.parse_args()
if args.model == 'efficientnet':
audio_model = models.EffNetAttention(label_dim=args.n_class, b=args.eff_b, pretrain=args.impretrain,
head_num=args.att_head)
elif args.model == 'resnet':
audio_model = models.ResNetAttention(label_dim=args.n_class, pretrain=args.impretrain)
# elif args.model == 'mbnet':
# audio_model = models.MBNet(label_dim=args.n_class, pretrain=args.effpretrain)
val_audio_conf = {'num_mel_bins': 128, 'target_length': args.target_length, 'freqm': 0, 'timem': 0, 'mixup': 0,
'dataset': args.dataset, 'mode': 'evaluation', 'mean': args.dataset_mean,
'std': args.dataset_std, 'noise': False}
state_dictBA = torch.load(args.weight_path)
new_state_dictBA = OrderedDict()
for k, v in state_dictBA.items():
if k[:7] == 'module.':
name = k[7:] # remove `module.`
else:
name = k
new_state_dictBA[name] = v
audio_model.load_state_dict(new_state_dictBA)
print(audio_model)
audio_model.float().eval().cuda()
# audio_model.half()
data = np.zeros([1, 300, 128], dtype=np.float32)
data = torch.from_numpy(data).cuda()
r = Relay(val_audio_conf, args.label_csv)
ori_data, ori_sr = torchaudio.load('./test.wav')
all_second = len(ori_data) / ori_sr
interval_rate = int(ori_sr * 1.5)
for idx, i in enumerate(range(0, len(ori_data[0]), int(1.5 * ori_sr))):
if i + interval_rate < len(ori_data[0]):
start_ = time.time()
out_data = ori_data[:, i:i + interval_rate]
data = r.preprocess(out_data, ori_sr)
ret = audio_model(data.cuda())
# print('predict:', ret.argmax(), label.argmax())
save_name = os.path.join("./test_ret", f"{ret.argmax()}_{idx * 1.5}_{idx * 1.5 + 1.5}.wav")
librosa.output.write_wav(save_name, out_data[0].cpu().numpy().astype(np.float32), ori_sr)
print(time.time() - start_)
print(ret)