flask 创建 EDGE 推理接口

1. flask 创建一个简单的接口:调用 http://127.0.0.1:5000/api/hello 直接输出{"data":"show"}

from flask import Flask, jsonify

app = Flask(__name__)

@app.route('/api/hello', methods=['GET'])
def hello():
    return jsonify({'data':"show"})


if __name__ == '__main__':

    app.run()

 2.改写EDGE 推理代码编写成类

import glob
import os
from functools import cmp_to_key
from pathlib import Path
from tempfile import TemporaryDirectory
import random
import jukemirlib
import numpy as np
import torch
from tqdm import tqdm
from data.slice import slice_audio
from log.EDGE import EDGE
from data.audio_extraction.baseline_features import extract as baseline_extract
from data.audio_extraction.jukebox_features import extract as juke_extract
import os,yaml
from flask import Flask, jsonify
import logging
import logging.handlers
from flask import Flask, render_template, request, session, send_file, make_response
import warnings
import subprocess

# 忽略所有警告
warnings.filterwarnings("ignore")

# api接口
app = Flask(__name__)

# 创建 edge 类
class EDGE_api:
    def __init__(self,data):
        # 读取所有参数
        self.key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
        self.stringintkey = cmp_to_key(self.stringintcmp)
        self.feature_func = juke_extract if data['feature_type'] == "jukebox" else baseline_extract
        self.sample_length = data['out_length']
        self.sample_size = int(self.sample_length / 2.5) - 1
        self.model = EDGE(data['feature_type'], data['checkpoint'])
        self.model.eval()
        self.render_dir = data['render_dir']
        self.motion_save_dir = data['motion_save_dir']
        self.cache_features = data['cache_features']
        self.conda_name = data['conda_name']
        self.fbx_save_dir = data['fbx_save_dir']
        self.no_render = data['no_render']
        self.feature_cache_dir = data['feature_cache_dir']
        # 创建log文件
        self.logger = self.logging_save(data['logging_path'],data['when'],data['interval'],data['backupCount'])
        self.logger.info('starting EDGE')
        self.logger.info(data)

    # 定义log 参数
    def logging_save(self,save_path, when, interval, backupCount):
        # 创建一个logger
        logger = logging.getLogger('logger')
        logger.setLevel(logging.DEBUG)

        # 创建一个handler,用于写入日志文件,每天创建一个新的日志文件
        handler = logging.handlers.TimedRotatingFileHandler(save_path, when=when, interval=interval,
                                                            backupCount=backupCount)

        # 定义handler的输出格式
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)

        # 给logger添加handler
        logger.addHandler(handler)

        return logger

    def stringintcmp(self,a,b):
        aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
        ka, kb = self.key_func(a), self.key_func(b)
        if aa < bb:
            return -1
        if aa > bb:
            return 1
        if ka < kb:
            return -1
        if ka > kb:
            return 1
        return 0

    # 执行推理方法
    def test(self,parm):
        self.logger.info('start... {}'.format(parm))

        cache_features = self.cache_features if parm['cache_features'] is None else True
        render_dir = self.render_dir if parm['render_dir'] is None else parm['render_dir']
        no_render = self.no_render if parm['no_render'] is None else True
        wav_file = parm['wav_file']
        feature_cache_dir = self.feature_cache_dir if parm['feature_cache_dir'] is None else parm['feature_cache_dir']
        motion_save_dir = self.motion_save_dir if parm['motion_save_dir'] is None else parm['motion_save_dir']

        if cache_features:
            songname = os.path.splitext(os.path.basename(wav_file))[0]
            save_dir = os.path.join(feature_cache_dir, songname)
            Path(save_dir).mkdir(parents=True, exist_ok=True)
            dirname = save_dir
        else:
            temp_dir = TemporaryDirectory()
            dirname = temp_dir.name

        slice_audio(wav_file, 2.5, 5.0, dirname)
        file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=self.stringintkey)
        # randomly sample a chunk of length at most sample_size
        rand_idx = random.randint(0, len(file_list) - self.sample_size)
        cond_list = []
        for idx, file in enumerate(file_list):
            if (not cache_features) and (not (rand_idx <= idx < rand_idx + self.sample_size)):
                continue
            reps, _ = self.feature_func(file)
            # save reps
            if cache_features:
                featurename = os.path.splitext(file)[0] + ".npy"
                np.save(featurename, reps)
            if rand_idx <= idx < rand_idx + self.sample_size:
                cond_list.append(reps)
        cond_list = torch.from_numpy(np.array(cond_list))
        data_tuple = None, cond_list, file_list[rand_idx : rand_idx + self.sample_size]
        self.model.render_sample(
            data_tuple, "test", render_dir, render_count=-1, fk_out=motion_save_dir, render=not no_render
        )
        torch.cuda.empty_cache()
        if temp_dir in locals():
            temp_dir.cleanup()
        name = 'test_' + wav_file.split('\\')[-1].replace('.wav','.pkl')
        outpath = os.path.join(motion_save_dir,name)

        self.logger.info('end - out path : {}'.format(outpath))
        os.makedirs(self.fbx_save_dir,exist_ok=True)

        # 调用env2环境中的脚本
        self.logger.info('start - pkl2fbx : {}'.format(outpath))
        # 调用 pkl 转 fbx 方法
        subprocess.run(["conda", "run", "-n", self.conda_name, "python", "smpl2fbx/Convert.py",'--input_dir', outpath ,'--output_dir', self.fbx_save_dir])

        self.logger.info('end - pkl2fbx : {}'.format(self.fbx_save_dir + '/' + name.replace('.kpl','.fbx')))
        return jsonify({'output_pkl': outpath,
                        'output_fbx':  self.fbx_save_dir + '/' + name.replace('.kpl','.fbx')})

# 读取固定参数
with open("log/config.yaml", "r", encoding="utf-8") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)['test_options']

os.environ["CUDA_VISIBLE_DEVICES"] = config['cuda']
edge = EDGE_api(config)

# api 调用
@app.route('/api/hello', methods=['GET'])
def show():

    parm = {}
    keys = ['cache_features','render_dir','no_render','wav_file','feature_cache_dir','motion_save_dir']
    for key in keys:
        parm[key] = request.args.get(key)

    if parm['wav_file'] is None:
        return jsonify({'error': '必须传入输入音乐地址 wav_file 字段'})

    return edge.test(parm)


if __name__ == "__main__":
    app.run(host=config['host'], port=config['port'], debug=False)
    # edge.test(wav_file=r'te/kemu3.wav')

 3.config.yaml

test_options:
  feature_type: jukebox
  out_length: 30.0
  render_dir: renders/
  checkpoint: log/checkpoint.pt
  motion_save_dir: out_data/pkl
  fbx_save_dir: out_data/fbx
  cache_features: false
  no_render: false
  feature_cache_dir: cached_features/
  logging_path: log/logging.txt
  when: D
  interval: 1
  backupCount: 7
  host: '0.0.0.0'
  port: 8888
  cuda: '0'
  conda_name: zrx_edge2
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值