平台相关常用接口、函数

 1、接口主程序

1)接口参数数据校验。jsonschema

2)异常处理。

# 方法一:
try:
    pass
except Exception as err:
    pass
# 方法二,此方法,可将其编写在其它文件子函数里,并实现进行一步异常处理统一返回给前端
raise ValueError('....')

 3)单/多进程 + 协程

# -*- coding:utf-8 -*-
import json
import os
import copy
import time
import torch

from gevent import pywsgi, monkey
# 多线程,非阻塞
monkey.patch_all()

from flask import Flask, request, jsonify
from flask_cors import CORS
from multiprocessing import cpu_count, Process
from jsonschema import validate, ValidationError

from detect import inference_main
from train import train_main
from utils.job_manager import kill_process_by_port, kill_process_by_name
from utils.logger import get_logger

# 日志
log_file = './logs/yolov5.log'
logger = get_logger(name='yolov5', log_file=log_file)

# flask服务
app = Flask(__name__)
CORS(app, resources=r'/*')
app.config['JSON_AS_ASCII'] = False

# 获取工程目录
project_root = str(pathlib.Path(__file__).resolve().parents[2])
sys.path.append(_project_root)


# http接口参数校验
# 接口http://ip:port/train的用户校验schema字典定义
schema_train = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", 'basic_hyp'],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["train", "val", "nc", "names", "result_path"],
                                "properties": {"train": {"type": "string"},
                                               "val": {"type": "string"},
                                               "nc": {"type": "integer", "minimum": 1},
                                               "names": {"type": "array"},
                                               "result_path": {"type": "string"}
                                               }},
                "basic_hyp": {"type": "object",
                              "required": ["epochs", "batch-size", "workers", "img-size", "device"],
                              "properties": {"epochs": {"type": "integer", "minimum": 1},
                                             "batch-size": {"type": "integer", "minimum": 1},
                                             "workers": {"type": "integer", "minimum": 0},
                                             "img-size": {"type": "array"},
                                             "device": {"type": "string"}
                                             }},
            }
        }
    }
}

# 接口http://ip:port/inference的用户校验schema字典定义
schema_inference = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", "basic_hyp"],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["test", "result_path"],
                                "properties": {"test": {"type": "string"},
                                               "result_path": {"type": "string"}
                                               }
                                },
                "basic_hyp": {"type": "object",
                              "required": ["weights", "device"],
                              "properties": {"weights": {"type": "string"},
                                             "device": {"type": "string"}
                                             }
                              }
            }
        }
    }
}


# data参数校验装饰器,可指定不同的校验schema
def json_validate(schema):
    def wrapper(func):
        def inner(data, *args, **kwargs):
            try:
                validate(data, schema)
            except ValidationError as e:
                logger.error("接口参数校验失败:{}!".format(e.message))
                return {'error': True, 'msg': e.message}
            else:
                logger.info("接口参数校验通过!")
                return func(data, *args, **kwargs)
        return inner
    return wrapper


def api_result(event_id, state_code, msg_type, msg, result):
    """
    构建接口返回结果
    """
    api_res = {
        "event_id": event_id,
        "state_code": state_code,
        "feed_type": msg_type,
        "feed_msg": msg,
        "feed_data": result,
    }
    logger.info("feed_msg: {}".format(api_res))
    logger.info("=====================================================\n")
    return jsonify(api_res)


@app.route('/train', methods=['POST'])
def train_post():
    """
    模型训练接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_train)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型训练.......")
                return msg_dict_copy
            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 训练
            if not validate_msg['error']:
                result_path = train_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'train', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'train', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)

    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "train", feed_msg, None)


@app.route('/inference', methods=['POST'])
def inference_post():
    """
    模型推理接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_inference)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型推理.......")
                return msg_dict_copy

            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 推理
            if not validate_msg['error']:
                result_path = inference_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'inference', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'inference', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)
    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "inference", feed_msg, None)


def start_app(MULTI_PROCESS=False, USE_CORES=1):
    """
    启动服务
    Returns:
    """
    # 先清空显存占用
    torch.cuda.empty_cache()
    try:
        logger.info("\n===============================================================================")
        logger.info("deeplearn server starting...")
        # 持久化服务
        if MULTI_PROCESS == False:
            server = pywsgi.WSGIServer(("0.0.0.0", 8080), app)
            server.serve_forever()
            logger.info("deeplearn server start success.")
            print('单进程 + 协程')
            return
        else:
            mulserver = pywsgi.WSGIServer(('0.0.0.0', 8080), app)
            mulserver.start()

            def server_forever():
                mulserver.start_accepting()
                mulserver._stop_event.wait()

            all_cpu_cores = cpu_count()
            if USE_CORES > all_cpu_cores:
                use_cores = all_cpu_cores
            else:
                use_cores = USE_CORES
            for i in range(use_cores):
                p = Process(target=server_forever)
                p.start()
            print('多进程 + 协程,进程数:{}+1'.format(use_cores))
            return

    except Exception as err:
        logger.error("exception in server: {}".format(err))
        logger.error("a same service port has been started. please shut down before operation.")
        try:
            logger.error("{}".format(kill_process_by_port(8080)))
        except Exception as err:
            logger.error("exception in server: {}".format(err))



def stop_app():
    """
    结束服务
    Returns:
    """
    logger.info("\n===============================================================================")
    logger.warning("deeplearn server stopping...")
    try:
        logger.info("stop info: {}".format(kill_process_by_port(8080)))
    except Exception as err:
        logger.error("stop err: {}".format(err, kill_process_by_name("python.exe")))


if __name__ == "__main__":
    # app.run(port=8080, host="0.0.0.0", )
    MULTI_PROCESS = True
    # 默认启动2+1进程
    USE_CORES = int(os.getenv('USE_CORES')) if os.getenv('USE_CORES') else 2
    start_app(MULTI_PROCESS=MULTI_PROCESS, USE_CORES=USE_CORES)
    # stop_app()

 2、进程相关常用函数

# job_manager.py
# -*- coding:utf-8 -*-

import os
import psutil


def get_all_process():
    pid_dict = {}
    pids = psutil.pids()
    try:
        for pid in pids:
            p = psutil.Process(pid)
            pid_dict[pid] = p.name()
    except Exception as err:
        pass
    return pid_dict

def find_pid_by_name(name: str):
    """
    根据进程名获取进程pid
    Args:
        name: process name

    Returns: process pid

    """
    pros = psutil.process_iter()
    print("[" + name + "]'s pid is:")
    pids = []
    for pro in pros:
        if (pro.name() == name):
            print(pro.pid)
            pids.append(pro.pid)
    return pids

def find_port_by_pid(pid: int):
    """根据pid寻找该进程对应的端口"""
    alist = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.pid == pid:
            alist.append({pid: con_info.laddr.port})
    return alist


def find_pid_by_port(port: int):
    """根据端口寻找该进程对应的pid"""
    pid_list = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.laddr.port == port:
            pid_list.append(con_info.pid)
    return pid_list


def kill_process_by_pid(pid):
    # windows
    # cmd = 'taskkill /pid ' + pid + ' /f'
    cmd = 'kill -9 ' + pid
    try:
        os.system(cmd)
    except Exception as e:
        print(e)


def kill_process_by_name(set_name):
    all_pid = get_all_process()
    for pid, name in all_pid.items():
        if name == set_name:
            kill_process_by_pid(str(pid))
    msg_str = "kill process in name: {}".format(set_name)
    return msg_str


def kill_process_by_port(port):
    pids = find_pid_by_port(port)
    for pid in pids:
        kill_process_by_pid(str(pid))

    msg_str = "kill process in port: {}".format(port)
    return msg_str


def clean_cmd():
    kill_process_by_name("cmd.exe")
    kill_process_by_name("bash.exe")


if __name__ == "__main__":
    # kill_process_by_port(8010)
    # kill_process_by_name("python.exe")
    # kill_process_by_name("cmd.exe")
    # kill_process_by_name("bash.exe")
    # kill_process_by_name("myProcess")
    # print(find_pid_by_port('8080'))
    print(find_pid_by_name('myProcess'))

3、日志模块 

# logger.py

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import logging
from logging.handlers import RotatingFileHandler
import functools

logger_initialized = {}


@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified a FileHandler will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    for logger_name in logger_initialized:
        if name == logger_name:
            return logger

    formatter = logging.Formatter(
        '[%(asctime)s.%(msecs)03d] %(name)s %(levelname)s: %(message)s', datefmt="%Y/%m/%d %H:%M:%S")

    stream_handler = logging.StreamHandler(stream=sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if log_file is not None:
        log_file_folder = os.path.split(log_file)[0]
        os.makedirs(log_file_folder, exist_ok=True)
        # file_handler = logging.FileHandler(log_file, 'a')
        file_handler = RotatingFileHandler(filename=log_file, maxBytes=10 * 1024 * 1024, backupCount=15, encoding='utf-8')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        logger.setLevel(log_level)

    logger_initialized[name] = True
    return logger

if __name__ == "__main__":
    # 日志
    log_file = './logs/yolov5.log'
    logger = get_logger(name='yolov5', log_file=log_file)

4、dockefile 构建镜像

 docker build -t wood_detect:test .

# Dockerfile

# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
#FROM nvcr.io/nvidia/pytorch:21.05-py3
#FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
FROM deploy.hello.com/2020-public/yolov5_base:1.0.1

# Install linux packages
#RUN apt update && apt install -y zip htop screen libgl1-mesa-glx

## Create working directory
#RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app

# Copy contents
COPY . /usr/src/app
EXPOSE 8080
# Install python dependencies
#COPY requirements.txt .
# RUN python -m pip install --upgrade pip
#RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
#RUN pip install --no-cache -r requirements.txt coremltools onnx gsutil -i https://pypi.douban.com/simple/
RUN pip install --no-cache -r requirements.txt -i https://pypi.douban.com/simple/

# RUN pip install --no-cache -U torch torchvision

## Set environment variables
#ENV HOME=/usr/src/app
#
ENTRYPOINT ["python","main_app.py"]
# ---------------------------------------------------  Extras Below  ---------------------------------------------------

# Build and Push
# t=ultralytics/yolov5:latest && sudo docker build -t $t . && sudo docker push $t
# for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done

# Pull and Run
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t

# Pull and Run with local directory access
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t

# Kill all
# sudo docker kill $(sudo docker ps -q)

# Kill all image-based
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)

# Bash into running container
# sudo docker exec -it 5a9b5863d93d bash

# Bash into stopped container
# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash

# Send weights to GCP
# python -c "from utils.general import *; strip_optimizer('runs/train/exp0_*/weights/best.pt', 'tmp.pt')" && gsutil cp tmp.pt gs://*.pt

# Clean up
# docker system prune -a --volumes

 5、Docker-compose.yml方法构建镜像并部署

docker-compose up
# Docker-compose.yml
# GPU配置,参考https://docs.docker.com/compose/gpu-support/

version: "3.8"

services:
  yolov5:
    build:
      context: .
    image: deploy.com/public/yolov5_server:alpha_v1.0
    restart: always
    container_name: yolov5_server
    ports:
      - 8080:8080
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: [ '0',]
              capabilities: [ gpu ]

6、重写yaml文件

# --hyp,data/hyp.scratch.yaml文件
with open(opt.hyp) as f:
    hyp = yaml.safe_load(f)
    if 'lr0' in yolo_hype and isinstance(yolo_hype['lr0'], float) and yolo_hype['lr0'] >= 0.0:
        hyp['lr0'] = yolo_hype['lr0']
    yaml.safe_dump(hyp, open(opt.hyp, mode='w'))

7、重写当前工作目录

import os
import path

PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parents[0])
print(os.path.abspath('.'))
os.chdir(f'{PROJECT_ROOT}/token_text') # 设置当前目录
print(os.path.abspath('.'))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值