from pyannote.core import Segment
import whisper
import pickle
import torch
import time
import os
from zhconv import convert
from pyannote.audio import Pipeline
from pyannote.core import Annotation
def get_text_with_timestamp(transcribe_res):
timestamp_texts = []
for item in transcribe_res["segments"]:
start = item["start"]
end = item["end"]
text=item["text"]
timestamp_texts.append((Segment(start, end), text))
return timestamp_texts
def add_speaker_info_to_text(timestamp_texts, ann):
spk_text = []
for seg, text in timestamp_texts:
print(ann.crop(seg))
spk = ann.crop(seg).argmax()
spk_text.append((seg, spk, text))
return spk_text
def merge_cache(text_cache):
sentence = ''.join([item[-1] for item in text_cache])
spk = text_cache[0][1]
start = round(text_cache[0][0].start, 1)
end = round(text_cache[-1][0].end, 1)
return Segment(start, end), spk, sentence
PUNC_SENT_END = ['.', '?', '!', "。", "?", "!"]
def merge_sentence(spk_text):
merged_spk_text = []
pre_spk = None
text_cache = []
for seg, spk, text in spk_text:
if spk != pre_spk and len(text_cache) > 0:
merged_spk_text.append(merge_cache(text_cache))
text_cache = [(seg, spk, text)]
pre_spk = spk
elif spk==pre_spk and text==text_cache[-1][2]:
print(text_cache[-1][2])
print(text)
continue
else:
text_cache.append((seg, spk, text))
pre_spk = spk
if len(text_cache) > 0:
merged_spk_text.append(merge_cache(text_cache))
return merged_spk_text
def diarize_text(transcribe_res, diarization_result):
timestamp_texts = get_text_with_timestamp(transcribe_res)
spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result)
res_processed = merge_sentence(spk_text)
return res_processed
if __name__=="__main__":
sd_config_path="./models/speaker-diarization-3.1/config.yaml"
asr_model=whisper.load_model("large")
asr_model.to(torch.device("cuda"))
speaker_diarization = Pipeline.from_pretrained(sd_config_path)
speaker_diarization.to(torch.device("cuda"))
files = os.listdir("./audios_wav")
for file in files:
start_time = time.time()
print(file)
dialogue_path="./audios_txt/"+file.split(".")[0]+".pkl"
audio="./audios_wav/"+file
asr_result = asr_model.transcribe(audio,initial_prompt="输入的音频是关于一个采访内容,接下来您将扮演一个优秀记录能力的听众,通过倾听语音内容,将语音信息通过文字的方式记录下来。请你首先要判断语音中讲话者的讲话内容和语气,根据内容和语气记录带有标点符号的文本信息。具体要求为:1、中文语音的文本字体为简体中文,其他类型语音根据语音中说话的语种类型记录;2、文本信息的标点符号和文本内容要准确,不能捏造信息,同一段语音不能重复识别,不能捏造语音的语种类型;示例输出格式:-就AI的研发和部署而言,为什么你觉得中国很快就能赶上、甚至赶超美国?-首要原因是AI已经完成了从探索阶段到应用阶段的转型。在探索阶段,最先取得探索成果的人拥有绝对优势;然而现在AI算法已为诸多业内实践人士所熟知。所以,现在的关键在于速度、执行、资本以及对海量数据的获取,而中国在以上每个层面都具有优势。")
asr_time=time.time()
print("ASR time:"+str(asr_time-start_time))
diarization_result: Annotation = speaker_diarization(audio)
final_result = diarize_text(asr_result, diarization_result)
dialogue=[]
for segment, spk, sent in final_result:
content={'speaker':spk,'start':segment.start,'end': segment.end,'text':sent}
dialogue.append(content)
print("%s [%.2fs -> %.2fs] %s " % (spk,segment.start, segment.end, sent))
with open(dialogue_path, 'wb') as f:
pickle.dump(dialogue, f)
end_time = time.time()
print(file+" spend time:"+str(end_time-start.time))