用ThreadLocal做链路追踪(初版、升级版、最终版)

前言
1、ThreadLocal是线程变量,线程之间彼此隔离,天生线程安全。因为它是跟着线程走的,考虑到这点,它很适合做链路追踪(TraceId)
2、当我们写的接口接收到其它地方(可能是前端、也可能是其它服务)发来的请求时,此刻,我们的接口所在的服务称作服务端【Server】,而请求方称作客户端【Client】;当我们的接口中再请求其他服务,此刻,我们的接口所在的服务称作客户端【Client】,而被请求方称作服务端【Server】

线程变量承载体

public class TraceIdHolder {

    // 初版, 普通的 ThreadLocal, 只适用于不开辟子线程的情况
    //private static final ThreadLocal<String> TRACE_ID_HOLDER = new ThreadLocal<>();

    // 升级版,new子线程时,为了让线程变量继承,得改用 InheritableThreadLocal
    //private static final ThreadLocal<String> TRACE_ID_HOLDER = new InheritableThreadLocal<>();

    // 最终版,线程池中开辟线程时,存在线程复用(上个请求使用的线程的线程变量会残留),得改用 TransmittableThreadLocal + TtlExecutors
    private static final ThreadLocal<String> TRACE_ID_HOLDER = new TransmittableThreadLocal<>();

    public TraceIdHolder() {
    }

    public static void set(String traceId) {
        TRACE_ID_HOLDER.set(traceId);
    }

    public static String get() {
        return TRACE_ID_HOLDER.get();
    }

    public static String remove() {
        String traceId = TRACE_ID_HOLDER.get();
        TRACE_ID_HOLDER.remove();
        return traceId;
    }
}

RequestWrapper

import org.springframework.util.StreamUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;

public class RequestWrapper extends HttpServletRequestWrapper {

    private final byte[] body;

    public RequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        //保存一份InputStream,将其转换为字节数组
        body = StreamUtils.copyToByteArray(request.getInputStream());
    }

    //转换成String
    public String getBodyString(){
        return new String(body, StandardCharsets.UTF_8);
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    //把保存好的InputStream,传下去
    @Override
    public ServletInputStream getInputStream() throws IOException {

        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {

            @Override
            public int read() {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }
}

打印日志
服务端的日志打印很好做,用过滤器Filter即可,每次请求打过来,记录下请求路径、请求头、参数

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.cqf.config.constant.Constant;
import com.cqf.config.wrapper.RequestWrapper;
import com.cqf.threadLocal.TraceIdHolder;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Enumeration;
import java.util.UUID;

@Order(1)
@WebFilter(filterName = "logFilter", urlPatterns = "/*")
public class LogFilter implements Filter {

    private static final Logger LOGGER = LoggerFactory.getLogger(LogFilter.class);
    private static final String GET = "GET";
    private static final String POST = "POST";
    private static final String PUT = "PUT";
    private static final String DELETE = "DELETE";
    private static final List<String> USELESS_HEADER = Lists.newArrayList("cookie", "postman-token", "host", "user-agent", "accept", "accept-encoding", "connection", "content-length", "accept");

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        String url = request.getRequestURL().toString();
        String method = request.getMethod();
        Enumeration<String> headerNames = request.getHeaderNames();
        String next;
        StringBuilder bodyBuffer = new StringBuilder();
        boolean hasTraceId = false;
        while ((next = headerNames.nextElement()) != null) {
            // 过滤掉postman自带的请求头
            if (USELESS_HEADER.contains(next.toLowerCase())) {
                continue;
            }
            String headerValue = request.getHeader(next);
            // 每次外部请求过来,在这里设置trace-id
            if ("trace-id".equalsIgnoreCase(next) && StringUtils.isNotBlank(headerValue)) {
                hasTraceId = true;
                TraceIdHolder.set(headerValue);
            }
            bodyBuffer.append(next).append("=").append(headerValue).append(", ");
        }
        if (!hasTraceId) {
            String traceId = UUID.randomUUID().toString().replace("-", "");
            bodyBuffer.append("trace-id").append("=").append(traceId).append(", ");
            TraceIdHolder.set(traceId);
        }
        if (StringUtils.isNotBlank(bodyBuffer.toString())) {
            bodyBuffer.delete(bodyBuffer.length() - 2, bodyBuffer.length());
        }
        RequestWrapper requestWrapper = new RequestWrapper(request);
        String bodyString = requestWrapper.getBodyString();
        String payload;
        if (StringUtils.startsWith(bodyString, "{")) {
            payload = JSON.toJSONString(JSONObject.parseObject(bodyString),
                    SerializerFeature.PrettyFormat,
                    SerializerFeature.WriteMapNullValue,
                    SerializerFeature.WriteDateUseDateFormat);
        } else {
            payload = JSON.toJSONString(JSONObject.parseArray(bodyString),
                    SerializerFeature.PrettyFormat,
                    SerializerFeature.WriteMapNullValue,
                    SerializerFeature.WriteDateUseDateFormat);
        }

        // before request
        String msg = "[Server] Before request [" +
                method +
                " uri=" +
                url +
                "; " +
                "headers={" +
                bodyBuffer.toString() +
                "}]";
        if (!GET.equalsIgnoreCase(method)) {
            msg = msg + "; payload=" + payload + "]";
        }
        LOGGER.info(msg);
        // 在执行链调用前就要把trace-id放在响应头里
        HttpServletResponse response = (HttpServletResponse) servletResponse;
        response.setHeader("trace-id", TraceIdHolder.get());
        filterChain.doFilter(requestWrapper, servletResponse);
        int status = response.getStatus();

        // after request
        LOGGER.info("[Server] After request, status=" +
                status +
                " [" +
                method +
                " uri=" +
                url +
                "; " +
                "headers={" +
                bodyBuffer.toString() +
                "}]");
    }
}

作为客户端时,自己是主动发起方,所以要给RestTemplate设置拦截器,发送请求前打印日志

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.cqf.config.constant.Constant;
import com.cqf.threadLocal.TraceIdHolder;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

@Configuration
public class WebConfig {

    private static final Logger LOGGER = LoggerFactory.getLogger(WebConfig.class);
    private final String GET = "GET";
    private final String POST = "POST";
    private final String PUT = "PUT";
    private final String DELETE = "DELETE";

    private ClientHttpRequestInterceptor logClientHttpRequestInterceptor() {
        return ((httpRequest, body, clientHttpRequestExecution) -> {
            // 设置公共请求头
            httpRequest.getHeaders().set("trace-id", TraceIdHolder.get());

            String methodValue = httpRequest.getMethodValue();
            HttpHeaders headers = httpRequest.getHeaders();
            String payload = new String(body, 0, body.length, StandardCharsets.UTF_8.name());
            StringBuffer buffer = new StringBuffer("");
            headers.forEach((key, value) -> {
                if (!Constant.USELESS_HEADER.contains(key.toLowerCase())) {
                    buffer.append(key).append("=").append(value.get(0)).append(", ");
                }
            });
            if (StringUtils.isNotBlank(buffer.toString())) {
                buffer.delete(buffer.length() - 2, buffer.length());
            }

            // 格式化json字符串,方便再日志中查看
            payload = JSON.toJSONString(JSONObject.parseObject(payload),
                    SerializerFeature.PrettyFormat,
                    SerializerFeature.WriteMapNullValue,
                    SerializerFeature.WriteDateUseDateFormat);

            // before request
            String msg = "[Client] Before request [" +
                    methodValue +
                    " uri=" +
                    httpRequest.getURI() +
                    "; " +
                    "headers={" +
                    buffer.toString() + "}";

            if (!GET.equalsIgnoreCase(methodValue)) {
                msg = msg + "; payload=" + payload;
            }
            msg += "]";
            LOGGER.info(msg);
            ClientHttpResponse response = clientHttpRequestExecution.execute(httpRequest, body);

            // after request
            LOGGER.info("[Client] After request, status=" +
                    response.getStatusCode().value() +
                    " [" +
                    methodValue +
                    " uri=" +
                    httpRequest.getURI() +
                    "; " +
                    "headers={" +
                    buffer.toString() +
                    "}]");
            return response;
        });
    }

    @Bean
    public RestTemplate restTemplate() {
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(5000);
        requestFactory.setReadTimeout(5000);
        RestTemplate restTemplate = new RestTemplate(requestFactory);
        // 给RestTemplate 设置拦截器,发送请求前打印日志
        ClientHttpRequestInterceptor interceptor = logClientHttpRequestInterceptor();
        List<ClientHttpRequestInterceptor> interceptorList = new ArrayList<>();
        interceptorList.add(interceptor);
        restTemplate.setInterceptors(interceptorList);
        return restTemplate;
    }
}

Controller的编写

    @Resource
    private RestTemplate restTemplate;

    // 最终版使用,配合TransmittableThreadLocal一起用才生效。线程池需要包装一下。线程数量设置为1是为了发起第二次请求时就能暴露出线程复用带来的问题
    private static final ExecutorService executorService = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(1));

    @GetMapping("/method1")
    public String method1() {
        // 初版, 普通的 ThreadLocal, 只适用于不开辟子线程的情况
        //restTemplate.exchange("http://localhost:5555/test02", HttpMethod.GET, new HttpEntity<>(null, null), String.class);

        // 升级版,new子线程时,为了让线程变量继承,得改用 InheritableThreadLocal
        //new Thread(() -> {
        //    restTemplate.exchange("http://localhost:5555/test02", HttpMethod.GET, new HttpEntity<>(null, null), String.class);
        //}).start();

        // 最终版,线程池中开辟线程时,存在线程复用(上个请求使用的线程的线程变量会残留),得改用 TransmittableThreadLocal + TtlExecutors
        CompletableFuture.runAsync(() -> {
            restTemplate.exchange("http://localhost:5555/test02", HttpMethod.GET, new HttpEntity<>(null, null), String.class);
        }, executorService);
        try { future.get(); } catch (Exception e) { e.printStackTrace(); }
        return "1";
    }

最终版,postMan连续发送两次请求的日志
在这里插入图片描述
总结
初版(ThreadLocal)的缺点:当new新线程时,子线程获取不到父线程的变量,导致trace-id丢失。
于是出现了升级版(InheritableThreadLocal):解决了初版的问题,但是当线程池中开辟线程时,线程复用会残留上一次的trace-id,导致混乱不准。
于是出现了最终版(TransmittableThreadLocal + TtlExecutors)

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值