基于tensorflow的语音唤醒实践

基于tensorflow的语音唤醒实践

语音唤醒简单的来说就是一个分类任务,将样本分为唤醒词与非唤醒词(这次唤醒词为"hello, xiaogua"),这次实践所完成的任务是对给出的多段音频,通过训练的模型给出其分类。中间通过数据预处理,模型搭建与训练,后处理三个步骤。
使用训练集11000余条音频,测试集4000余条音频,均为他人自制。文中代码均是在python3.7环境下。
笔者刚刚入门tensorflow,分类任务是通过mnist手写数据集上的代码学习的,搭建网络,传入参数的方式都与mnist上的代码相似。
整个任务流程参考论文:SMALL-FOOTPRINT KEYWORD SPOTTING USING DEEP NEURAL NETWORKS

数据的预处理(特征提取)

  • 预加重:消除频谱倾斜,提升高频段
  • 分帧:截取一段音频进行处理
  • 加窗:消除吉布斯效应,使音频信号具备一些周期函数的特性
  • 快速傅里叶变换:将时域信号转换到频域
  • 通过梅尔滤波器组:模拟人耳听觉特征
  • 拼帧:将连续几帧拼接起来作为一个训练(测试)样本
'''
fbank_reader在这段代码块下方
fbank即提取出的特征,每一帧的shape为(, 40)
这里选择对当前帧之前30帧与之后10帧进行拼帧
当前帧之前不足30帧则从第一帧向后取41帧拼帧作为当前帧拼帧结果
当前帧之后不足10帧则从最后一帧向后取41帧拼帧作为当前帧拼帧结果
如果整段拼帧区域不足41帧,前面不足则重复第一帧,后面不足则重复最后一帧
'''
def frame_combine(frame, file_path, start, end):
    fbank = fbank_reader.HTKFeat_read(file_path).getall()

    if end - start + 1 < 41:
        if frame - start <= 30 and end - frame <= 10:
            frame_to_combine = []
            front_rest = 30 - (frame - start)
            back_rest = 10 - (end - frame)
            for i in range(front_rest):
                frame_to_combine.append(fbank[start].tolist())
            for i in range(start, end + 1):
                frame_to_combine.append(fbank[i].tolist())
            for i in range(back_rest):
                frame_to_combine.append(fbank[end].tolist())

        elif end - frame >= 10:
            frame_to_combine = []
            front_rest = 30 - (frame - start)
            for i in range(front_rest):
                frame_to_combine.append(fbank[start].tolist())
            for i in range(start, frame+11):
                frame_to_combine.append(fbank[i].tolist())

        else:
            frame_to_combine = []
            back_rest = 10 - (end - frame)
            for i in range(frame - 30, end + 1):
                frame_to_combine.append(fbank[i].tolist())
            for i in range(back_rest):
                frame_to_combine.append(fbank[end].tolist())
        combined = np.array(frame_to_combine).reshape(-1)

    else:
        if frame - start >= 30 and end - frame >= 10:
            frame_to_combine = fbank[frame - 30: frame + 11]
            combined = frame_to_combine.reshape(-1)

        elif frame - start < 30:
            frame_to_combine = fbank[start: start+41]
            combined = frame_to_combine.reshape(-1)

        else:
            frame_to_combine = fbank[end - 40: end+1]
            combined = frame_to_combine.reshape(-1)

    return combined.tolist()
# fbank_reader.py
# Copyright (c) 2007 Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""Read HTK feature files.
This module reads the acoustic feature files used by HTK
"""

__author__ = "David Huggins-Daines <dhuggins@cs.cmu.edu>"
__version__ = "$Revision $"

from struct import unpack, pack
import numpy

LPC = 1
LPCREFC = 2
LPCEPSTRA = 3
LPCDELCEP = 4
IREFC = 5
MFCC = 6
FBANK = 7
MELSPEC = 8
USER = 9
DISCRETE = 10
PLP = 11

_E = 0o0000100 # has energy
_N = 0o0000200 # absolute energy supressed
_D = 0o0000400 # has delta coefficients
_A = 0o0001000 # has acceleration (delta-delta) coefficients
_C = 0o0002000 # is compressed
_Z = 0o0004000 # has zero mean static coefficients
_K = 0o0010000 # has CRC checksum
_O = 0o0020000 # has 0th cepstral coefficient
_V = 0o0040000 # has VQ data
_T = 0o0100000 # has third differential coefficients


class HTKFeat_read(object):
    "Read HTK format feature files"
    def __init__(self, filename=None):
        self.swap = (unpack('=i', pack('>i', 42))[0] != 42)
        if (filename != None):
            self.open(filename)

    def __iter__(self):
        self.fh.seek(12, 0)
        return self

    def open(self, filename):
        self.filename = filename
        # To run in python2, change the "open" to "file"
        self.fh = open(filename, "rb")
        self.readheader()

    def readheader(self):
        self.fh.seek(0, 0)
        spam = self.fh.read(12)
        self.nSamples, self.sampPeriod, self.sampSize, self.parmKind = unpack(">IIHH", spam)
        # Get coefficients for compressed data
        if self.parmKind & _C:
            self.dtype = 'h'
            self.veclen = self.sampSize / 2
            if self.parmKind & 0x3f == IREFC:
                self.A = 32767
                self.B = 0
            else:
                self.A = numpy.fromfile(self.fh, 'f', self.veclen)
                self.B = numpy.fromfile(self.fh, 'f', self.veclen)
                if self.swap:
                    self.A = self.A.byteswap()
                    self.B = self.B.byteswap()
        else:
            self.dtype = 'f'
            self.veclen = self.sampSize / 4
        self.hdrlen = self.fh.tell()

    def seek(self, idx):
        self.fh.seek(self.hdrlen + idx * self.sampSize, 0)

    def next(self):
        vec = numpy.fromfile(self.fh, self.dtype, self.veclen)
        if len(vec) == 0:
            raise StopIteration
        if self.swap:
            vec = vec.byteswap()
        # Uncompress data to floats if required
        if self.parmKind & _C:
            vec = (vec.astype('f') + self.B) / self.A
        return vec

    def readvec(self):
        return self.next()

    def getall(self):
        self.seek(0)
        data = numpy.fromfile(self.fh, self.dtype)
        if self.parmKind & _K: # Remove and ignore checksum
            data = data[:-1]
        data = data.reshape(int(len(data)/self.veclen), int(self.veclen))
        if self.swap:
            data = data.byteswap()
        # Uncompress data to floats if required
        if self.parmKind & _C:
            data = (data.astype('f') + self.B) / self.A
        return data

模型搭建和训练

数据读入

由给定的训练集、测试集列表读入数据,进行拼帧后进行训练、测试,由于训练集要循环使用,测试集只要测试一次。而且训练集不仅每段音频顺序要打乱,同一段音频内的每一帧拼帧后的结果也要打乱,而测试集由于需要进行后处理,要求不打乱顺序,还要知道每段音频的位置。两个数据集的操作相差很多,所以分别定义为两个类:TestSet和TrainSet:

class TestSet(object):
    def __init__(self, exampls, labels, num_examples, fbank_end_frame):
        self._exampls = exampls
        self._labels = labels
        self._index_in_epochs = 0  # 调用next_batch()函数后记住上一次位置
        self.num_examples = num_examples  # 训练样本数
        self.fbank_end_frame = fbank_end_frame

    def next_batch(self, batch_size):
        start = self._index_in_epochs

        if start + batch_size > self.num_examples:
            self._index_in_epochs = self.num_examples
            end = self._index_in_epochs
            return self._exampls[start:end], self._labels[start:end]
        else:
            self._index_in_epochs += batch_size
            end = self._index_in_epochs
            return self._exampls[start:end], self._labels[start:end]


class TrainSet(object):
    def __init__(self, examples_list, position_data):
        self.examples_list = examples_list
        self.position_data = position_data
        self.fbank_position = 0   # 记住训练集读取到了什么位置
        self.index_in_epochs = 0  # 调用next_batch()函数后记住上一次位置
        self.example = []
        self.labels = []
        self.num_examples = 0
	# 每次读入十个fbank拼帧,样本列表用类似循环列表的方式存储
    def read_train_set(self):
        self.example = []
        self.labels = []
        self.num_examples = 0
        step_length = 10
        start = self.fbank_position % len(self.examples_list)
        end = (self.fbank_position + step_length) % len(self.examples_list)
        if start < end:
            fbank_list = self.examples_list[start: end]
            self.fbank_position += step_length

        else:
            fbank_list = self.examples_list[start: len(self.examples_list)]
            self.fbank_position = 0
            index = np.arange(len(self.examples_list))
            np.random.shuffle(index)
            self.examples_list = np.array(self.examples_list)[index]

        for example in fbank_list:
            if example == '':
                continue
            file_path = "E://aslp_wake_up_word_data/data/positive/train/" + \
                        example + ".fbank"
            if os.path.exists(file_path):
                start = self.position_data.find(example)
                end = self.position_data.find("positive", start + 1)
                if end != -1:
                    position_str = self.position_data[start + 15: end - 1]
                else:
                    position_str = self.position_data[start + 15: end]

                # start and end position of "hello" & start and end position of "xiao gua"
                keyword_position = position_str.split(" ")

                file_path = "E://aslp_wake_up_word_data/data/positive/train/" + \
                        example + ".fbank"

                keyword_frame_position = []
                for i in range(4):
                    fbank = fbank_reader.HTKFeat_read(file_path).getall()
                    length = fbank.shape[0]
                    frame_position = int(keyword_position[i]) // 160
                    if frame_position >= length:
                        frame_position = length - 1
                    keyword_frame_position.append(frame_position)

                print(example)
                for frame in range(keyword_frame_position[0], keyword_frame_position[1] + 1):
                    self.example.append(
                        frame_combine(frame, file_path, keyword_frame_position[0], keyword_frame_position[1]))
                    self.labels.append('0')
                    self.num_examples += 1
                for frame in range(keyword_frame_position[2], keyword_frame_position[3] + 1):
                    self.example.append(
                        frame_combine(frame, file_path, keyword_frame_position[2], keyword_frame_position[3]))
                    self.labels.append('1')
                    self.num_examples += 1

            else:
                file_path = "E://aslp_wake_up_word_data/data/negative/train/" + \
                            example + ".fbank"

                fbank = fbank_reader.HTKFeat_read(file_path).getall()
                frame_number = fbank.shape[0]

                print(example)
                for frame in range(frame_number):
                    self.example.append(frame_combine(frame, file_path, 0, frame_number - 1))
                    self.labels.append('2')
                    self.num_examples += 1

    def next_batch(self, batch_size):
        start = self.index_in_epochs

        if start == 0:
            self.read_train_set()
            index0 = np.arange(self.num_examples)
            np.random.shuffle(index0)
            self.example = np.array(self.example)[index0]
            self.labels = np.array(self.labels)[index0]

        if start + batch_size > self.num_examples:
            examples_rest_part = self.example[start: self.num_examples]
            labels_rest_part = self.labels[start: self.num_examples]
            self.index_in_epochs = 0
            return examples_rest_part, labels_rest_part

        else:
            self.index_in_epochs += batch_size
            end = self.index_in_epochs
            return self.example[start:end], self.labels[start:end]
模型搭建

这里搭建的是全连接神经网络,隐层大小为3×128:

# tensor_build.py
import tensorflow as tf

NUM_CLASSES = 3


def inference(speeches, hidden1_units, hidden2_units, hidden3_units):
	# 搭建网络
    hidden1 = tf.contrib.layers.fully_connected(speeches, hidden1_units)
    tf.nn.dropout(hidden1, keep_prob=0.9)
    hidden2 = tf.contrib.layers.fully_connected(hidden1, hidden2_units)
    tf.nn.dropout(hidden2, keep_prob=0.9)
    hidden3 = tf.contrib.layers.fully_connected(hidden2, hidden3_units)
    tf.nn.dropout(hidden3, keep_prob=0.9)
    output_logits = tf.contrib.layers.fully_connected(hidden3, NUM_CLASSES)

    return output_logits


def loss(logits, labels):
    # 计算交叉熵,作为损失函数
    labels = tf.to_int64(labels)
    return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)


def training(loss, learning_rate):
    tf.summary.scalar('loss', loss)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = optimizer.minimize(loss, global_step=global_step)
    return train_op
模型训练
def run_training():

    train, test = input_data.read_data_sets()

    with tf.Graph().as_default():
        speeches_placeholder, labels_placeholder = placeholder_inputs()

        logits = tensor_build.inference(speeches_placeholder, FLAGS.hidden1, FLAGS.hidden2, FLAGS.hidden3)

        outputs = tf.nn.softmax(logits=logits)

        loss = tensor_build.loss(logits, labels_placeholder)

        train_op = tensor_build.training(loss, FLAGS.learning_rate)

        summary = tf.summary.merge_all()

        init = tf.global_variables_initializer()

        saver = tf.train.Saver()

        sess = tf.Session()

        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(init)

        test_false_alarm_rate_list = []
        test_false_reject_rate_list = []
        loss_list = []
        total_loss = []
        for step in range(FLAGS.max_steps):

            feed_dict = fill_feed_dict(train, speeches_placeholder, labels_placeholder)
            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
            loss_list.append(loss_value)
            if step % 25897 == 0 and step != 0:
                total_loss.append(sum(loss_list[step - 25897: step]) / 25897)
            if step % 100 == 0:
                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            if step + 1 == FLAGS.max_steps:
                checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
                # 以下可以暂时忽略。进行测试,并且对测试结果进行评估,计算误唤醒率与误拒绝率
                test_false_alarm_rate_list, test_false_reject_rate_list = do_eval(sess, speeches_placeholder,
                                                                                  labels_placeholder, test, outputs)
                print(total_loss)
    # 画出ROC曲线
    plot(test_false_alarm_rate_list, test_false_reject_rate_list)

后处理

平滑&置信度计算

平滑公式如下(第j帧的第i个标签的平滑后概率)

p i j ′ = { 1 j ∑ k = 0 j p i k , if    j ≤ 30 1 30 ∑ k = j − 29 j p i k , if    j > 30 p_{ij}^{'}= \begin{cases} \frac{1}{j}\sum\limits_{k=0}^jp_{ik}, & \text{if} \; j \leq 30\\[3ex] \frac{1}{30}\sum\limits_{k=j-29}^jp_{ik}, & \text{if} \; j > 30 \end{cases} pij=j1k=0jpik,301k=j29jpik,ifj30ifj>30

置信度公式如下(第j帧的置信度)

c o n f i d e n c e = { ∏ i = 1 2 1 j max ⁡ 1 ≤ k ≤ j p i k if    j ≤ 100 ∏ i = 1 2 1 100 max ⁡ j − 99 ≤ k ≤ j p i k if    j > 100 confidence= \begin{cases} \sqrt{\prod\limits_{i =1}^2\frac{1}{j}\max\limits_{1\leq k \leq j}p_{ik}} & \text{if} \; j \leq 100\\[4ex] \sqrt{\prod\limits_{i =1}^2\frac{1}{100}\max\limits_{j-99 \leq k\leq j}p_{ik}} & \text{if} \; j > 100 \end{cases} confidence=i=12j11kjmaxpik i=121001j99kjmaxpik ifj100ifj>100

整个音频文件的置信度就是其每一帧对应的置信度中的最大值,与唤醒的阈值比较,就能得到是否唤醒的判断

def find_max(smooth_probability):
    length = len(smooth_probability)
    max1 = smooth_probability[0][0]
    max2 = smooth_probability[0][1]
    for i in range(length):
        if smooth_probability[i][0] > max1:
            max1 = smooth_probability[i][0]
        if smooth_probability[i][1] > max2:
            max2 = smooth_probability[i][1]
    return max1, max2


def posterior_handling(probability, fbank_end_frame):
    confidence = []
    for i in range(len(fbank_end_frame)):
        if i == 0:
            fbank_probability = probability[0: fbank_end_frame[0] - 1]
        else:
            fbank_probability = probability[fbank_end_frame[i-1]: fbank_end_frame[i] - 1]
        smooth_probability = []
        frame_confidence = []

        for j in range(len(fbank_probability)):
            if j + 1 <= 30:
                smooth_probability.append(np.sum((np.array(fbank_probability[0: j + 1])/(j + 1)), axis=0).tolist())
            else:
                smooth_probability.append(np.sum((np.array(fbank_probability[j - 30: j + 1])/30), axis=0).tolist())
        for j in range(len(fbank_probability)):
            if j + 1 <= 100:
                max1, max2 = find_max(smooth_probability[0: j + 1])
                frame_confidence.append(max1 * max2)
            else:
                max1, max2 = find_max(smooth_probability[j - 100: j + 1])
                frame_confidence.append(max1 * max2)
        confidence.append(math.sqrt(max(frame_confidence)))
    return confidence

完整代码

# tensor_build.py
import tensorflow as tf

NUM_CLASSES = 3


def inference(speeches, hidden1_units, hidden2_units, hidden3_units):
    hidden1 = tf.contrib.layers.fully_connected(speeches, hidden1_units)
    tf.nn.dropout(hidden1, keep_prob=0.9)
    hidden2 = tf.contrib.layers.fully_connected(hidden1, hidden2_units)
    tf.nn.dropout(hidden2, keep_prob=0.9)
    hidden3 = tf.contrib.layers.fully_connected(hidden2, hidden3_units)
    tf.nn.dropout(hidden3, keep_prob=0.9)
    output_logits = tf.contrib.layers.fully_connected(hidden3, NUM_CLASSES)

    return output_logits


def loss(logits, labels):
    # 计算交叉熵,作为损失函数
    labels = tf.to_int64(labels)
    return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)


def training(loss, learning_rate):
    tf.summary.scalar('loss', loss)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = optimizer.minimize(loss, global_step=global_step)
    return train_op
# fbank_reader.py
# Copyright (c) 2007 Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""Read HTK feature files.

This module reads the acoustic feature files used by HTK
"""

__author__ = "David Huggins-Daines <dhuggins@cs.cmu.edu>"
__version__ = "$Revision $"

from struct import unpack, pack
import numpy

LPC = 1
LPCREFC = 2
LPCEPSTRA = 3
LPCDELCEP = 4
IREFC = 5
MFCC = 6
FBANK = 7
MELSPEC = 8
USER = 9
DISCRETE = 10
PLP = 11

_E = 0o0000100 # has energy
_N = 0o0000200 # absolute energy supressed
_D = 0o0000400 # has delta coefficients
_A = 0o0001000 # has acceleration (delta-delta) coefficients
_C = 0o0002000 # is compressed
_Z = 0o0004000 # has zero mean static coefficients
_K = 0o0010000 # has CRC checksum
_O = 0o0020000 # has 0th cepstral coefficient
_V = 0o0040000 # has VQ data
_T = 0o0100000 # has third differential coefficients


class HTKFeat_read(object):
    "Read HTK format feature files"
    def __init__(self, filename=None):
        self.swap = (unpack('=i', pack('>i', 42))[0] != 42)
        if (filename != None):
            self.open(filename)

    def __iter__(self):
        self.fh.seek(12, 0)
        return self

    def open(self, filename):
        self.filename = filename
        # To run in python2, change the "open" to "file"
        self.fh = open(filename, "rb")
        self.readheader()

    def readheader(self):
        self.fh.seek(0, 0)
        spam = self.fh.read(12)
        self.nSamples, self.sampPeriod, self.sampSize, self.parmKind = unpack(">IIHH", spam)
        # Get coefficients for compressed data
        if self.parmKind & _C:
            self.dtype = 'h'
            self.veclen = self.sampSize / 2
            if self.parmKind & 0x3f == IREFC:
                self.A = 32767
                self.B = 0
            else:
                self.A = numpy.fromfile(self.fh, 'f', self.veclen)
                self.B = numpy.fromfile(self.fh, 'f', self.veclen)
                if self.swap:
                    self.A = self.A.byteswap()
                    self.B = self.B.byteswap()
        else:
            self.dtype = 'f'
            self.veclen = self.sampSize / 4
        self.hdrlen = self.fh.tell()

    def seek(self, idx):
        self.fh.seek(self.hdrlen + idx * self.sampSize, 0)

    def next(self):
        vec = numpy.fromfile(self.fh, self.dtype, self.veclen)
        if len(vec) == 0:
            raise StopIteration
        if self.swap:
            vec = vec.byteswap()
        # Uncompress data to floats if required
        if self.parmKind & _C:
            vec = (vec.astype('f') + self.B) / self.A
        return vec

    def readvec(self):
        return self.next()

    def getall(self):
        self.seek(0)
        data = numpy.fromfile(self.fh, self.dtype)
        if self.parmKind & _K: # Remove and ignore checksum
            data = data[:-1]
        data = data.reshape(int(len(data)/self.veclen), int(self.veclen))
        if self.swap:
            data = data.byteswap()
        # Uncompress data to floats if required
        if self.parmKind & _C:
            data = (data.astype('f') + self.B) / self.A
        return data
# input_data.py
import fbank_reader
import numpy as np
import os

# 测试集类
class TestSet(object):
    def __init__(self, exampls, labels, num_examples, fbank_end_frame):
        self._exampls = exampls
        self._labels = labels
        self._index_in_epochs = 0  # 调用next_batch()函数后记住上一次位置
        self.num_examples = num_examples  # 训练样本数
        self.fbank_end_frame = fbank_end_frame

    def next_batch(self, batch_size):
        start = self._index_in_epochs

        if start + batch_size > self.num_examples:
            self._index_in_epochs = self.num_examples
            end = self._index_in_epochs
            return self._exampls[start:end], self._labels[start:end]
        else:
            self._index_in_epochs += batch_size
            end = self._index_in_epochs
            return self._exampls[start:end], self._labels[start:end]

# 训练集类
class TrainSet(object):
    def __init__(self, examples_list, position_data):
        self.examples_list = examples_list
        self.position_data = position_data
        self.fbank_position = 0   # 记住训练集读取到了什么位置
        self.index_in_epochs = 0  # 调用next_batch()函数后记住上一次位置
        self.example = []
        self.labels = []
        self.num_examples = 0

    def read_train_set(self):
        self.example = []
        self.labels = []
        self.num_examples = 0
        step_length = 10
        start = self.fbank_position % len(self.examples_list)
        end = (self.fbank_position + step_length) % len(self.examples_list)
        if start < end:
            fbank_list = self.examples_list[start: end]
            self.fbank_position += step_length

        else:
            fbank_list = self.examples_list[start: len(self.examples_list)]
            self.fbank_position = 0
            index = np.arange(len(self.examples_list))
            np.random.shuffle(index)
            self.examples_list = np.array(self.examples_list)[index]

        for example in fbank_list:
            if example == '':
                continue
            file_path = "/home/disk2/internship_anytime/aslp_hotword_data/aslp_wake_up_word_data/data/positive/train/" + \
                        example + ".fbank"
            if os.path.exists(file_path):
                start = self.position_data.find(example)
                end = self.position_data.find("positive", start + 1)
                if end != -1:
                    position_str = self.position_data[start + 15: end - 1]
                else:
                    position_str = self.position_data[start + 15: end]

                # start and end position of "hello" & start and end position of "xiao gua"
                keyword_position = position_str.split(" ")

                file_path = "E://aslp_wake_up_word_data/data/positive/train/" + \
                        example + ".fbank"

                keyword_frame_position = []
                for i in range(4):
                    fbank = fbank_reader.HTKFeat_read(file_path).getall()
                    length = fbank.shape[0]
                    frame_position = int(keyword_position[i]) // 160
                    if frame_position >= length:
                        frame_position = length - 1
                    keyword_frame_position.append(frame_position)

                print(example)
                for frame in range(keyword_frame_position[0], keyword_frame_position[1] + 1):
                    self.example.append(
                        frame_combine(frame, file_path, keyword_frame_position[0], keyword_frame_position[1]))
                    self.labels.append('0')
                    self.num_examples += 1
                for frame in range(keyword_frame_position[2], keyword_frame_position[3] + 1):
                    self.example.append(
                        frame_combine(frame, file_path, keyword_frame_position[2], keyword_frame_position[3]))
                    self.labels.append('1')
                    self.num_examples += 1

            else:
                file_path = "E://aslp_wake_up_word_data/data/negative/train/" + \
                            example + ".fbank"

                fbank = fbank_reader.HTKFeat_read(file_path).getall()
                frame_number = fbank.shape[0]

                print(example)
                for frame in range(frame_number):
                    self.example.append(frame_combine(frame, file_path, 0, frame_number - 1))
                    self.labels.append('2')
                    self.num_examples += 1

    def next_batch(self, batch_size):
        start = self.index_in_epochs

        if start == 0:
            self.read_train_set()
            index0 = np.arange(self.num_examples)
            np.random.shuffle(index0)
            self.example = np.array(self.example)[index0]
            self.labels = np.array(self.labels)[index0]

        if start + batch_size > self.num_examples:
            examples_rest_part = self.example[start: self.num_examples]
            labels_rest_part = self.labels[start: self.num_examples]
            self.index_in_epochs = 0
            return examples_rest_part, labels_rest_part

        else:
            self.index_in_epochs += batch_size
            end = self.index_in_epochs
            return self.example[start:end], self.labels[start:end]

# 用于拼帧
def frame_combine(frame, file_path, start, end):
    fbank = fbank_reader.HTKFeat_read(file_path).getall()

    if end - start + 1 < 41:
        if frame - start <= 30 and end - frame <= 10:
            frame_to_combine = []
            front_rest = 30 - (frame - start)
            back_rest = 10 - (end - frame)
            for i in range(front_rest):
                frame_to_combine.append(fbank[start].tolist())
            for i in range(start, end + 1):
                frame_to_combine.append(fbank[i].tolist())
            for i in range(back_rest):
                frame_to_combine.append(fbank[end].tolist())

        elif end - frame >= 10:
            frame_to_combine = []
            front_rest = 30 - (frame - start)
            for i in range(front_rest):
                frame_to_combine.append(fbank[start].tolist())
            for i in range(start, frame+11):
                frame_to_combine.append(fbank[i].tolist())

        else:
            frame_to_combine = []
            back_rest = 10 - (end - frame)
            for i in range(frame - 30, end + 1):
                frame_to_combine.append(fbank[i].tolist())
            for i in range(back_rest):
                frame_to_combine.append(fbank[end].tolist())
        combined = np.array(frame_to_combine).reshape(-1)

    else:
        if frame - start >= 30 and end - frame >= 10:
            frame_to_combine = fbank[frame - 30: frame + 11]
            combined = frame_to_combine.reshape(-1)

        elif frame - start < 30:
            frame_to_combine = fbank[start: start+41]
            combined = frame_to_combine.reshape(-1)

        else:
            frame_to_combine = fbank[end - 40: end+1]
            combined = frame_to_combine.reshape(-1)

    return combined.tolist()

# 制作可以直接获取下一批样本的数据集
def read_data_sets():
    f = open("E://aslp_wake_up_word_data/positiveKeywordPosition.txt", "r")
    position_data = f.read()
    f.close()

    f = open("E://aslp_wake_up_word_data/train_positive.list", "r")
    temp = f.read()
    train_positive_list = temp.split('\n')
    f.close()

    f = open("E://aslp_wake_up_word_data/test_positive.list", "r")
    temp = f.read()
    test_positive_list = temp.split('\n')
    f.close()

    f = open("E://aslp_wake_up_word_data/train_negative.list", "r")
    temp = f.read()
    train_negative_list = temp.split('\n')
    f.close()

    f = open("E://aslp_wake_up_word_data/test_negative.list", "r")
    temp = f.read()
    test_negative_list = temp.split('\n')
    f.close()

    test_examples = []
    test_labels = []
    test_length = []
    test_num = 0

    for example in test_positive_list:
        if example == '':
            continue
        start = position_data.find(example)
        end = position_data.find("positive", start + 1)
        if end != -1:
            position_str = position_data[start + 15: end - 1]
        else:
            position_str = position_data[start + 15: end]

        # start and end position of "hello" & start and end position of "xiao gua"
        keyword_position = position_str.split(" ")

        file_path = "E://aslp_wake_up_word_data/data/positive/test/" + \
                    example + ".fbank"

        keyword_frame_position = []
        for i in range(4):
            fbank = fbank_reader.HTKFeat_read(file_path).getall()
            length = fbank.shape[0]
            frame_position = int(keyword_position[i]) // 160
            if frame_position >= length:
                frame_position = length - 1
            keyword_frame_position.append(frame_position)

        test_length.append(keyword_frame_position[1] - keyword_frame_position[0] + 1 +
                                    keyword_frame_position[3] - keyword_frame_position[2] + 1)

        print(example)
        for frame in range(keyword_frame_position[0], keyword_frame_position[1] + 1):
            test_examples.append(frame_combine(frame, file_path, keyword_frame_position[0], keyword_frame_position[1]))
            test_labels.append('0')
            test_num += 1
        for frame in range(keyword_frame_position[2], keyword_frame_position[3] + 1):
            test_examples.append(frame_combine(frame, file_path, keyword_frame_position[2], keyword_frame_position[3]))
            test_labels.append('1')
            test_num += 1

    for example in test_negative_list:
        if example == '':
            continue
        file_path = "/E://aslp_wake_up_word_data/data/negative/test/" + \
                    example + ".fbank"

        fbank = fbank_reader.HTKFeat_read(file_path).getall()
        frame_number = fbank.shape[0]
        test_length.append(frame_number)
        print(example)
        for frame in range(frame_number):
            test_examples.append(frame_combine(frame, file_path, 0, frame_number - 1))
            test_labels.append('2')
            test_num += 1
    fbank_end_frame = []
    for i in range(len(test_length)):
        fbank_end_frame.append(sum(test_length[0: i+1]))

    train_list = train_positive_list + train_negative_list

    train = TrainSet(train_list, position_data)
    test = TestSet(test_examples, test_labels, test_num, fbank_end_frame)

    return train, test

# main.py
import argparse
import os
import sys
import tensorflow as tf
import input_data
import tensor_build
import matplotlib.pyplot as plt
import numpy as np
import math

FLAGS = None

# 用于绘制ROC曲线
def plot(false_alarm_rate_list, false_reject_rate_list):

    plt.figure(figsize=(8, 4))
    plt.plot(false_alarm_rate_list, false_reject_rate_list)
    plt.xlabel('false_alarm_rate')
    plt.ylabel('false_reject_rate')
    plt.title('ROC')
    plt.show()


def placeholder_inputs():
    speeches_placeholder = tf.placeholder(tf.float32, shape=(None, 1640))
    labels_placeholder = tf.placeholder(tf.int32, shape=(None))
    return speeches_placeholder, labels_placeholder

# 用于为placeholder赋值
def fill_feed_dict(data_set, examples_pl, labels_pl):

    examples_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)

    feed_dict = {
        examples_pl: examples_feed,
        labels_pl: labels_feed,
    }
    return feed_dict


def find_max(smooth_probability):
    length = len(smooth_probability)
    max1 = smooth_probability[0][0]
    max2 = smooth_probability[0][1]
    for i in range(length):
        if smooth_probability[i][0] > max1:
            max1 = smooth_probability[i][0]
        if smooth_probability[i][1] > max2:
            max2 = smooth_probability[i][1]
    return max1, max2

# 用于进行数据的后处理
def posterior_handling(probability, fbank_end_frame):
    confidence = []
    for i in range(len(fbank_end_frame)):
        if i == 0:
            fbank_probability = probability[0: fbank_end_frame[0] - 1]
        else:
            fbank_probability = probability[fbank_end_frame[i-1]: fbank_end_frame[i] - 1]
        smooth_probability = []
        frame_confidence = []

        for j in range(len(fbank_probability)):
            if j + 1 <= 30:
                smooth_probability.append(np.sum((np.array(fbank_probability[0: j + 1])/(j + 1)), axis=0).tolist())
            else:
                smooth_probability.append(np.sum((np.array(fbank_probability[j - 30: j + 1])/30), axis=0).tolist())
        for j in range(len(fbank_probability)):
            if j + 1 <= 100:
                max1, max2 = find_max(smooth_probability[0: j + 1])
                frame_confidence.append(max1 * max2)
            else:
                max1, max2 = find_max(smooth_probability[j - 100: j + 1])
                frame_confidence.append(max1 * max2)
        confidence.append(math.sqrt(max(frame_confidence)))
    return confidence

# 用于计算不同唤醒阈值下的误唤醒率与误拒绝率作为评估指标
def do_eval(sess, speeches_placeholder, labels_placeholder, data_set, outputs):
    threshold_part = 10000
    steps_per_epoch = data_set.num_examples // FLAGS.batch_size
    probability = []
    label = []
    false_alarm_rate_list = []
    false_reject_rate_list = []
    for step in range(steps_per_epoch + 1):
        feed_dict = fill_feed_dict(data_set, speeches_placeholder, labels_placeholder)
        result_to_compare = sess.run([outputs, labels_placeholder], feed_dict=feed_dict)
        probability.extend(result_to_compare[0].tolist())
        label.extend(result_to_compare[1].tolist())

    fbank_end_frame = data_set.fbank_end_frame

    confidence = posterior_handling(probability, fbank_end_frame)

    for i in range(threshold_part):
        threshold = float(i) / threshold_part
        if threshold == 0:
            continue
        true_alarm = true_reject = false_reject = false_alarm = 0
        for j in range(len(confidence)):
            if j == 0:
                if confidence[j] < threshold:
                    if label[0] == 2:
                        true_reject += 1
                    else:
                        false_reject += 1
                if confidence[j] >= threshold:
                    if label[0] == 2:
                        false_alarm += 1
                    else:
                        true_alarm += 1
                continue
            if confidence[j] < threshold:
                if label[fbank_end_frame[j-1]] == 2:
                    true_reject += 1
                else:
                    false_reject += 1
            if confidence[j] >= threshold:
                if label[fbank_end_frame[j-1]] == 2:
                    false_alarm += 1
                else:
                    true_alarm += 1
        if false_reject + true_reject == 0 or false_alarm + true_alarm == 0:
            continue
        false_alarm_rate = float(false_alarm) / (false_alarm + true_alarm)
        false_reject_rate = float(false_reject) / (false_reject + true_reject)
        false_alarm_rate_list.append(false_alarm_rate)
        false_reject_rate_list.append(false_reject_rate)
    return false_alarm_rate_list, false_reject_rate_list


def run_training():

    train, test = input_data.read_data_sets()

    with tf.Graph().as_default():
        speeches_placeholder, labels_placeholder = placeholder_inputs()

        logits = tensor_build.inference(speeches_placeholder, FLAGS.hidden1, FLAGS.hidden2, FLAGS.hidden3)

        outputs = tf.nn.softmax(logits=logits)

        loss = tensor_build.loss(logits, labels_placeholder)

        train_op = tensor_build.training(loss, FLAGS.learning_rate)

        summary = tf.summary.merge_all()

        init = tf.global_variables_initializer()

        saver = tf.train.Saver()

        sess = tf.Session()

        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(init)

        test_false_alarm_rate_list = []
        test_false_reject_rate_list = []
        loss_list = []
        total_loss = []
        for step in range(FLAGS.max_steps):

            feed_dict = fill_feed_dict(train, speeches_placeholder, labels_placeholder)
            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
            loss_list.append(loss_value)
            if step % 25897 == 0 and step != 0:
                total_loss.append(sum(loss_list[step - 25897: step]) / 25897)
            if step % 300 == 0:
                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            if step + 1 == FLAGS.max_steps:
                checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
                test_false_alarm_rate_list, test_false_reject_rate_list = do_eval(sess, speeches_placeholder,
                                                                                  labels_placeholder, test, outputs)
                print(total_loss)
    plot(test_false_alarm_rate_list, test_false_reject_rate_list)


def main(_):

    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)
    run_training()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.001,
                        help='Initial learning rate.')
    parser.add_argument('--max_steps',
                        type=int,
                        default=78000,
                        help='Number of steps to run trainer.')
    parser.add_argument('--hidden1',
                        type=int,
                        default=128,
                        help='Number of units in hidden layer 1.')
    parser.add_argument('--hidden2',
                        type=int,
                        default=128,
                        help='Number of units in hidden layer 2.')
    parser.add_argument('--hidden3',
                        type=int,
                        default=128,
                        help='Number of units in hidden layer 3.')
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='Batch size.  Must divide evenly into the dataset sizes.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=os.path.join(os.getenv('TEST_TMPDIR', 'E:\\'),
                                             'wake_up/logs/fully_connected_feed_lyh'),
                        help='Directory to put the log data.')

    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

实验结果

评估指标用的ROC曲线,分别以误唤醒率与误拒绝率为横纵坐标(对比了3×128与5×128):在这里插入图片描述

  • 4
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值