1、背景
在我们的实际项目中,尤其以Web服务为例,经常遇到要做日志跟踪的场景。我们经常采用的方式是,生成一个trace_id, 在全链路的调用中都输出这个id进行跟踪。
这里需要处理的几个问题是:
(1)请求间的隔离
(2)全链路同id
(3)跟踪的独立性,不涉及业务代码(日志打印)改造
解决以上三个问题,我们需要借助请求会话和logging扩展。
2、请求会话
每一次的request请求,对应一次会话,请求与请求之间本身就是隔离的。所以,每次会话开始时读取一次request的trace_id作为当前请求的id
(1)可以在请求任意位置,比如Header,存放trace_id, 则:
(2)也可以在接受到requet的入口开始,生成一个trace_id(可结合uuid)
3、logging扩展
Python的logging支持Filter扩展,做法为:
(1)新增一个自定义的Filter,在这个Filter中记录,trace_id要会话保持一致
(2)Filter只需要重写filter方法,并返回True即可
(3)重写日志输出格式,日志输出中新增自定义的 trace_id
4、多线程多进程
在Web中,我们使用logging都是使用的单例模式,服务启动构造一个logger,于是:
(1)服务init时初始化一个logger
(2)每次会话,从当前的request中取出trace_id给到logger
但是在实际业务中,由于业务计算量较大,往往需要开启多线程/多进程的异步计算,这个时候在这些异步线程中也需要打印日志。
由于request对象是线程隔离的,如果在Filter保存了request对象,那么在异步线程中,无法获取到真正的request。这也是为什么我们在filter中只能保存trace_id的原因。因此涉及多线程/多进程时,需要将主线程中的trace_id传到异步线程中,这里采用local技术。
5、完整代码
trace_logger.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file : trace_logger.py
@author: Rico Wu
@create: 2021/10/20 15:29
@desc : 在logging的基础上封装自定义Logger,提供了链路跟踪功能
"""
import uuid
import sys
import logging
import threading
from concurrent_log_handler import ConcurrentRotatingFileHandler
# 当前线程的local_trace, 需要添加全局trace_id, 使用示例:trace.trace_id
local_trace = threading.local()
# 每个日志文件的最大字节数
MAX_BYTES = 1024 * 1024 * 10
# 最大日志文件备份数
BACKUP_COUNT = 30
# trace_id header
TRACE_ID_HEADER = "trace_id"
# 自定义trace_filter属性名
TRACE_FILTER_ATTR = "trace_filter"
# 默认格式化输出
DEFAULT_FORMATTER_STR = "%(asctime)s::%(thread)s::TRACE::%(trace_id)s::%(" \
"filename)s::%(lineno)s::%(levelname)s::%(message)s"
class TraceLogger:
@staticmethod
def get_logger(log_file, log_level=logging.INFO, formatter_str=""):
"""
生成带全链路trace_id的logger
@param log_file: 日志文件名
@param log_level: 日志级别
@param formatter_str: 格式化字符串
@return:
"""
# 这里注册session 上下文追踪一次就可以了
log_file = log_file
logger = logging.getLogger(log_file)
logger.setLevel(log_level)
# 添加日志跟踪filter
trace_filter = TraceFilter()
logger.addFilter(trace_filter)
# 自定义格式日志格式,添加trace_id
f_str = formatter_str if formatter_str else DEFAULT_FORMATTER_STR
formatter = logging.Formatter(f_str)
file_handler = ConcurrentRotatingFileHandler(filename=log_file,
maxBytes=MAX_BYTES,
backupCount=BACKUP_COUNT,
encoding="utf-8",
delay=False)
file_handler.suffix = '%Y-%m-%d.log'
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 终端显示日志
console_handler = logging.StreamHandler(stream=sys.stdout)
console_handler.setFormatter(formatter)
console_handler.setLevel(log_level)
logger.addHandler(console_handler)
# 扩展 trace_filter属性
setattr(logger, TRACE_FILTER_ATTR, trace_filter)
return logger
class TraceFilter(logging.Filter):
"""
通过在record中添加trace_id, 实现调用跟踪和日志打印的分离
"""
Default_Trace_Id = f"DEFAULT_{str(uuid.uuid1())}"
def __init__(self, name=""):
"""
init
@param name: filter name
"""
super().__init__(name)
def filter(self, record):
"""
重写filter方法
@param record: record
@return:
"""
trace_id = local_trace.trace_id
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = TraceFilter.Default_Trace_Id
return True
app.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file : scheduler_client_demo.py
@author: Rico Wu
@create: 2021/10/20 15:37
@desc : Demo for client side of flask service
"""
import json
import time
import logging
import threading
from flask import Flask, request
from trace_logger import TraceLogger, local_trace, \
TRACE_ID_HEADER
# flask app
app = Flask(__name__)
logger = TraceLogger.get_logger("logs.log", logging.DEBUG)
def thread_worker(trace_id: str):
"""
thread worker
@param trace_id: trace_id
@return:
"""
local_trace.trace_id = trace_id
time.sleep(3)
logger.info("【多线程】日志打印演示..")
def proc_worker(trace_id: str):
"""
real proc worker
@param trace_id: trace_id
@return:
"""
local_trace.trace_id = trace_id
time.sleep(3)
logger.info("【多进程】日志打印演示..")
@app.before_request
def set_trace_id():
"""
写在before request中, 对原路由无侵入
@return:
"""
local_trace.trace_id = request.headers.get(TRACE_ID_HEADER, "")
@app.route("/req", methods=["GET", "POST"])
def index():
"""
请求入口
@return:
"""
# logger示例
logger.debug("主线程:debug示例..")
logger.info("主线程:info示例..")
logger.warning("主线程:warning示例..")
logger.error("主线程:error示例..")
# 线程任务演示
t = threading.Thread(target=thread_worker, args=(local_trace.trace_id,))
t.start()
# 多进程任务演示
from multiprocessing import Process
p = Process(target=proc_worker, args=(local_trace.trace_id,))
p.start()
# 返回值
result = {"status": 0, "data": None, "message": "asynchronous task "
"running.."}
return json.dumps(result)
if __name__ == "__main__":
app.run(port=8880, threaded=True)