python常用工具类

# 获取当前机器gpu的数量
def get_gpu_count():
    # return len(os.popen("nvidia-smi -L").read().strip().split("\n"))
    # num_default = len(os.popen("nvidia-smi -L").read().strip().split("\n"))
    pattern = "(GPU \d+?): NVIDIA"
    num_default = len(re.findall(pattern, os.popen("nvidia-smi -L").read()))
    if "CUDA_VISIBLE_DEVICES" not in os.environ:
        return num_default
    num_specified = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
    return min(num_default, num_specified)
def get_gpu_memory(device_id=0):
    try:
        result = os.popen("nvidia-smi").read()
        pattern = "MiB.+?(\d+)+?MiB"
        results = re.findall(pattern, result)
        return int(results[device_id])
    except Exception as e:
        logger.error(e)
        return 0


def get_available_memory(device_id=0):
    try:
        result = os.popen("nvidia-smi").read()
        pattern = ".+?(\d+)MiB.+?(\d+)?MiB"
        results = re.findall(pattern, result)
        used, total = map(int, results[device_id])
        left = total - used
        return left
    except Exception as e:
        logger.error(e)
        return 0

# 提交结果
def submit_result(api, msg, retry=3, timeout=5):
    i = 0
    while i < retry:
        try:
            r = requests.post(api, json=msg, timeout=timeout)
            logger.info(r.text)
            return
        except Exception as e:
            i += 1
            logger.error(e)

    logger.error("结果提交失败!")
    logger.error(f"{api},{msg}")
def cv2_base64(image):
    base64_str = cv2.imencode('.jpg', image)[1].tostring()
    base64_str = base64.b64encode(base64_str)
    return base64_str


def base64_cv2(base64_str):
    imgString = base64.b64decode(base64_str)
    nparr = np.fromstring(imgString, np.uint8)
    image = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
    return image

def align_faces(dets, img_raw):
    '''
    dets对应retinaface的原生结果,img_raw为原始图片
    '''
    def align_face(img, bb, landmark, image_size):
        M = None
        if landmark is not None:
            src = np.array([
                [30.2946, 51.6963],
                [65.5318, 51.5014],
                [48.0252, 71.7366],
                [33.5493, 92.3655],
                [62.7299, 92.2041]], dtype=np.float32)
            if image_size[1] == 112:
                src[:, 0] += 8.0
            dst = landmark.astype(np.float32)

            tform = trans.SimilarityTransform()
            tform.estimate(dst, src)
            M = tform.params[0:2, :]

        if M is None:
            ret = img[bb[1]:bb[3], bb[0]:bb[2], :]
            if len(image_size) > 0:
                ret = cv2.resize(
                    ret, (image_size[1], image_size[0]), interpolation=cv2.INTER_CUBIC)
            return ret
        else:
            warped = cv2.warpAffine(
                img, M, (image_size[1], image_size[0]), borderValue=0.0)
            return warped

    face_boxes = dets[:, :4]
    face_landmarks = dets[:, 5:]
    face_cropped = []
    boxes = []
    # print("len(boxes):", len(face_boxes))
    for i in range(len(face_boxes)):
        face_box = face_boxes[i]
        face_landmark = face_landmarks[i].reshape((5, 2))
        face_aligned = align_face(
            img_raw, face_box, face_landmark, (112, 112))
        face_cropped.append(face_aligned)
        boxes.append(face_box)
    return face_cropped
def try_except(func):
    def handler(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            logger.error(e)
            return None

    return handler

保存视频


class MyVideoWriter:
    def __init__(self, save_path, height=None, width=None, fps=25):
        self.save_path = save_path
        self.height = height
        self.width = width
        self.fourcc = cv2.VideoWriter_fourcc(*'MP4V')
        self.cur_frame_id = -1
        self.fps = fps
        self.videoWriter = None

    def save(self, frame):
        self.cur_frame_id += 1
        if self.cur_frame_id == 0:
            if self.height is None or self.width is None:
                self.height, self.width, _ = frame.shape
            self.videoWriter = cv2.VideoWriter(
                self.save_path, self.fourcc, self.fps, (self.width, self.height))
        self.videoWriter.write(frame)

    def release(self):
        if self.videoWriter:
            self.videoWriter.release()
            self.videoWriter = None

    def __del__(self):
        self.release()


class MyVideoWiterImgIo:
    def __init__(self, save_path, fps=25):
        self.videowriter = imageio.get_writer(save_path, fps=fps)

    def save(self, frame):
        self.videowriter.append_data(frame[..., ::-1])

    def __del__(self):
        self.close()

    def close(self):
        if self.videowriter is not None:
            self.videowriter.close()
            self.videowriter = None

图片画多个roi

from itertools import chain
import os
import cv2
import numpy as np

cur_folder = os.path.dirname(__file__)


class RoiDrawer:
    def __init__(self, img):
        self.img = img

    def draw(self):
        def onMouse(event, x, y, flags, param):
            if event == cv2.EVENT_LBUTTONDOWN:
                cache_imgs.append(img.copy())
                cv2.circle(img, (x, y), 3, (0, 0, 255), -1)
                if len(polygon) != 0:
                    last_pt = polygon[-1]
                    cv2.line(img, last_pt, [x, y], (255, 0, 0), 1)
                polygon.append([x, y])
                cv2.imshow(winname, img)

        cache_imgs = []
        polygons = []
        polygon = []
        winname = "ori"
        cv2.namedWindow(winname)
        cv2.setMouseCallback(winname, onMouse)
        img = self.img.copy()
        img_h, img_w, *_ = img.shape

        while True:
            cv2.imshow(winname, img)
            key = cv2.waitKey()

            if key == 27:
                if len(polygon) > 0:
                    polygon.pop(-1)
                    img = cache_imgs.pop(-1)

            elif key == 13:  # enter
                pt1 = polygon[0]
                pt2 = polygon[-1]
                cv2.line(img, pt1, pt2, (255, 0, 0), 1)
                contour = np.array(polygon) / [img_w, img_h] * [640, 480]
                contour = contour.astype(np.int32).tolist()
                contour = list(chain.from_iterable(contour))
                contour = ",".join(map(str, contour)) + ','
                polygons.append({"Id": len(polygons), "Point": contour})
                cache_imgs.clear()
                polygon.clear()
                print(polygons)

            elif key == ord('q'):
                break

        cv2.destroyWindow(winname)
        return {"Polygons": polygons}


if __name__ == '__main__':

    root_folder = "test_imgs"
    img_names = os.listdir(root_folder)
    img_names.sort()

    for img_name in img_names:
        print(img_name)
        frame = cv2.imread(os.path.join(root_folder, img_name), cv2.IMREAD_COLOR)
        drawer = RoiDrawer(frame)
        polygons = drawer.draw()
        print(polygons)
        # for polygon in polygons:
        #     print(polygon)

        # # polygon = [[536, 188], [1239, 697], [1399, 906], [1018, 690], [537, 277]]
        #
        # img_h, img_w, *_ = img.shape
        # contour = np.array(polygon) / [img_w, img_h] * [640, 480]
        # contour = contour.astype(np.int32).tolist()
        # contour = list(chain.from_iterable(contour))
        # contour = ",".join(map(str, contour)) + ','
        # mon_parm = {"Polygons": [{"Id": 999999, "Point": contour}]}
        # result = processTest(mon_parm, None, None, frame)

操作minio

import minio
import os
import cv2
from loguru import logger


class MinioClient:
    def __init__(self,
                 endpoint="192.168.4.78:9000",
                 access_key="minioadmin",
                 secret_key="minioadmin",
                 secure=False
                 ):
        self.endpoint = endpoint
        try:
            self.client = minio.Minio(
                endpoint=endpoint, access_key=access_key, secret_key=secret_key, secure=secure)
        except Exception as e:
            self.client = None
            logger.error(e)

    def upload(self, bucket, file_path, dst_path=None, retry_times=3):
        if self.client is None:
            return False

        try:
            if not self.client.bucket_exists(bucket):
                self.client.make_bucket(bucket)
        except Exception as e:
            logger.error("创建bucket失败")
            logger.error(e)
            return False

        filename = os.path.basename(file_path)
        if dst_path is None:
            dst_path = filename

        for _ in range(retry_times):
            try:
                self.client.fput_object(
                    bucket, dst_path, file_path)
                return True

            except Exception as e:
                logger.error(f"{file_path} 上传失败")
                logger.error(e)

        return False

    def download(self, bucket, filename, dst_folder="/tmp", retry_times=3, only_path=True):
        if self.client is None:
            return False, ""

        if only_path:
            dst_path = os.path.join(self.endpoint, bucket, filename)
            dst_path = dst_path.replace("\\", "/")
            if "http" not in dst_path:
                dst_path = "http://" + dst_path
            return True, dst_path

        dst_path = os.path.join(dst_folder, filename)

        for _ in range(retry_times):
            try:
                self.client.fget_object(bucket, filename, dst_path)
                return True, dst_path

            except Exception as e:
                logger.error(f"{dst_path} 下载失败")
                logger.error(e)

        return False, ""


if __name__ == '__main__':

    client = MinioClient()

    file_path = "/workspace_wjr/develop/projects/continuous_model_server/test.py"
    # s = time()
    print(client.upload("test",file_path))
    # print(time()-s)

    # ret, dst_path = client.download("test", "1.mp4",only_path=False)
    # print(dst_path)
    # cap = cv2.VideoCapture(dst_path)
    # # while cap.isOpened():
    # #     ret, frame = cap.read()
    # #     cv2.imshow("frame", frame)
    # #     cv2.waitKey(5)
    # logger.info(cap.isOpened())

    # from utils.db_utils import load_config

    # config = load_config()
    # minio_config = config["minio"]

    # client = MinioClient(minio_config["endpoint"],
    #                      minio_config["access_key"],
    #                      minio_config["secret_key"],
    #                      minio_config["secure"])

    # client.upload("test", "minio_client.py")
 

操纵mq

import pika
from loguru import logger
from pika.exceptions import AMQPConnectionError
from container_utils.common import try_except


class MqSender:

    @try_except
    def __init__(self, config):
        host = config.get("host", "127.0.0.1")
        username = config.get("username", "admin")
        pwd = config.get("pwd", "admin")
        port = config.get("port", 5672)
        virtual_host = config.get("virtual_host", "/")
        target_queue = config.get("target_queue", "test_queue")

        self.target_queue = target_queue

        try:
            credentials = pika.PlainCredentials(username, pwd)
            self.conn = pika.BlockingConnection(
                parameters=pika.ConnectionParameters(host=host, port=port, virtual_host=virtual_host,
                                                     credentials=credentials))
            self.channel = self.conn.channel()

            result = self.channel.queue_declare(queue=self.target_queue)  # exclusive=True 会在mq断开时自动删除队列
            # self.queue_name = result.method.queue
        except AMQPConnectionError as e:
            logger.error("mq连接异常!")
            for se in e.args:
                logger.error(se.exception)
            self.conn = None
            self.channel = None


    def __call__(self, body):
        if self.channel is None or self.conn is None:
            logger.warning("mq写入异常!")
            logger.warning(f"{self.target_queue}: {body}")
            return False

        try:
            self.channel.basic_publish(
                exchange="",
                routing_key=self.target_queue,
                body=body.encode(encoding="utf-8")
            )
            return True

        except Exception as e:
            logger.error("mq写入异常!")
            logger.error(e)
            logger.error(f"{self.target_queue}: {body}")
            return False

    def __del__(self):
        self.close()

    def close(self):
        if self.conn is not None:
            self.conn.close()
            self.conn = None

if __name__ == '__main__':
    from container_utils.db_utils import load_config
    import datetime
    import json

    mq_config = load_config("../container_config/config.yml").get("mq", {"host": "127.0.0.1", "port": 5672, "username": "admin",
                                                                     "pwd": "admin"})

    sender = MqSender("test_ch", mq_config)
    for content in ["good morning!", "good afternoon!", "good evening!"]:
        # 处理完一条才会处理下一条
        sender(content)

操作mysql

import pymysql
from loguru import logger


class MysqlDb:
    def __init__(self, config):
        host = config.get("host", "127.0.0.1")
        user = config.get("user", "root")
        pwd = config.get("pwd", "admin")
        port = config.get("port", 3306)
        db = config.get("db", "mvp")

        # 如果不带cursorclass,返回值是一个元组
        try:
            self.db = pymysql.connect(host=host, user=user, password=pwd, database=db,
                                      port=port, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor,autocommit=True)
        except Exception as e:
            logger.error(f"mysql连接异常!")
            logger.error(e)
            self.db = None

    def __del__(self):
        self.close()
    
    def close(self):
        if self.db is not None:
            self.db.close()
            self.db = None

    # 查询数据库,type表示每次查询一条还是查询所有
    def query(self, cmd, type="one"):
        if self.db is None:
            logger.warning("mysql连接异常!")
            return None
        try:
            with self.db.cursor() as cur:
                row = cur.execute(cmd)
                if type == "one":
                    data = cur.fetchone()
                else:
                    data = cur.fetchall()
                return data
        except Exception as e:
            logger.error(e)
            return None

    # 增/删/改数据库,args不为None表示一次更新多条数据
    def update(self, cmd, args=None):
        if self.db is None:
            logger.warning("mysql连接异常!")
            return False
        try:
            with self.db.cursor() as cur:
                if args is not None:
                    cur.executemany(cmd, args)
                else:
                    cur.execute(cmd)
                self.db.commit()
                return True
        except Exception as e:
            logger.error(e)
            self.db.rollback()
            return False


if __name__ == '__main__':
    from container_utils.db_utils import load_config
    import datetime
    import json

    mysql_config = load_config("../container_config/config.yml").get("mysql", {"host": "127.0.0.1", "port": 3306, "user": "admin",
                                                                     "pwd": "", "db": "mvp"})

    db = MysqlDb(mysql_config)

    # 模型服务刚启动,需要查找所有status为0或1的任务
    select_sql = 'SELECT * FROM continuous_tasks where status in(0,1)'
    # # 模型启动后续只需要查找status=0的任务
    # select_sql = 'SELECT * FROM continuous_tasks where status=0'

    # 每次更新一条数据
    result_tmp = [[1,2,3,4,0.5],[5,6,7,8,0.9]]
    resultUrl = "a/b/c/d.avi"
    # 一次更新一条任务
    update_sql1 = f'UPDATE continuous_tasks SET status="1", updateTime="{datetime.datetime.now()}",timestamp="{datetime.datetime.now()}",resultUrl="{resultUrl}", result="{json.dumps(result_tmp)}" WHERE id=4'

    # 一次更新多条任务
    # 要注意的是里面的参数, 不管什么类型,统一使用%s作为占位符
    update_sql2 = 'UPDATE continuous_tasks SET status=%s, updateTime=%s, result=%s,timestamp=%s,resultUrl=%s WHERE id=%s'


    "***************************************"
    data = db.query(select_sql)
    print(data)
    print(">"*30)

    data = db.query(select_sql, "all")
    if data is not None:
        for d in data:
            print(d)
    print(">" * 30)
    "***************************************"
    res = db.update(update_sql1)
    print(res)
    print(">" * 30)

    res = db.update(update_sql2,[(2, datetime.datetime.now(),json.dumps(result_tmp), datetime.datetime.now(),"/data/result/res.mp4",4),
                                 (2, datetime.datetime.now(),json.dumps(result_tmp), datetime.datetime.now(),"/data/result/res2.mp4",5)])
    print(">" * 30)


    #
    # # insert_sql = 'INSERT INTO continuous_tasks(id, username, password) VALUES(11, "王五", "333333")'
    # # delete_sql = 'DELETE FROM continuous_tasks WHERE id = 11'
    #
    # # data = db.query(select_sql)
    # # print(data)
    # data = db.query(select_sql, type="all")
    # for d in data:
    #     print(d)
    # # db.update(update_sql1)
    # # db.update(update_sql2, [("aaa", 123456, 0), ("bbb", 654321, 1)])
    # # db.update(update_sql2,[("smoke",json.dumps({"a":10, "b":20}),4)])
    #
    # # db.update(update_sql1.format("modelName","fire","modelConfig",json.dumps({"a":10, "b":20}),"id", 5))
    # # modelName = "fire"
    # # modelConfig = {'a':10, 'b':20}
    # # db.update(f'UPDATE continuous_tasks SET modelName="{modelName}", modelConfig="{modelConfig}" WHERE id=5')
    # # data = db.query(select_sql, type="all")
    # # print(data)
    # # print(json.loads(data[0]["modelConfig"])["a"])
    # # print(json.loads(data[0]["modelConfig"])["b"])

操作redis

import redis
from loguru import logger

class RedisDb:
    def __init__(self, redis_config):
        host = redis_config.get("host", "127.0.0.1")
        port = redis_config.get("port", 6379)
        pwd = redis_config.get("pwd", "")
        db_id = redis_config.get("db_id", 0)

        try:
            self.conn = redis.Redis(host=host, port=port, db=db_id, password=pwd)
        except Exception as e:
            logger.error(e)
            logger.error("redis连接异常!")
            self.conn = None

    def __del__(self):
        self.close()

    def close(self):
        if self.conn is not None:
            self.conn.close()
            self.conn = None

    def set(self, func_name, *args, **kwargs):
        if self.conn is None:
            return False

        if not hasattr(self.conn, func_name):
            logger.warning(f"redis不存在{func_name}函数")
            return False

        func = getattr(self.conn, func_name)
        try:
            res = func(*args, **kwargs)
            return res
        except Exception as e:
            logger.error(f"redis执行{func_name}错误!")
            logger.error(e)
            return False

    def get(self, func_name, *args, **kwargs):
        if self.conn is None:
            return None

        if not hasattr(self.conn, func_name):
            logger.warning(f"redis不存在{func_name}函数")
            return None

        func = getattr(self.conn, func_name)
        try:
            res = func(*args, **kwargs)
            if res is not None:
                res = res.decode("utf8")
            return res
        except Exception as e:
            logger.error(f"redis执行{func_name}错误!")
            logger.error(e)
            return None


if __name__ == '__main__':
    from container_utils.db_utils import load_config
    redis_config = load_config("../container_config/config.yml").get("redis", {"host": "127.0.0.1", "port": 6379, "pwd": "", "db_id": 0})
    db = RedisDb(redis_config)
    # res = db.conn.setex("name", 60, "wjr")
    # print(db.conn.get("name").decode("utf8"))
    print(db.set("setex", "name", 60, "wjr"))
    print(db.get("get", "name"))
    print(db.get("get", "name2"))


马士兵全栈mysql封装

from pymysql import cursors, connect


class MysqlHelper(object):
    # 数据库链接参数,可以定义多个,比如conn_params1,conn_params2,用于连接多个数据库,在类实例化时指定
    conn_params1 = {'host': '192.168.4.60', 'port': 3306, 'user': 'root',
                    'passwd': 'admin', 'db': 'testdb', 'charset': 'utf8'}
    conn_params2 = {'host': 'localhost', 'port': 3306, 'user': 'root',
                    'passwd': 'root', 'db': 'mytestdb2', 'charset': 'utf8'}

    # 类的构造函数,主要用于类的初始化
    def __init__(self, conn_params):
        self.__host = conn_params['host']
        self.__port = conn_params['port']
        self.__db = conn_params['db']
        self.__user = conn_params['user']
        self.__passwd = conn_params['passwd']
        self.__charset = conn_params['charset']

    # 建立数据库连接和打开游标
    def __connect(self):
        self.__conn = connect(host=self.__host, port=self.__port, db=self.__db,
                              user=self.__user, passwd=self.__passwd, charset=self.__charset,
                              cursorclass=cursors.DictCursor,    # 结果以字典形式返回,否则以元组形式返回
                              #   autocommit=True   # autocommit会自动提交,刷新内存,同步数据库
                              )
        self.__cursor = self.__conn.cursor()

    # 关闭游标和关闭连接
    def __close(self):
        self.__cursor.close()
        self.__conn.close()

    # 取一条数据
    def get_one(self, sql, params):
        result = None
        try:
            self.__connect()
            self.__cursor.execute(sql, params)
            result = self.__cursor.fetchone()
            self.__close()
        except Exception as e:
            print(e)
        return result

    # 取所有数据
    def get_all(self, sql, params):
        lst = []
        try:
            self.__connect()
            self.__cursor.execute(sql, params)
            lst = self.__cursor.fetchall()
            self.__close()
        except Exception as e:
            print(e)
        return lst

    # 增加数据
    def insert(self, sql, params):
        return self.__edit(sql, params)

    # 修改数据
    def update(self, sql, params):
        return self.__edit(sql, params)

    # 删除数据
    def delete(self, sql, params):
        return self.__edit(sql, params)

    def __edit(self, sql, params):
        count = 0
        try:
            self.__connect()
        except Exception as e:
            print(e)
            return count
        try:
            if isinstance(params, list):
                count = self.__cursor.executemany(sql, params)
            else:
                count = self.__cursor.execute(sql, params)
            self.__conn.commit()
            self.__conn.rollback()
            self.__close()
        except Exception as e:
            print(e)

        return count


if __name__ == "__main__":

    conn = MysqlHelper(MysqlHelper.conn_params1)
    # 插入一条数据
    sql = "insert into student values (%s,%s,%s,%s)"
    params1 = [(3, "wjr", "900120", "male"), (4, "wjr", "900120", "male")]
    params2 = (3, "wjr", "900120", "male")
    count = conn.insert(sql, params1)
    print("插入了"+str(count)+"条数据")
    count = conn.insert(sql, params2)
    print("插入了"+str(count)+"条数据")

    # 修改数据
    sql = "update student set name=%s where idx=%s"
    params = ("wm", 1)
    count = conn.update(sql, params)
    print("更新了" + str(count)+"条数据")

    # 删除数据
    sql = "delete from student where name=%s"
    params = ("wm")
    count = conn.delete(sql, params)
    print("删除了"+str(count)+"条数据")

    # 查询一条数据
    sql = "select * from student where name=%s"
    params = ("wjr")
    data = conn.get_one(sql, params)
    print(data)

    # 查询所有数据
    sql = "select * from student where name=%s"
    params = ("wjr")
    data = conn.get_all(sql, params)
    print(data)
    
    # 只查询某些关键字
    sql = "select idx, passwd from student where name=%s"
    params = ("wjr")
    data = conn.get_all(sql, params)
    print(data)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值