Speech Enhancement Generation Adversarial Network(code analysis)

本文接上一篇文章,主要讲述网络代码结构(main_test)


1、定义flags,分别为整形、布尔型、浮点、字符串

flags.DEFINE_integer
flags.DEFINE_boolean
flags.DEFINE_float
flags.DEFINE_string

2、选择CPU/GPU

for device in devices:
    if len(devices) > 1 and 'cpu' in device.name:
        # Use cpu only when we dont have gpus
        continue
    print('Using device: ', device.name)
    udevices.append(device.name)

3、选择模型(GAN/AE)以及训练/测试

with tf.Session(config=config) as sess:
    if FLAGS.model == 'gan':
        print('Creating GAN model')
        se_model = SEGAN(sess, FLAGS, udevices)
    elif FLAGS.model == 'ae':
        print('Creating AE model')
        se_model = SEAE(sess, FLAGS, udevices)
    else:
        raise ValueError('{} model type not understood!'.format(FLAGS.model))
    if FLAGS.test_wav is None:
        se_model.train(FLAGS, udevices)
    else:
        if FLAGS.weights is None:
            raise ValueError('weights must be specified!')
        print('Loading model weights...')

4、对单通道带噪音频进行clean处理

se_model.load(FLAGS.save_path, FLAGS.weights)
wav_all = os.listdir(FLAGS.test_wav)
wav_number = len(wav_all)
for wav_index in range(wav_number):
    wav_single_full = FLAGS.test_wav+wav_all[wav_index]
    fm, wav_data = wavfile.read(wav_single_full)
    wavname = wav_all[wav_index]
    if fm != 16000:
        raise ValueError('16kHz required! Test file is different')
    wave1 = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.
    if FLAGS.preemph  > 0:
        print('preemph test wave with {}'.format(FLAGS.preemph))
        x_pholder_1, preemph_op_1 = pre_emph_test(FLAGS.preemph, wave1.shape[0])
        wave1 = sess.run(preemph_op_1, feed_dict={x_pholder_1:wave1})
    print('test wave1 shape: ', wave1.shape)
    print('test wave1 min:{}  max:{}'.format(np.min(wave1), np.max(wave1)))
    c_wave1 = se_model.clean(wave1)
    wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), 16000, c_wave1)

5、进入clean函数
* 对数据进行分段处理,进入Gs
* 进而选择模型(AE/GAN)

def clean(self, x):
 """ clean a utterance x
     x: numpy array containing the normalized noisy waveform
 """
 c_res = None
 for beg_i in range(0, x.shape[0], self.canvas_size):
     if x.shape[0] - 
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值