import os
import json
from flask import Flask, jsonify, request, render_template
from datetime import datetime
import logging
import glob
import time
import re
app = Flask(__name__)
# 配置日志
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 定义基础路径(使用更通用的路径格式)
BASE_PATH = r"C:\Users\l30078648\Desktop\250730"
# 定义芯片和模型组路径映射
TRACK_PATHS = {
"Ascend610Lite": {
"rl_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rl_nn"),
"rsc_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rsc_nn"),
"prediction_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "prediction_nn")
},
"BS9SX1A": {
"rl_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rl_nn"),
"rsc_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rsc_nn"),
"prediction_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "prediction_nn")
}
}
JSON_PATHS = {
"Ascend610Lite": {
"rl_nn": os.path.join(BASE_PATH, "json", "Ascend610Lite", "rl_nn"),
"rsc_nn": os.path.join(BASE_PATH, "json", "Ascend610Lite", "rsc_nn"),
"prediction_nn": os.path.join(BASE_PATH, "json", "Ascend610Lite", "prediction_nn")
},
"BS9SX1A": {
"rl_nn": os.path.join(BASE_PATH, "json", "BS9SX1A", "rl_nn"),
"rsc_nn": os.path.join(BASE_PATH, "json", "BS9SX1A", "rsc_nn"),
"prediction_nn": os.path.join(BASE_PATH, "json", "BS9SX1A", "prediction_nn")
}
}
def get_prebuild_id_data(chip, group):
"""从原始路径获取Pre Build ID数据(优化版本)"""
prebuild_data = {}
# 检查路径有效性
if chip not in TRACK_PATHS or group not in TRACK_PATHS[chip]:
logger.error(f"无效路径: {chip}/{group}")
return prebuild_data
group_path = TRACK_PATHS[chip][group]
if not os.path.exists(group_path):
logger.error(f"原始路径不存在: {group_path}")
return prebuild_data
# 查找并处理JSON文件
json_files = glob.glob(os.path.join(group_path, "*.json"))
for json_file in json_files:
try:
with open(json_file, 'r') as f:
data = json.load(f)
model_name = os.path.splitext(os.path.basename(json_file))[0]
prebuild_id = data.get('prebuild_id')
# 验证prebuild_id格式
if prebuild_id and re.match(r'^[a-f0-9]{32}$', prebuild_id):
prebuild_data[model_name] = prebuild_id
else:
logger.warning(f"无效prebuild_id格式: {prebuild_id} in {json_file}")
except Exception as e:
logger.error(f"解析原始文件 {json_file} 时出错: {str(e)}")
return prebuild_data
def extract_quantization_rate(data):
"""从JSON数据中提取量化率(支持多种键名)"""
# 尝试多种可能的键名
keys_to_try = [
'quantization_rate',
'quantization rate',
'quant_rate',
'quantization_ratio',
'quantization_ratio'
]
for key in keys_to_try:
if key in data:
return data[key]
# 尝试嵌套结构
for value in data.values():
if isinstance(value, dict):
for nested_key in keys_to_try:
if nested_key in value:
return value[nested_key]
return None
def parse_json_file(file_path):
"""解析新路径的JSON文件,提取性能数据(增强版)"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
except Exception as e:
logger.error(f"读取文件出错 {file_path}: {e}")
return None
model_name = os.path.splitext(os.path.basename(file_path))[0]
# 提取时延
latency = None
for key in data:
if key.endswith('.om'):
latency = data[key]
break
# 提取带宽
bandwidth = None
if 'mean_ddr' in data:
bandwidth = data['mean_ddr']
else:
for key, value in data.items():
if isinstance(value, dict) and 'mean_ddr' in value:
bandwidth = value['mean_ddr']
break
# 提取量化率
quantization_rate = extract_quantization_rate(data)
# 获取最后修改时间
last_modified = datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
return {
"model_name": model_name,
"latency": latency,
"bandwidth": bandwidth,
"quantization_rate": quantization_rate,
"last_modified": last_modified
}
def get_performance_data(chip, group):
"""获取性能数据(优化版本)"""
# 创建响应数据结构
performance_data = {
"status": "success",
"models": [],
"timestamp": datetime.now().isoformat(),
"chip_type": chip,
"group": group,
"json_path": JSON_PATHS[chip].get(group, "") if chip in JSON_PATHS else "",
"track_path": TRACK_PATHS[chip].get(group, "") if chip in TRACK_PATHS else "",
"file_count": 0,
"prebuild_id_count": 0
}
# 1. 获取Pre Build ID数据
prebuild_id_map = get_prebuild_id_data(chip, group)
performance_data["prebuild_id_count"] = len(prebuild_id_map)
# 2. 获取性能数据
if chip not in JSON_PATHS or group not in JSON_PATHS[chip]:
performance_data["status"] = "error"
performance_data["error"] = "无效芯片或模型组"
return performance_data
group_path = JSON_PATHS[chip][group]
if not os.path.exists(group_path):
performance_data["status"] = "error"
performance_data["error"] = "JSON路径不存在"
return performance_data
json_files = glob.glob(os.path.join(group_path, "*.json"))
performance_data["file_count"] = len(json_files)
# 处理每个JSON文件
for json_file in json_files:
model_data = parse_json_file(json_file)
if model_data:
model_name = model_data["model_name"]
model_data["prebuild_id"] = prebuild_id_map.get(model_name, "NA")
performance_data["models"].append(model_data)
return performance_data
@app.route('/api/performance', methods=['GET'])
def performance_api():
"""性能数据API接口(添加缓存控制)"""
start_time = time.time()
try:
device = request.args.get('device', 'Ascend610Lite')
group = request.args.get('type', 'rl_nn')
logger.info(f"性能API请求 - 设备: {device}, 组: {group}")
performance_data = get_performance_data(device, group)
process_time = time.time() - start_time
# 添加缓存头
response = jsonify({
**performance_data,
"process_time": round(process_time, 4)
})
response.headers['Cache-Control'] = 'public, max-age=300' # 5分钟缓存
return response
except Exception as e:
logger.exception("处理请求时出错")
return jsonify({
"status": "error",
"error": "服务器内部错误",
"details": str(e),
"process_time": round(time.time() - start_time, 4)
}), 500
@app.route('/api/prebuild_ids', methods=['GET'])
def prebuild_ids_api():
"""专用接口:返回所有prebuild_id(添加分页)"""
start_time = time.time()
try:
device = request.args.get('device', 'Ascend610Lite')
group = request.args.get('type', 'rl_nn')
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 50))
logger.info(f"Prebuild ID请求 - 设备: {device}, 组: {group}")
prebuild_id_map = get_prebuild_id_data(device, group)
all_ids = list(prebuild_id_map.values())
total = len(all_ids)
# 分页处理
start_idx = (page - 1) * per_page
end_idx = start_idx + per_page
paginated_ids = all_ids[start_idx:end_idx]
response = {
"status": "success",
"prebuild_ids": paginated_ids,
"page": page,
"per_page": per_page,
"total": total,
"total_pages": (total + per_page - 1) // per_page,
"timestamp": datetime.now().isoformat(),
"process_time": round(time.time() - start_time, 4)
}
return jsonify(response)
except Exception as e:
logger.exception("获取prebuild_id时出错")
return jsonify({
"status": "error",
"error": "服务器内部错误",
"details": str(e),
"process_time": round(time.time() - start_time, 4)
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查端点(添加更多检查项)"""
status = {
"status": "ok",
"timestamp": datetime.now().isoformat(),
"components": {
"disk_space": os.path.exists(BASE_PATH),
"track_paths": {chip: {group: os.path.exists(path) for group, path in groups.items()}
for chip, groups in TRACK_PATHS.items()},
"json_paths": {chip: {group: os.path.exists(path) for group, path in groups.items()}
for chip, groups in JSON_PATHS.items()}
}
}
return jsonify(status)
@app.route('/debug/paths', methods=['GET'])
def debug_paths():
"""调试接口:返回路径配置信息(添加文件列表)"""
debug_data = {
"base_path": BASE_PATH,
"track_paths": {},
"json_paths": {}
}
# 原始路径状态
for chip, group_dict in TRACK_PATHS.items():
debug_data["track_paths"][chip] = {}
for group, path in group_dict.items():
exists = os.path.exists(path)
entry = {
"path": path,
"exists": exists
}
if exists:
files = glob.glob(os.path.join(path, "*.json"))
entry["file_count"] = len(files)
entry["sample_files"] = files[:3] # 显示前3个文件作为示例
debug_data["track_paths"][chip][group] = entry
# 新路径状态
for chip, group_dict in JSON_PATHS.items():
debug_data["json_paths"][chip] = {}
for group, path in group_dict.items():
exists = os.path.exists(path)
entry = {
"path": path,
"exists": exists
}
if exists:
files = glob.glob(os.path.join(path, "*.json"))
entry["file_count"] = len(files)
entry["sample_files"] = files[:3] # 显示前3个文件作为示例
debug_data["json_paths"][chip][group] = entry
return jsonify(debug_data)
@app.route('/')
def home():
"""首页路由 - 渲染前端页面(添加API文档链接)"""
return render_template('index.html')
if __name__ == '__main__':
# 生产环境建议关闭debug模式
app.run(host="127.0.0.1", port=8080, debug=True)
这个后端脚本提示我的prebuild_id格式无效,解决这个问题
最新发布