前言
本博客对小白不太友好,需要熟悉其声纹识别代码架构和有编程基础的同学。
本篇博客参考的是夜雨飘零的开源项目,但是由于其模型需要付费下载,所以本篇指南不提供Torch、Onnx和Rknn模型,如有需要请自行前往原作者处购买。
写本篇博客是因为本人项目上的需要,顺便巩固一下模型转换的相关内容。在本篇中依然需要Torch框架,因为其声纹提取使用的是Torchaudio的Kaldi,当然有可能也可以使用类似libsora之类的音频信息提取库进行替代,但是其训练的模型依然是使用Kaldi,所以如果盲目使用其它音频分析工具可能导致最后的声纹对比出现异常,当然这里也涉及到特征形状转换的工作,有兴趣的朋友可以试一试。
本次只是对其提供的对比模型进行Onnx和Rknn化,在Ubuntu22.04下,Onnx版本能够比Torch版本快大约100-200ms(6s音频),精度浮动不大。
Torch转换Onnx
PS:项目刚开始的时候应该是默认需要GPU才能进行推理的,要改成CPU需要将其部分代码改一下,因为比较简单,这里不进行说明了。
在转换之前,我们先要了解其输入形状,我们在其mvector/predict.py中可以找到模型推理代码:
def predict(self,
audio_data,
sample_rate=16000):
"""预测一个音频的特征
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy,AudioSegment对象。如果是字节的话,必须是完整并带格式的字节文件
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
:return: 声纹特征向量
"""
# 加载音频文件,并进行预处理
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
input_data = torch.tensor(input_data.samples, dtype=torch.float32).unsqueeze(0)
audio_feature = self._audio_featurizer(input_data).to(self.device)
# 执行预测
feature = self.predictor(audio_feature).data.cpu().numpy()[0]
return feature
其中,audio_feature就是预处理后的tensor数据,我们将其形状打印出来,那么其类似于:
torch.Size([1, 184, 80])
torch.Size([1, 201, 80])
那么就可以清楚,其输入数据的第二个维度是动态的,不断变动的值就和你的音频长度有关,那么一会转换Onnx的时候我们就需要对其进行设置。
def convert_to_onnx(model, input_tensor, output_path="models/model.onnx", opset_version=11):
model.eval()
torch.onnx.export(
model, # 要转换的模型
input_tensor, # 示例输入张量
output_path, # 输出 ONNX 文件的路径
export_params=True, # 是否导出参数
opset_version=opset_version, # ONNX opset 版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入张量的名称
output_names=["output"], # 输出张量的名称
dynamic_axes={"input": {1:"sequence_length"}, "output": {0: "batch_size"} } # 动态轴支持
)
将Torch模型设置为评估模式后,进行Onnx转换,input_tensor随便填一个符合要求的就可以了,opset_version这里填的是11(这个与后面选择哪个Rknntoolkit有关系,详情请看),dynamic_axes需要设置输入的第二个维度为动态的。
注意这个model不能直接将其使用torch.load,而是应该通过mvector/utils/checkpoint.py中的load_pretrained函数获得,如下:
predictor = load_pretrained(predictor, model_path, use_gpu=False).eval()
convert_to_onnx(predictor, input_tensor)
这样我们就可以顺利转换为Onnx模型了。
使用Onnx进行推理的代码如下,在mvector/predict.py中修改:
def inference_onnx(self,onnx_model_path, input_data):
ort_session = ort.InferenceSession(onnx_model_path)
input_name = ort_session.get_inputs()[0].name
input_data = input_data.cpu().numpy()
ort_inputs = {input_name: input_data}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]
def predict(self,
audio_data,
sample_rate=16000):
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
input_data = torch.tensor(input_data.samples, dtype=torch.float32).unsqueeze(0)
audio_feature = self._audio_featurizer(input_data).to(self.device)
feature = self.inference_onnx(config.audio_print_model_path_onnx, audio_feature.cpu().numpy())
feature = feature.squeeze()
同样的predict_batch也需要修改:
def predict_batch(self, audios_data, sample_rate=16000, batch_size=32):
"""此处省略部分代码"""
inputs = torch.tensor(inputs, dtype=torch.float32)
input_lens_ratio = torch.tensor(input_lens_ratio, dtype=torch.float32)
audio_feature = self._audio_featurizer(inputs, input_lens_ratio).to(self.device)
for i in range(0, input_size, batch_size):
batch_audio_feature = audio_feature[i:i + batch_size]
separated_features = [batch_audio_feature[j:j + 1] for j in range(batch_audio_feature.shape[0])]
for separated_feature in separated_features:
feature = self.inference_onnx(config.audio_print_model_path_onnx, separated_feature)
features.extend(feature)
features = np.array(features)
如有需要请自行修改其它功能的代码
Onnx转换Rknn
因为Rknn目前好像还不支持动态输入,所以需要指定输入shape,这里就比较烦了,我这里指定了[1,184,80],这里不建议使用Rknn进行推理。
rknn = RKNN(verbose=True)
rknn.config(mean_values=None, std_values=None, target_platform='rk3588',dynamic_input=[[[1,184,80]]])
ret = rknn.load_onnx(model=ONNX_MODEL)
if ret != 0:
print('Load model failed!')
print('done')
ret = rknn.build(do_quantization=False)
if ret != 0:
print('Load model failed!')
print('done')
ret = rknn.init_runtime()
# ret = rknn.init_runtime('rk3566')
if ret != 0:
print('Init runtime environment failed!')
# print('done')
print('--> Export rknn model')
ret = rknn.export_rknn(RKNN_MODEL)
if ret != 0:
print('Export rknn model failed!')
# exit(ret)
print('done')