Google UIS-RNN 的安装与测试

简介

Google UIS-RNN是无界交错状态回归神经网络(UIS-RNN)算法的库
官方地址:https://github.com/google/uis-rnn

环境配置

首先需要安装Anaconda Python3.6版本,具体操作可在
https://blog.csdn.net/weixin_41738734/article/details/85262238
中查到安装方法

之后按照官方的说明运行

python demo.py --train_iteration=1000 -l=0.001 -hl=100

根据提示进行安装相应的库

但是有一个特殊的库安装不同,直接进行pip install torch安装是无法安装的
具体方式如下:
首先查看CUDA的版本并记录下来

然后登陆torch官网
https://pytorch.org/

在这里插入图片描述
选择对应的版本并复制
conda install pytorch torchvision cuda100 -c pytorch
在Anaconda Prompt中并按回车安装即可

训练模型

安装完毕后打开demo.py并运行
等一段时间后显示如下信息
证明通过了测试,并且在当前目录下面会生成训练好的模型文件:saved_model.uisrnn

--------------------------------------------------------------------------------
Finished diarization experiment
Config:
  sigma_alpha: 1.0
  sigma_beta: 1.0
  crp_alpha: 1.0
  learning rate: 1e-05
  learning rate half life: 0
  regularization: 1e-05
  batch size: 10

Performance:
  averaged accuracy: 1.000000
  accuracy numbers for all testing sequences:
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
    1.000000
================================================================================

使用模型

model_args, training_args, inference_args = uisrnn.parse_arguments()
model = uisrnn.UISRNN(model_args)
model.load(SAVED_MODEL_NAME)
predicted_label = model.predict(rt_data, inference_args)

SAVED_MODEL_NAME是训练模型文件的位置
rt_data 是 输入的数据,输入的数据形状必须是二维的
并且数据格式是float64

如 shape = 2,256
其中2 代表有两个语音片段
256代表每个语音片段的大小

predicted_label 表示当前输入的rt_data有几个人在讲话
返回是一个list,list中的每个元素与rt_data的第一维相对应
如rt_data.shape = 2,256
那么predicted_label.shape = 2,
具体可以查看官方提供的demo代码

问题

不知为何model.predict奇慢无比,任务管理里看到GPU没怎么跑。都是CPU在跑,而且数据量并不大,不知道是为什么。

附录代码

基于官方demo.py修改的,可以读麦克风的数据流,送给RNN

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A demo script showing how to use the uisrnn package on toy data."""

import numpy as np
import pyaudio
import uisrnn
import queue
import time
import torch
CHUNK = 256
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 11025
SAVED_MODEL_NAME = 'saved_model.uisrnn'
counter=1



p = pyaudio.PyAudio()
audiodata = bytes()

def audio_callback(in_data, frame_count, time_info, status):
    global audiodata
    #global ad_rdy_ev
    #print (type(audiodata))
    audiodata += in_data
    #q.put(in_data)
    #ad_rdy_ev.set()
    if counter <= 0:
        return (None,pyaudio.paComplete)
    else:
        return (None,pyaudio.paContinue)



def diarization_experiment(model_args, training_args, inference_args):
  """Experiment pipeline.

  Load data --> train model --> test model --> output result

  Args:
    model_args: model configurations
    training_args: training configurations
    inference_args: inference configurations
  """

  predicted_labels = []
  test_record = []

  train_data = np.load('./data/training_data.npz')
  test_data = np.load('./data/testing_data.npz')
  train_sequence = train_data['train_sequence']
  train_cluster_id = train_data['train_cluster_id']
  test_sequences = test_data['test_sequences']
  test_cluster_ids = test_data['test_cluster_ids']
  
  
  for (test_sequence, test_cluster_id) in zip(test_sequences, test_cluster_ids):
      print (test_sequence.shape)
 
  
  model = uisrnn.UISRNN(model_args)

  # training
  #model.fit(train_sequence, train_cluster_id, training_args)
  #model.save(SAVED_MODEL_NAME)
  # we can also skip training by calling:
  model.load(SAVED_MODEL_NAME)

  # testing
  for (test_sequence, test_cluster_id) in zip(test_sequences, test_cluster_ids):
    predicted_label = model.predict(test_sequence, inference_args)
    predicted_labels.append(predicted_label)
    accuracy = uisrnn.compute_sequence_match_accuracy(
        test_cluster_id, predicted_label)
    test_record.append((accuracy, len(test_cluster_id)))
    print('Ground truth labels:')
    print(test_cluster_id)
    print('Predicted labels:')
    print(predicted_label)
    print('-' * 80)

  output_string = uisrnn.output_result(model_args, training_args, test_record)

  print('Finished diarization experiment')
  print(output_string)


def main():
  global audiodata
  """The main function."""
  #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#第一行代码

  
  
  model_args, training_args, inference_args = uisrnn.parse_arguments()
  #diarization_experiment(model_args, training_args, inference_args)
  
  
  
  model = uisrnn.UISRNN(model_args)
  model.load(SAVED_MODEL_NAME)
  
  stream = p.open(format=FORMAT,
        channels=CHANNELS,
        rate=RATE,
        input=True,
        output=False,
        frames_per_buffer=CHUNK,
        stream_callback=audio_callback)
  #ad_rdy_ev=threading.Event()
  while stream.is_active():
      #if not q.empty():
      print (len(audiodata))
      if len(audiodata) >= 2*256*2:
          #process audio data here
          rt_data = np.frombuffer(audiodata[0:2*256*2],np.dtype('<i2'))
          audiodata = audiodata[2*256*2:]
           
        
          rt_data = rt_data.astype(np.float64)
          rt_data.shape = 2,256
          
          print (len(rt_data)) 
          
          predicted_label = model.predict(rt_data, inference_args)
          print (predicted_label)
      #ad_rdy_ev.clear()
  


if __name__ == '__main__':
  main()

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值