1. flask 创建一个简单的接口:调用 http://127.0.0.1:5000/api/hello 直接输出{"data":"show"}
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/api/hello', methods=['GET'])
def hello():
return jsonify({'data':"show"})
if __name__ == '__main__':
app.run()
2.改写EDGE 推理代码编写成类
import glob
import os
from functools import cmp_to_key
from pathlib import Path
from tempfile import TemporaryDirectory
import random
import jukemirlib
import numpy as np
import torch
from tqdm import tqdm
from data.slice import slice_audio
from log.EDGE import EDGE
from data.audio_extraction.baseline_features import extract as baseline_extract
from data.audio_extraction.jukebox_features import extract as juke_extract
import os,yaml
from flask import Flask, jsonify
import logging
import logging.handlers
from flask import Flask, render_template, request, session, send_file, make_response
import warnings
import subprocess
# 忽略所有警告
warnings.filterwarnings("ignore")
# api接口
app = Flask(__name__)
# 创建 edge 类
class EDGE_api:
def __init__(self,data):
# 读取所有参数
self.key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
self.stringintkey = cmp_to_key(self.stringintcmp)
self.feature_func = juke_extract if data['feature_type'] == "jukebox" else baseline_extract
self.sample_length = data['out_length']
self.sample_size = int(self.sample_length / 2.5) - 1
self.model = EDGE(data['feature_type'], data['checkpoint'])
self.model.eval()
self.render_dir = data['render_dir']
self.motion_save_dir = data['motion_save_dir']
self.cache_features = data['cache_features']
self.conda_name = data['conda_name']
self.fbx_save_dir = data['fbx_save_dir']
self.no_render = data['no_render']
self.feature_cache_dir = data['feature_cache_dir']
# 创建log文件
self.logger = self.logging_save(data['logging_path'],data['when'],data['interval'],data['backupCount'])
self.logger.info('starting EDGE')
self.logger.info(data)
# 定义log 参数
def logging_save(self,save_path, when, interval, backupCount):
# 创建一个logger
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
# 创建一个handler,用于写入日志文件,每天创建一个新的日志文件
handler = logging.handlers.TimedRotatingFileHandler(save_path, when=when, interval=interval,
backupCount=backupCount)
# 定义handler的输出格式
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
# 给logger添加handler
logger.addHandler(handler)
return logger
def stringintcmp(self,a,b):
aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
ka, kb = self.key_func(a), self.key_func(b)
if aa < bb:
return -1
if aa > bb:
return 1
if ka < kb:
return -1
if ka > kb:
return 1
return 0
# 执行推理方法
def test(self,parm):
self.logger.info('start... {}'.format(parm))
cache_features = self.cache_features if parm['cache_features'] is None else True
render_dir = self.render_dir if parm['render_dir'] is None else parm['render_dir']
no_render = self.no_render if parm['no_render'] is None else True
wav_file = parm['wav_file']
feature_cache_dir = self.feature_cache_dir if parm['feature_cache_dir'] is None else parm['feature_cache_dir']
motion_save_dir = self.motion_save_dir if parm['motion_save_dir'] is None else parm['motion_save_dir']
if cache_features:
songname = os.path.splitext(os.path.basename(wav_file))[0]
save_dir = os.path.join(feature_cache_dir, songname)
Path(save_dir).mkdir(parents=True, exist_ok=True)
dirname = save_dir
else:
temp_dir = TemporaryDirectory()
dirname = temp_dir.name
slice_audio(wav_file, 2.5, 5.0, dirname)
file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=self.stringintkey)
# randomly sample a chunk of length at most sample_size
rand_idx = random.randint(0, len(file_list) - self.sample_size)
cond_list = []
for idx, file in enumerate(file_list):
if (not cache_features) and (not (rand_idx <= idx < rand_idx + self.sample_size)):
continue
reps, _ = self.feature_func(file)
# save reps
if cache_features:
featurename = os.path.splitext(file)[0] + ".npy"
np.save(featurename, reps)
if rand_idx <= idx < rand_idx + self.sample_size:
cond_list.append(reps)
cond_list = torch.from_numpy(np.array(cond_list))
data_tuple = None, cond_list, file_list[rand_idx : rand_idx + self.sample_size]
self.model.render_sample(
data_tuple, "test", render_dir, render_count=-1, fk_out=motion_save_dir, render=not no_render
)
torch.cuda.empty_cache()
if temp_dir in locals():
temp_dir.cleanup()
name = 'test_' + wav_file.split('\\')[-1].replace('.wav','.pkl')
outpath = os.path.join(motion_save_dir,name)
self.logger.info('end - out path : {}'.format(outpath))
os.makedirs(self.fbx_save_dir,exist_ok=True)
# 调用env2环境中的脚本
self.logger.info('start - pkl2fbx : {}'.format(outpath))
# 调用 pkl 转 fbx 方法
subprocess.run(["conda", "run", "-n", self.conda_name, "python", "smpl2fbx/Convert.py",'--input_dir', outpath ,'--output_dir', self.fbx_save_dir])
self.logger.info('end - pkl2fbx : {}'.format(self.fbx_save_dir + '/' + name.replace('.kpl','.fbx')))
return jsonify({'output_pkl': outpath,
'output_fbx': self.fbx_save_dir + '/' + name.replace('.kpl','.fbx')})
# 读取固定参数
with open("log/config.yaml", "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)['test_options']
os.environ["CUDA_VISIBLE_DEVICES"] = config['cuda']
edge = EDGE_api(config)
# api 调用
@app.route('/api/hello', methods=['GET'])
def show():
parm = {}
keys = ['cache_features','render_dir','no_render','wav_file','feature_cache_dir','motion_save_dir']
for key in keys:
parm[key] = request.args.get(key)
if parm['wav_file'] is None:
return jsonify({'error': '必须传入输入音乐地址 wav_file 字段'})
return edge.test(parm)
if __name__ == "__main__":
app.run(host=config['host'], port=config['port'], debug=False)
# edge.test(wav_file=r'te/kemu3.wav')
3.config.yaml
test_options:
feature_type: jukebox
out_length: 30.0
render_dir: renders/
checkpoint: log/checkpoint.pt
motion_save_dir: out_data/pkl
fbx_save_dir: out_data/fbx
cache_features: false
no_render: false
feature_cache_dir: cached_features/
logging_path: log/logging.txt
when: D
interval: 1
backupCount: 7
host: '0.0.0.0'
port: 8888
cuda: '0'
conda_name: zrx_edge2