音频分类之PSLA

贴下原作者github:https://github.com/YuanGongND/psla

最近弄了下音频分类的工作,了解了下基本流程。一般都是MFCC/Fbank特征提取后再进行CNN/attention。数据集公共的一般是用FSD50K、audioset居多。我这里是用自己的数据集,需要转换到audioset格式即可。

        下面说下我对PSLA代码的理解:

AudiosetDataset:
这个是数据增强和预处理,包括声音混合、fbank、数据标准化以及数据填充到固定长度,还有FrequencyMasking、TimeMasking。

FrequencyMasking、TimeMasking效果如下图,就是在fbank后对频域和时域进行随机遮盖,这个要根据自己的任务特点进行使用,我的任务刚开始使用这个精度一直有问题后来关闭了这个功能就好很多了,异物遮盖导致我的特征点被弄没了,训练出错了。

图片出处:利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型_AI科技大本营-CSDN博客

作者模型backbone主要是efficentnet和resnet、mbnet可选,训练启动脚本我用的.\egs\fsd50k\run.sh,因为sh文件无法debug所以我把所有参数都改到了run.py里面去方便调试。

单独预测的代码作者还没有写,可以自己改下预处理的代码关闭部分数据增强的部分即可。代码如下:

功能:读取一个wav文件预测结果并保存截断(wav文件每隔1.5s预测一次)的音频

import csv
import os
import time
import librosa
import numpy
import torchaudio
import torch
import ast
import models
import argparse
from collections import OrderedDict
import numpy as np


def make_index_dict(label_csv):
    index_lookup = {}
    with open(label_csv, 'r') as f:
        csv_reader = csv.DictReader(f)
        line_count = 0
        for row in csv_reader:
            index_lookup[row['mid']] = row['index']
            line_count += 1
    return index_lookup


class Relay:
    def __init__(self, audio_conf, label_csv):
        """
        Dataset that manages audio recordings
        :param audio_conf: Dictionary containing the audio loading and preprocessing settings
        :param dataset_json_file
        """
        self.audio_conf = audio_conf
        print('---------------the {:s} dataloader---------------'.format(self.audio_conf.get('mode')))
        self.melbins = self.audio_conf.get('num_mel_bins')
        self.freqm = self.audio_conf.get('freqm')
        self.timem = self.audio_conf.get('timem')
        print('now using following mask: {:d} freq, {:d} time'.format(self.audio_conf.get('freqm'),
                                                                      self.audio_conf.get('timem')))
        self.mixup = self.audio_conf.get('mixup')
        print('now using mix-up with rate {:f}'.format(self.mixup))
        self.dataset = self.audio_conf.get('dataset')
        print('now process ' + self.dataset)
        # dataset spectrogram mean and std, used to normalize the input
        self.norm_mean = self.audio_conf.get('mean')
        self.norm_std = self.audio_conf.get('std')
        self.skip_norm = self.audio_conf.get('skip_norm') if self.audio_conf.get('skip_norm') else False
        if self.skip_norm:
            print('now skip normalization (use it ONLY when you are computing the normalization stats).')
        else:
            print(
                'use dataset mean {:.3f} and std {:.3f} to normalize the input.'.format(self.norm_mean, self.norm_std))
        self.noise = self.audio_conf.get('noise')
        if self.noise == True:
            print('now use noise augmentation')

        self.index_dict = make_index_dict(label_csv)
        self.label_num = len(self.index_dict)
        print('number of classes is {:d}'.format(self.label_num))

    def _wav2fbank(self, waveform, sr):
        # mixup
        # waveform, sr = torchaudio.load(filename)
        waveform = waveform - waveform.mean()

        fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
                                                  window_type='hanning', num_mel_bins=self.melbins, dither=0.0,
                                                  frame_shift=10)

        target_length = self.audio_conf.get('target_length')
        n_frames = fbank.shape[0]

        p = target_length - n_frames

        # cut and pad
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            fbank = m(fbank)
        elif p < 0:
            fbank = fbank[0:target_length, :]

        return fbank, 0

    def preprocess(self, data, sr):
        """
        returns: image, audio, nframes
        where image is a FloatTensor of size (3, H, W)
        audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform
        nframes is an integer
        """
        fbank, mix_lambda = self._wav2fbank(data, sr)

        if not self.skip_norm:
            fbank = (fbank - self.norm_mean) / (self.norm_std)
        fbank = fbank[None, :]
        return fbank#, label_indices


_eff_b = 2
_target_length = 300

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--weight_path", type=str,
                    default="/home/zhouhe/workspace/acv/acv/audio/psla-main/exp/2/models/best_audio_model.pth",
                    help="model file path")
parser.add_argument("--model", type=str, default="efficientnet", help="audio model architecture",
                    choices=["efficientnet", "resnet", "mbnet"])
parser.add_argument("--eff_b", type=int, default=_eff_b,
                    help="which efficientnet to use, the larger number, the more complex")
parser.add_argument("--n_class", type=int, default=2, help="number of classes")
parser.add_argument('--impretrain', help='if use imagenet pretrained CNNs', type=ast.literal_eval, default='True')
parser.add_argument("--att_head", type=int, default=4, help="number of attention heads")
parser.add_argument("--target_length", type=int, default=_target_length, help="the input length in frames")
parser.add_argument("--dataset_mean", type=float, default=-4.6476,
                    help="the dataset mean, used for input normalization")
parser.add_argument("--dataset_std", type=float, default=4.5699, help="the dataset std, used for input normalization")
parser.add_argument("--dataset", type=str, default="audioset", help="the dataset used",
                    choices=["audioset", "esc50", "speechcommands"])
parser.add_argument("--data-val", type=str, default='../wz_relay/wz_relay.json', help="validation data json")
parser.add_argument("--label-csv", type=str, default='../wz_relay/wz_relay.csv', help="csv with class labels")

args = parser.parse_args()

if args.model == 'efficientnet':
    audio_model = models.EffNetAttention(label_dim=args.n_class, b=args.eff_b, pretrain=args.impretrain,
                                         head_num=args.att_head)
elif args.model == 'resnet':
    audio_model = models.ResNetAttention(label_dim=args.n_class, pretrain=args.impretrain)
# elif args.model == 'mbnet':
#     audio_model = models.MBNet(label_dim=args.n_class, pretrain=args.effpretrain)

val_audio_conf = {'num_mel_bins': 128, 'target_length': args.target_length, 'freqm': 0, 'timem': 0, 'mixup': 0,
                  'dataset': args.dataset, 'mode': 'evaluation', 'mean': args.dataset_mean,
                  'std': args.dataset_std, 'noise': False}
state_dictBA = torch.load(args.weight_path)
new_state_dictBA = OrderedDict()
for k, v in state_dictBA.items():
    if k[:7] == 'module.':
        name = k[7:]  # remove `module.`
    else:
        name = k
    new_state_dictBA[name] = v
audio_model.load_state_dict(new_state_dictBA)
print(audio_model)
audio_model.float().eval().cuda()
# audio_model.half()

data = np.zeros([1, 300, 128], dtype=np.float32)
data = torch.from_numpy(data).cuda()
r = Relay(val_audio_conf, args.label_csv)

ori_data, ori_sr = torchaudio.load('./test.wav')
all_second = len(ori_data) / ori_sr
interval_rate = int(ori_sr * 1.5)
for idx, i in enumerate(range(0, len(ori_data[0]), int(1.5 * ori_sr))):
    if i + interval_rate < len(ori_data[0]):
        start_ = time.time()
        out_data = ori_data[:, i:i + interval_rate]
        data = r.preprocess(out_data, ori_sr)
        ret = audio_model(data.cuda())
        # print('predict:', ret.argmax(), label.argmax())
        save_name = os.path.join("./test_ret", f"{ret.argmax()}_{idx * 1.5}_{idx * 1.5 + 1.5}.wav")
        librosa.output.write_wav(save_name, out_data[0].cpu().numpy().astype(np.float32), ori_sr)
        print(time.time() - start_)
print(ret)

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值