import whisper
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
import io
import torch
import numpy as np
from pydantic import BaseModel
import os
import tempfile
# 加载 Whisper 模型
model = whisper.load_model("medium")# FastAPI 应用
app = FastAPI()defpad_or_trim(array, length:int=16000*30,*, axis:int=-1):"""
Split the audio array into multiple segments of length N_SAMPLES if it exceeds N_SAMPLES.
If it is shorter, pad the array to N_SAMPLES.
"""if torch.is_tensor(array):
arrays =[]for i inrange(0, array.shape[axis], length):
segment = array.index_select(
dim=axis, index=torch.arange(i,min(i + length, array.shape[axis]), device=array.device))if segment.shape[axis]< length:
pad_widths =[(0,0)]* segment.ndim
pad_widths[axis]=(0, length - segment.shape[axis])
segment = torch.nn.functional.pad(segment,[pad for sizes in pad_widths[::-1]for pad in sizes])
arrays.append(segment)
array = torch.stack(arrays)else:
arrays =[]for i inrange(0, array.shape[axis], length):
segment = array.take(indices=range(i,min(i + length, array.shape[axis])), axis=axis)if segment.shape[axis]< length:
pad_widths =[(0,0)]* segment.ndim
pad_widths[axis]=(0, length - segment.shape[axis])
segment = np.pad(segment, pad_widths)
arrays.append(segment)
array = np.stack(arrays)return array
@app.post("/transcribe/")asyncdeftranscribe_audio(file: UploadFile = File(...)):# 将上传的文件保存到临时文件with tempfile.NamedTemporaryFile(delete=False)as tmp:
tmp.write(file.file.read())
tmp_path = tmp.name
audio = whisper.load_audio(tmp_path)# Pad or trim the audio to 30 seconds
audio = pad_or_trim(audio)asyncdefgenerate():# Process each audio segment# for segment in audio:# Compute the log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)# Detect the spoken language
_, probs = model.detect_language(mel)for probs in probs:
detected_language =max(probs, key=probs.get)yieldf"data: Detected language: {detected_language}\n\n"# Decode the audio
options = whisper.DecodingOptions()
result = whisper.decode(model, mel, options)# Yield the transcribed textfor r in result:yieldf"data: {r.text}\n\n"# 删除临时文件
os.remove(tmp_path)return StreamingResponse(generate(), media_type="text/event-stream")# 运行 FastAPI 应用if __name__ =="__main__":import uvicorn
uvicorn.run(app, host="0.0.0.0", port=6006)
测试方法
1. 使用 curl
curl-X POST "http://127.0.0.1:8000/transcribe/"-H"accept: text/event-stream"-H"Content-Type: multipart/form-data"-F"file=@path_to_your_audio_file"
2. 使用 requests Python 脚本
import requests
url ="http://localhost:6006/transcribe/"
file_path ="瓦解--史塔克.mp3"# 替换为你的实际音频文件路径withopen(file_path,"rb")as f:
files ={"file": f}
response = requests.post(url, files=files, headers={"accept":"text/event-stream"})print(response)for line in response.iter_lines():if line:print(line.decode("utf-8"))