处理一条语句

import soundfile as sf
import torch
from scipy.io import wavfile
import os
import numpy as np

# !/usr/bin/env python

# Created on 2018/12
# Author: Kaituo XU
# Edited by: yoonsanghyu

import argparse
import os
from nbss.NBSS import NBSS
import soundfile as sf
from scipy.io import wavfile
import librosa
import torch
import numpy as np
from collections import OrderedDict
from tqdm import tqdm

# from FaSNet import FaSNet_TAC
from D2Net.mc_bss_D2NET import Net

parser = argparse.ArgumentParser('Separate speech using FaSNet + TAC')
parser.add_argument('--model_path', type=str, default='exp/tmp/mc_D2Net_whamr_seperation/temp_best.pth.tar',
                    help='Path to model file created by training')
parser.add_argument('--test_dir', type=str,
                    default="/home/weiWB/dataset/speech_seperation/mc_whamr/wav8k/S1_S2_rir_noise_tt_list",
                    help='path to test/2mic/samples')
parser.add_argument('--out_dir', type=str,
                    default='/home/weiWB/dataset/speech_seperation/mc_whamr/wav8k/test/',
                    help='Directory putting separated wav files')
parser.add_argument('--use_cuda', type=int, default=1,
                    help='Whether use GPU to separate speech')
parser.add_argument('--sample_rate', default=8000, type=int,
                    help='Sample rate')
parser.add_argument('--batch_size', default=1, type=int,
                    help='Batch size')

# Network architecture
parser.add_argument('--enc_dim', default=64, type=int, help='Number of filters in autoencoder')
parser.add_argument('--win_len', default=4, type=int, help='Number of convolutional blocks in each repeat')
parser.add_argument('--context_len', default=16, type=int, help='context window size')
parser.add_argument('--feature_dim', default=64, type=int, help='feature dimesion')
parser.add_argument('--hidden_dim', default=128, type=int, help='Hidden dimension')
parser.add_argument('--layer', default=5, type=int, help='Number of layer in dprnn step')
parser.add_argument('--segment_size', default=50, type=int, help="segment_size")
parser.add_argument('--nspk', default=2, type=int, help='Maximum number of speakers')
parser.add_argument('--mic', default=6, type=int, help='number of microphone')

def remove_pad(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 3:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 3:  # [B, C, T]
            results.append(input[:, :length].view(C, -1).cpu().numpy())
        elif dim == 2:  # [B, T]
            results.append(input[:length].view(-1).cpu().numpy())
    return results

def separate(args):
    # Load FaSNet model
    # model = FaSNet_TAC(enc_dim=args.enc_dim, feature_dim=args.feature_dim, hidden_dim=args.hidden_dim, layer=args.layer, segment_size=args.segment_size,
    #                        nspk=args.nspk, win_len=args.win_len, context_len=args.context_len, sr=args.sample_rate)

    # Load mc_D2Net
    model = Net()

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    model_info = torch.load(args.model_path)
    try:
        model.load_state_dict(model_info['model_state_dict'])
    except KeyError:
        state_dict = OrderedDict()
        for k, v in model_info['model_state_dict'].items():
            name = k.replace("module.", "")  # remove 'module.'
            state_dict[name] = v
        model.load_state_dict(state_dict)

    print(model)
    model.eval()



    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        sf.write(filename, inputs, sr)  # norm=True)

    with torch.no_grad():
        # t = tqdm(total=len(eval_dataset), mininterval=0.5)

        s, padded_mixture= wavfile.read("/home/dataset/speech_separation/mc_whamr/wav8k/min/tt/mix_both_reverb/443c0212_1.8957_421o0301_-1.8957.wav")
        padded_mixture = torch.tensor(padded_mixture)
        padded_mixture = padded_mixture.permute(1,0)
        padded_mixture = torch.tensor(padded_mixture[:,:32000])
        padded_mixture = padded_mixture.unsqueeze(dim=0)
        if args.use_cuda:

            padded_mixture = padded_mixture.cuda()
            # mixture_lengths = mixture_lengths.cuda()
            # padded_source = padded_source.cuda()

        x = torch.rand(2, 6, 32000)
        none_mic = torch.zeros(1).type(x.type())
        # FaSNet_TAC
        # estimate_source = model(padded_mixture, none_mic.long())  # [M, C, T]

        # D2Net
        estimate_source = model(padded_mixture)  # [M, C, T]

        for j in range(estimate_source.size()[0]):

            scs = estimate_source[j].cpu().numpy()

            power = np.sqrt((padded_mixture.cpu().numpy() ** 2).sum() / len(padded_mixture.cpu().numpy()))
            for k, src in enumerate(scs):
                this_dir = os.path.join(args.out_dir, 'utt{0}'.format( 1))
                if not os.path.exists(this_dir):
                    os.makedirs(this_dir)
                source = src * (power / np.sqrt((src ** 2).sum() / len(padded_mixture)))
                write(source, os.path.join(this_dir, 's{0}.wav'.format(k + 1)))

        # t.update()


if __name__ == '__main__':
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"  # 指定使用第几个显卡
    args = parser.parse_args()
    print(args)
    separate(args)



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值