package com.xquant.xpms.commons.config;
import com.zaxxer.hikari.HikariDataSource;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.logging.jdbc.ConnectionLogger;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.sql.SQLException;
/**
* @author: guanglai.zhou
* @date: 2021/9/1 21:57
*/
@Configuration
public class DataSourceConfig {
protected static <T> T createDataSource(DataSourceProperties properties, Class<? extends DataSource> type) {
return (T) properties.initializeDataSourceBuilder().type(type).build();
}
@Bean
@ConfigurationProperties(prefix = "spring.datasource.hikari")
DataSource dataSource(DataSourceProperties properties) {
HikariDataSource dataSource = createDataSource(properties, HikariDataSource.class);
if (StringUtils.hasText(properties.getName())) {
dataSource.setPoolName(properties.getName());
}
return (DataSource) Proxy.newProxyInstance(ClassUtils.getDefaultClassLoader(), new Class<?>[]{DataSource.class},
new HikariDataSourceProxy(dataSource));
}
static class HikariDataSourceProxy implements InvocationHandler {
private final HikariDataSource target;
HikariDataSourceProxy(HikariDataSource target) {
this.target = target;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if ("getConnection".equals(method.getName())) {
Connection connection = (Connection) method.invoke(target, args);
String logId = HikariDataSourceProxy.class.getName();
return getConnection(LogFactory.getLog(logId), connection);
}
return method.invoke(target, args);
}
protected Connection getConnection(Log statementLog, Connection connection) throws SQLException {
if (statementLog.isDebugEnabled()) {
return ConnectionLogger.newInstance(connection, statementLog, 0);
} else {
return connection;
}
}
}
}
以上方式有侵入性,而且其他数据源不支持,在Spring中按照如下方式处理
import com.xquant.xams.commons.jdbc.datasource.DataSourceProxyFactory;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.Ordered;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
@Component
public class RepositoryLoggerProcessor implements BeanPostProcessor, BeanFactoryAware, Ordered {
private ConfigurableListableBeanFactory beanFactory;
private Set<String> dataSourceBeanNames;
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (CollectionUtils.isNotEmpty(dataSourceBeanNames) && dataSourceBeanNames.contains(beanName)) {
return DataSourceProxyFactory.enhancedDataSource((DataSource) bean);
}
return bean;
}
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
String[] names = this.beanFactory.getBeanNamesForType(DataSource.class);
dataSourceBeanNames = Arrays.stream(names).collect(Collectors.toSet());
}
@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
}
}
import com.xquant.xams.commons.jdbc.strength.JdbcRepositoryStrengthor;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.logging.jdbc.ConnectionLogger;
import org.springframework.aop.framework.ReflectiveMethodInvocation;
import org.springframework.aop.interceptor.ExposeInvocationInterceptor;
import org.springframework.data.repository.Repository;
import org.springframework.util.ClassUtils;
import javax.sql.DataSource;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.sql.SQLException;
public class DataSourceProxyFactory {
/**
* 针对数据源进行增强
*
* @param dataSource 数据源实例
* @return 日志增强的数据源(整合了MyBatis日志)
*/
public static DataSource enhancedDataSource(DataSource dataSource) {
return (DataSource) Proxy.newProxyInstance(ClassUtils.getDefaultClassLoader(), new Class<?>[]{DataSource.class},
new DataSourceProxy(dataSource));
}
static class DataSourceProxy implements InvocationHandler {
public static final String GET_CONNECTION = "getConnection";
public static final String DOT = ".";
private final DataSource target;
DataSourceProxy(DataSource target) {
this.target = target;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (GET_CONNECTION.equals(method.getName())) {
String logId = resolveLogIdIfRepository();
if (StringUtils.isNotBlank(logId)) {
Connection connection = (Connection) method.invoke(target, args);
return getConnection(LogFactory.getLog(logId), connection);
}
}
return method.invoke(target, args);
}
private String resolveLogIdIfRepository() {
try {
ReflectiveMethodInvocation methodInvocation = (ReflectiveMethodInvocation) ExposeInvocationInterceptor.currentInvocation();
Object invocationProxy = methodInvocation.getProxy();
if (invocationProxy instanceof Repository) {
Class<?> repositoryInterface = JdbcRepositoryStrengthor.getRepositoryInterface((Repository) invocationProxy);
if (repositoryInterface != null) {
String logId = repositoryInterface.getName();
String methodName = methodInvocation.getMethod().getName();
logId = StringUtils.isNotBlank(methodName) ? logId + DOT + methodName : logId;
return logId;
}
}
} catch (Exception e) {
// ignore
}
return null;
}
protected Connection getConnection(Log statementLog, Connection connection) throws SQLException {
if (statementLog.isDebugEnabled()) {
return ConnectionLogger.newInstance(connection, statementLog, 0);
} else {
return connection;
}
}
}
}