本文接上一篇文章,主要讲述网络代码结构(main_test)
1、定义flags,分别为整形、布尔型、浮点、字符串
flags.DEFINE_integer
flags.DEFINE_boolean
flags.DEFINE_float
flags.DEFINE_string
2、选择CPU/GPU
for device in devices:
if len(devices) > 1 and 'cpu' in device.name:
# Use cpu only when we dont have gpus
continue
print('Using device: ', device.name)
udevices.append(device.name)
3、选择模型(GAN/AE)以及训练/测试
with tf.Session(config=config) as sess:
if FLAGS.model == 'gan':
print('Creating GAN model')
se_model = SEGAN(sess, FLAGS, udevices)
elif FLAGS.model == 'ae':
print('Creating AE model')
se_model = SEAE(sess, FLAGS, udevices)
else:
raise ValueError('{} model type not understood!'.format(FLAGS.model))
if FLAGS.test_wav is None:
se_model.train(FLAGS, udevices)
else:
if FLAGS.weights is None:
raise ValueError('weights must be specified!')
print('Loading model weights...')
4、对单通道带噪音频进行clean处理
se_model.load(FLAGS.save_path, FLAGS.weights)
wav_all = os.listdir(FLAGS.test_wav)
wav_number = len(wav_all)
for wav_index in range(wav_number):
wav_single_full = FLAGS.test_wav+wav_all[wav_index]
fm, wav_data = wavfile.read(wav_single_full)
wavname = wav_all[wav_index]
if fm != 16000:
raise ValueError('16kHz required! Test file is different')
wave1 = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.
if FLAGS.preemph > 0:
print('preemph test wave with {}'.format(FLAGS.preemph))
x_pholder_1, preemph_op_1 = pre_emph_test(FLAGS.preemph, wave1.shape[0])
wave1 = sess.run(preemph_op_1, feed_dict={x_pholder_1:wave1})
print('test wave1 shape: ', wave1.shape)
print('test wave1 min:{} max:{}'.format(np.min(wave1), np.max(wave1)))
c_wave1 = se_model.clean(wave1)
wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), 16000, c_wave1)
5、进入clean函数
* 对数据进行分段处理,进入Gs
* 进而选择模型(AE/GAN)
def clean(self, x):
""" clean a utterance x
x: numpy array containing the normalized noisy waveform
"""
c_res = None
for beg_i in range(0, x.shape[0], self.canvas_size):
if x.shape[0] -