import soundfile as sf
import torch
from scipy.io import wavfile
import os
import numpy as np
# !/usr/bin/env python
# Created on 2018/12
# Author: Kaituo XU
# Edited by: yoonsanghyu
import argparse
import os
from nbss.NBSS import NBSS
import soundfile as sf
from scipy.io import wavfile
import librosa
import torch
import numpy as np
from collections import OrderedDict
from tqdm import tqdm
# from FaSNet import FaSNet_TAC
from D2Net.mc_bss_D2NET import Net
parser = argparse.ArgumentParser('Separate speech using FaSNet + TAC')
parser.add_argument('--model_path', type=str, default='exp/tmp/mc_D2Net_whamr_seperation/temp_best.pth.tar',
help='Path to model file created by training')
parser.add_argument('--test_dir', type=str,
default="/home/weiWB/dataset/speech_seperation/mc_whamr/wav8k/S1_S2_rir_noise_tt_list",
help='path to test/2mic/samples')
parser.add_argument('--out_dir', type=str,
default='/home/weiWB/dataset/speech_seperation/mc_whamr/wav8k/test/',
help='Directory putting separated wav files')
parser.add_argument('--use_cuda', type=int, default=1,
help='Whether use GPU to separate speech')
parser.add_argument('--sample_rate', default=8000, type=int,
help='Sample rate')
parser.add_argument('--batch_size', default=1, type=int,
help='Batch size')
# Network architecture
parser.add_argument('--enc_dim', default=64, type=int, help='Number of filters in autoencoder')
parser.add_argument('--win_len', default=4, type=int, help='Number of convolutional blocks in each repeat')
parser.add_argument('--context_len', default=16, type=int, help='context window size')
parser.add_argument('--feature_dim', default=64, type=int, help='feature dimesion')
parser.add_argument('--hidden_dim', default=128, type=int, help='Hidden dimension')
parser.add_argument('--layer', default=5, type=int, help='Number of layer in dprnn step')
parser.add_argument('--segment_size', default=50, type=int, help="segment_size")
parser.add_argument('--nspk', default=2, type=int, help='Maximum number of speakers')
parser.add_argument('--mic', default=6, type=int, help='number of microphone')
def remove_pad(inputs, inputs_lengths):
"""
Args:
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
inputs_lengths: torch.Tensor, [B]
Returns:
results: a list containing B items, each item is [C, T], T varies
"""
results = []
dim = inputs.dim()
if dim == 3:
C = inputs.size(1)
for input, length in zip(inputs, inputs_lengths):
if dim == 3: # [B, C, T]
results.append(input[:, :length].view(C, -1).cpu().numpy())
elif dim == 2: # [B, T]
results.append(input[:length].view(-1).cpu().numpy())
return results
def separate(args):
# Load FaSNet model
# model = FaSNet_TAC(enc_dim=args.enc_dim, feature_dim=args.feature_dim, hidden_dim=args.hidden_dim, layer=args.layer, segment_size=args.segment_size,
# nspk=args.nspk, win_len=args.win_len, context_len=args.context_len, sr=args.sample_rate)
# Load mc_D2Net
model = Net()
if args.use_cuda:
model = torch.nn.DataParallel(model)
model.cuda()
model_info = torch.load(args.model_path)
try:
model.load_state_dict(model_info['model_state_dict'])
except KeyError:
state_dict = OrderedDict()
for k, v in model_info['model_state_dict'].items():
name = k.replace("module.", "") # remove 'module.'
state_dict[name] = v
model.load_state_dict(state_dict)
print(model)
model.eval()
os.makedirs(args.out_dir, exist_ok=True)
def write(inputs, filename, sr=args.sample_rate):
sf.write(filename, inputs, sr) # norm=True)
with torch.no_grad():
# t = tqdm(total=len(eval_dataset), mininterval=0.5)
s, padded_mixture= wavfile.read("/home/dataset/speech_separation/mc_whamr/wav8k/min/tt/mix_both_reverb/443c0212_1.8957_421o0301_-1.8957.wav")
padded_mixture = torch.tensor(padded_mixture)
padded_mixture = padded_mixture.permute(1,0)
padded_mixture = torch.tensor(padded_mixture[:,:32000])
padded_mixture = padded_mixture.unsqueeze(dim=0)
if args.use_cuda:
padded_mixture = padded_mixture.cuda()
# mixture_lengths = mixture_lengths.cuda()
# padded_source = padded_source.cuda()
x = torch.rand(2, 6, 32000)
none_mic = torch.zeros(1).type(x.type())
# FaSNet_TAC
# estimate_source = model(padded_mixture, none_mic.long()) # [M, C, T]
# D2Net
estimate_source = model(padded_mixture) # [M, C, T]
for j in range(estimate_source.size()[0]):
scs = estimate_source[j].cpu().numpy()
power = np.sqrt((padded_mixture.cpu().numpy() ** 2).sum() / len(padded_mixture.cpu().numpy()))
for k, src in enumerate(scs):
this_dir = os.path.join(args.out_dir, 'utt{0}'.format( 1))
if not os.path.exists(this_dir):
os.makedirs(this_dir)
source = src * (power / np.sqrt((src ** 2).sum() / len(padded_mixture)))
write(source, os.path.join(this_dir, 's{0}.wav'.format(k + 1)))
# t.update()
if __name__ == '__main__':
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3" # 指定使用第几个显卡
args = parser.parse_args()
print(args)
separate(args)
处理一条语句
于 2022-09-26 21:57:57 首次发布