python multiprocessing多进程导致数据库连接不可用问题

问题背景

公司的算法服务为了能够管理每次算法服务请求的计算耗时同时使算法服务能够充分利用CPU的多核处理能力,在响应算法请求时使用了pebble包中的concurrent.process注解,使每个请求的处理由单开的进程来完成,并能够设置进程处理的超时时长。

产生的问题

在算法的处理中需要与pg库操作,使用了psycopg2包来建立pg库连接池。在linux系统下默认使用fork模式创建新进程。
在实际算法服务运行时,发现当出现并发操作时数据库连接会报如下异常:

psycopg2.OperationalError: lost synchronization with server: got message type "

type后面的内容每次可能不太一样。
经过研究发现出现这种情况的问题在于同一个连接同时被两个以上的进程操作导致。这其中包含对以下几个问题的研究,做一下记录。

涉及的知识点

psycopg2的数据库连接池

psycopg2的AbstractConnectionPool类中包含了对连接池中获取连接和连接放回的实现。连接池是由python的list结构维持,取连接使用pop()方法,放回连接使用append()方法。因此连接池的取放操作类似于栈的行为。如果一个连接池只有一个线程在使用,那么每次取到的一定是列表末尾的那个连接。

进程的初始化方式

multiprocessing包提供三种进程初始化方式,spawn, fork, forkserver。其中:

spawn方式不会继承原进程中的文件描述符、网络连接等,所以大部分资源需要重新初始化,所以效率低一些,但是比较安全;
fork方式会以copy on write的方式继承原进程的所有数据,包括文件描述符、网络连接等。所以效率比较高,但是不安全;
forkserver方式的表现和spawn比较接近,但是速度会更快一些;

三种方式的具体解释可以参考python文档:https://docs.python.org/3.10/library/multiprocessing.html#contexts-and-start-methods

问题的原因

所以,问题的原因就是算法服务启动新进程来响应算法请求时使用了fork方式初始化子进程,从而继承了父进程的连接池数据。此时如果有并发的请求发生,另一个子进程也会继承父进程的同一份连接池数据,因此两个子进程存在同时使用数据库连接的情况,而根据对psycopg2的连接池实现的了解,两个子进程只要存在并发情况,必然会同时操作同一个数据库连接。就会造成上面的异常
这个问题在psycopg2的文档中也有描述,参考https://www.psycopg.org/docs/usage.html#thread-and-process-safety。

解决思路

  1. 使用fork方式,需要注意新进程启动后关闭之前的连接(该操作不会造成连接在父进程中被关闭,参考文档:https://www.python.org/dev/peps/pep-0433/#inherited-file-descriptors-issues),重新获取新的连接,在进程结束前注意关闭连接;
  2. 使用spawn或者fork server方式创建进程;

python中的id()的返回

定位问题时发现子进程中继承的父进程中的数据库连接池以及所有连接在日志中打印出的id内容是完全相等的。这是因为在cpython中,id方法返回的是对象在进程中的地址,是相对于进程初始地址的偏移量,不是内存中的绝对地址,所以,只能保证进程内唯一,进程间是有可能出现id值相等的对象的。

TODO:要满足多核计算、流程耗时可控、资源共享最大化。算法服务应如何实现

代码附录

出现数据库连接冲突问题的简化代码如下:

# rds_io.py
from typing import Dict, Tuple
from functools import lru_cache, wraps
from datetime import datetime
import time
import random
import traceback
import os
import multiprocessing as mp

import psycopg2
from psycopg2 import pool
from loguru import logger
import pandas as pd

from pebble import concurrent
from pebble.common import ProcessExpired

from _timing import timing

logger.add('my_log.log')

# FIXME: 同时间 有多个请求时,会报错
#   报错内容: lost synchronization with server
#   当前临时解决方案: 等待一段时间,防止冲突 -> 基本解决
#   造成原因: 未知,可能是
#       + psycopg2 包问题
class RDS_IO:
    def __init__(self):
        self._pool = None

    def is_init(self):
        "返回rds连接是否初始化"

        if self._pool is None:
            return False
        return True

    def init_pg(
        self,
        host: str,
        port: str,
        user: str,
        password: str,
        dbname: str,
        minconn: int = 8,
        maxconn: int = 16,
    ):
        self._pool = psycopg2.pool.ThreadedConnectionPool(
            minconn=minconn,
            maxconn=maxconn,
            host=host,
            port=port,
            dbname=dbname,
            user=user,
            password=password,
        )
        logger.info(f"成功初始化pg数据库连接池,连接数{minconn} - {maxconn}")
        self.minconn = minconn
        self.maxconn = maxconn
        self.host = host
        self.port = port
        self.dbname = dbname
        self.user = user
        self.password = password

    def reinit_pg(self):
        self.close()
        self.init_pg(
            self.host,
            self.port,
            self.user,
            self.password,
            self.dbname,
            self.minconn,
            self.maxconn,
        )
        logger.info(f"重连数据库成功")

    def pool(self):
        return self._pool

    def execute_sql(
        self,
        sql: str,
        data_type: Dict = {},
    ):
        time_tag = datetime.today().strftime("%y%m%d")

        # change `data_type` to immutable type for `lru_cache` use
        data_type = tuple(sorted(data_type.items()))

        ret = self._execute_sql(
            time_tag=time_tag,
            sql=sql,
            data_type=data_type,
        )

        if ret is None:
            msg = f"取数据为空\n\n{sql}"
            raise Exception(msg)

        ret = ret.copy(deep=True)  # 复制一份,防止缓存的结果被修改
        return ret

    def _execute_sql(
        self,
        time_tag,
        sql: str,
        data_type: Tuple,
    ) -> pd.DataFrame:
        """
        time_tag: 用作时间标记,代表了缓存的有效期
        """

        data_type = dict(data_type)

        if not self.is_init():
            msg = "pg数据库连接 未初始化,请用init_pg(*)初始化"
            logger.critical(msg)
            raise RuntimeError(msg)

        logger.debug(f"query {sql}")

        ok = False
        # max_retry_cnt = 3
        max_retry_cnt = 5
        for iid in range(max_retry_cnt):
            # 极端处理:试3次失败,重连数据库,再试2次,但会导致其他连接也被关闭失败,不稳妥
            if iid == 3:
                self.reinit_pg()
            if iid > 0:
                # 重试时 等待 随机的一段时间
                sleep_sec = random.random() * 10
                time.sleep(sleep_sec)

            try:
                print(f'[getconn前] 线程: {os.getpid()} 连接池锁id和状态: {id(self._pool._lock)}, {self._pool._lock.locked()}')
                conn = self._pool.getconn()
                print(f'[getconn后] 线程: {os.getpid()} 连接池锁id和状态: {id(self._pool._lock)}, {self._pool._lock.locked()}')
                logger.info(f'[取] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
                logger.info(
                    f'池状态: {self._pool.closed}, 池连接状态, {[enum.get_transaction_status() for enum in self._pool._pool]}, 池key-连接: {self._pool._used}, 池连接-key: {self._pool._rused}, 池key: {self._pool._keys}')
                flag = False
                with conn.cursor() as cur:
                    try:
                        cur.execute(sql)
                        rows = cur.fetchall()
                        col_names = []
                        for elt in cur.description:
                            col_names.append(elt[0])
                        ok = True
                        if iid > 0:
                            logger.info(f"第{iid+1}次 重试成功")
                    except Exception as e:
                        logger.error(f"第{iid+1}次 执行 sql 失败\n{e}, 丢弃该连接")
                        # 失败则关闭连接,必须经过putconn,否则连接数无法恢复
                        logger.error(f'[异常] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
                        raise e
                    # conn.close()
                    # flag = True
                    # logger.error(f"丢弃该连接")
                logger.info(f'[放前] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
                self._pool.putconn(conn, close=flag)
                logger.info(f'[放后] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
                logger.info(
                    f'池状态: {self._pool.closed}, 池连接状态, {[enum.get_transaction_status() for enum in self._pool._pool]}, 池key-连接: {self._pool._used}, 池连接-key: {self._pool._rused}, 池key: {self._pool._keys}')
                logger.info(str(self._pool._pool))
            except Exception as err:
                logger.error(str(err))
                logger.error(traceback.format_exc())
            if ok:
                break
        else:
            err_msg = f"pg库请求失败 累计重试{max_retry_cnt}次"
            logger.error(err_msg)
            return None

        if len(rows) < 1:
            return None

        ret = pd.DataFrame(
            rows,
            columns=col_names,
        )
        ret = ret.astype(dtype=data_type)

        return ret

    def close(self):
        self._pool.closeall()
        self._pool = None
        logger.info("数据库连接清理完毕!!!")


rds_io_ins = RDS_IO()
rds_host = "192.168.XXX.XXX"
rds_port = 5432
rds_dbname = "XXX"
rds_user = "XXX"
rds_password = "XXX"
ddd = {1:1, 2:2, 3:3}
logger.info("=======================准备初始化连接池")
rds_io_ins.init_pg(host=rds_host, port=rds_port, user=rds_user, password=rds_password, dbname=rds_dbname, minconn=2, maxconn=10)


def new_process(f):
    @wraps(f)
    def func_wrapper(*args, **kwargs):
        # rds_io_ins.reinit_pg()如果要解决并发冲突,放开这行注释
        result = f(*args, **kwargs)
        # rds_io_ins.close()如果要解决并发冲突,放开这行注释
        return result
    return func_wrapper
    
@concurrent.process(timeout=None, context=mp.get_context('fork'))
@new_process
def select_test():
    for i in range(0,60):
        logger.info(f"第{i}次执行sql")
        data = rds_io_ins.execute_sql("select * from road_segment_dongguan limit 500")
        logger.info(f"{data}")

def action():
    future = select_test()
    result = future.result()
    logger.info(f"{result}")

# timed_select_test = 

def timed_select_test():
    return timing.timeout(20)(select_test)
# concurrent_test.py
from functools import wraps
from math import radians
from threading import Thread
import os
from multiprocessing import managers, Value

from loguru import logger

from rds_io import rds_io_ins
from rds_io import select_test
from rds_io import action

if __name__ == '__main__':
    for i in range(10):
        thread = Thread(
            target=action,
            args=(
            ),
        )
        thread.start()
# requirements.txt
loguru
pebble
psycopg2
pandas
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值