bytekit是阿里开源的字节码增强工具,可以很方便的提供字节码增强的api
pom
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>bytekit-core</artifactId>
<version>0.0.4</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>net.bytebuddy</groupId>
<artifactId>byte-buddy-agent</artifactId>
<version>1.10.18</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.benf</groupId>
<artifactId>cfr</artifactId>
<version>0.150</version>
<optional>true</optional>
</dependency>
工具类
import com.alibaba.bytekit.agent.inst.NewField;
import com.alibaba.bytekit.asm.MethodProcessor;
import com.alibaba.bytekit.asm.interceptor.InterceptorProcessor;
import com.alibaba.bytekit.asm.interceptor.annotation.InterceptorParserHander;
import com.alibaba.bytekit.asm.interceptor.parser.DefaultInterceptorClassParser;
import com.alibaba.bytekit.utils.AsmUtils;
import com.alibaba.deps.org.objectweb.asm.ClassReader;
import com.alibaba.deps.org.objectweb.asm.Type;
import com.alibaba.deps.org.objectweb.asm.tree.AnnotationNode;
import com.alibaba.deps.org.objectweb.asm.tree.ClassNode;
import com.alibaba.deps.org.objectweb.asm.tree.FieldNode;
import com.alibaba.deps.org.objectweb.asm.tree.MethodNode;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
public class EnhanceUtil {
private static final String CONSTRUCT_MATCH_METHOD = "constructMatch";
private static final String STATIC_MATCH_METHOD = "staticMatch";
private static final String METHOD_MATCH_METHOD = "methodMatch";
public static byte[] enhanceClass(byte[] bytes, Class interceptorClass,ClassLoader classLoader) throws Exception {
// 初始化Instrumentation
// AgentUtils.install();
// 解析定义的 Interceptor类 和相关的注解
DefaultInterceptorClassParser interceptorClassParser = new DefaultInterceptorClassParser();
// 源文件字节码classNode
ClassNode originClassNode = AsmUtils.toClassNode(bytes);
//interceptor classNode
// ClassNode apmClassNode = AsmUtils.loadClass(interceptorClass);
ClassNode apmClassNode = getInterceptorClassNode(interceptorClass,classLoader);
ClassNode targetClassNode = AsmUtils.copy(originClassNode);
//把interceptor classNode 名称换成originClassNode 的名称
byte[] renameClass = AsmUtils.renameClass(AsmUtils.toBytes(apmClassNode), Type.getObjectType(originClassNode.name).getClassName());
apmClassNode = AsmUtils.toClassNode(renameClass);
//添加@NewField 的字段
List<FieldNode> fieldNodes = apmClassNode.fields;
List<String> fieldMethodName = new ArrayList<>();
if (fieldNodes != null){
for (FieldNode fieldNode : fieldNodes){
if (fieldNode.visibleAnnotations != null){
for( AnnotationNode annotationNode : fieldNode.visibleAnnotations) {
if(Type.getType(NewField.class).equals(Type.getType(annotationNode.desc))) {
String fieldName = fieldNode.name;
fieldName = fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
fieldMethodName.add("set"+fieldName);
fieldMethodName.add("get"+fieldName);
AsmUtils.addField(targetClassNode, fieldNode);
}
}
}
}
}
//添加字段的get set方法
for (MethodNode methodNode : apmClassNode.methods){
if (fieldMethodName.contains(methodNode.name)){
AsmUtils.addMethod(targetClassNode,methodNode);
}
}
// List<String> methodNameList = Arrays.asList(targetMethodNames);
//根据注解解析出拦截器InterceptorProcessor
List<InterceptorProcessor> processors = interceptorClassParser.parse(interceptorClass);
// 对加载到的字节码做增强处理
for (MethodNode methodNode : targetClassNode.methods) {
Object flag = null;
//构造函数
if ("<init>".contains(methodNode.name)) {
try {
Method method = interceptorClass.getMethod(CONSTRUCT_MATCH_METHOD,String.class);
flag = method.invoke(null,methodNode.desc);
}catch (Throwable e){
System.err.println("get construct math method error");
e.printStackTrace();
}
// flag = byteKitInterceptor.constructMatch(methodNode.desc);
}else if ("<clinit>".contains(methodNode.name)){
//静态方法快
try {
Method method = interceptorClass.getMethod(STATIC_MATCH_METHOD,String.class);
flag = method.invoke(null,methodNode.desc);
}catch (Throwable e){
System.err.println("get static math method error");
e.printStackTrace();
}
} else {
//方法
try {
Method method = interceptorClass.getMethod(METHOD_MATCH_METHOD,String.class,String.class);
flag = method.invoke(null,methodNode.name,methodNode.desc);
}catch (Throwable e){
System.err.println("get method math method error");
e.printStackTrace();
}
// flag = byteKitInterceptor.methodMatch(methodNode.name,methodNode.desc);
}
if (flag!= null && "true".equals(flag.toString())){
MethodProcessor methodProcessor = new MethodProcessor(targetClassNode, methodNode);
for (InterceptorProcessor interceptor : processors) {
System.out.println("----------------[SF Agent bytekit]EnhanceUtil interceptor className:"+targetClassNode.name +" method:"+methodNode.name +" methodDesc:"+methodNode.desc +" methodProcessor:"+methodProcessor.getMethodNode().name );
//生成拦截器代码
interceptor.process(methodProcessor);
}
}
// if (methodNameList.contains(methodNode.name)) {
// MethodProcessor methodProcessor = new MethodProcessor(targetClassNode, methodNode);
// for (InterceptorProcessor interceptor : processors) {
// interceptor.process(methodProcessor);
// }
// }
}
ClassReader classReader = AsmUtils.toClassNode(bytes, originClassNode);
// ClassReader classReader2 = new ClassReader(bytes);
// System.out.println(classReader2);
byte[] resutlBytes = AsmUtils.toBytes(targetClassNode,classLoader,classReader);
// 获取增强后的字节码
return resutlBytes;
}
private static ClassNode getInterceptorClassNode(Class interceptorClass,ClassLoader classLoader) throws Exception {
String resource = interceptorClass.getName().replace('.', '/') + ".class";
//
// String path = interceptorClass.getProtectionDomain().getCodeSource().getLocation().getFile();
// JarFile jarFile = new JarFile(path);
//
// JarEntry jarEntry = jarFile.getJarEntry(resource);
// InputStream is = jarFile.getInputStream(jarEntry);
ClassLoader interceptorClassClassLoader = interceptorClass.getClassLoader();
if (interceptorClassClassLoader != null){
return AsmUtils.loadClass(interceptorClass);
}
//
interceptorClassClassLoader = ClassLoader.getSystemClassLoader();
InputStream is = interceptorClassClassLoader.getResourceAsStream(resource);
ClassReader cr = new ClassReader(is);
ClassNode classNode = new ClassNode();
cr.accept(classNode, ClassReader.SKIP_FRAMES);
return classNode;
}
// public static void main(String[] args) {
// URL resource = ClassLoader.getSystemClassLoader().getResource("sun/net/www/http/HttpClient.class");
// System.out.println(resource);
// }
}
使用
BytekitTransformer 字节码增强入口类
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.lang.reflect.Method;
import java.security.ProtectionDomain;
import java.util.*;
public class BytekitTransformer implements ClassFileTransformer,Runnable {
private static Map<String, String> enhanceModels = new HashMap<>();
private static Map<String, TransformerLoad> TransformerLoadMap = new HashMap<>();
private static int delayStartRetransformTime = -1;
static {
//后面可以改成spi扩展的方式
enhanceModels.put("org/apache/logging/log4j/core/layout/PatternLayout$Builder","bytekit.Log4j2ByteEnhance");
enhanceModels.put("ch/qos/logback/core/pattern/PatternLayoutBase","bytekit.LogbackByteEnhance");
enhanceModels.put("ch/qos/logback/classic/PatternLayout","bytekit.LogbackPatternByteEnhance");
// enhanceModels.put("sun/net/www/protocol/http/HttpURLConnection","HttpUrlConnectionByteEnhance");
// enhanceModels.put("java/net/HttpURLConnection","HttpUrlConnectionByteEnhance");
// enhanceModels.put("sun/net/www/http/HttpClient","HttpUrlConnectionByteEnhance");
}
@Override
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
if (enhanceModels.containsKey(className)) {
System.out.println("[bytekit] "+new Date()+" bytekitTransformer transform loader:"+loader+" className:"+className);
String abstractByteEnhance;
TransformerLoad transformerLoad = null;
try {
abstractByteEnhance = enhanceModels.get(className);
Class abstractByteEnhanceCla = null;
try {
if (loader == null){
loader = getClassLoader(abstractByteEnhance);
}
transformerLoad = getTransformerLoad(className,loader,abstractByteEnhance);
if (transformerLoad.getLoadCount() >0 ){
System.out.println("[ bytekit] "+new Date()+" bytekitTransformer transform 已经加载过 loader:"+loader+" className:"+className+" plugin:"+abstractByteEnhance);
return classfileBuffer;
}
transformerLoad.setLoadCount();
abstractByteEnhanceCla = loader.loadClass(abstractByteEnhance);
}catch (Throwable e){
System.err.println("[bytekit] "+new Date()+" bytekitTransformer loader加载不到类 loader:"+loader+" className:"+className+" plugin:"+abstractByteEnhance+" e:"+e.getMessage());
if (transformerLoad != null){
transformerLoad.setErrorCount();
transformerLoad.setError(e.getMessage());
}
return classfileBuffer;
}
if (abstractByteEnhanceCla == null){
System.err.println("[bytekit] "+new Date()+" bytekitTransformer abstractByteEnhance is null loader:"+loader+" className:"+className+" plugin:"+abstractByteEnhance);
throw new RuntimeException("abstractByteEnhanceCla is null");
}
Object obj = abstractByteEnhanceCla.newInstance();
Method enhance = abstractByteEnhanceCla.getMethod("enhance",ClassLoader.class,String.class,Class.class,ProtectionDomain.class,classfileBuffer.getClass());
byte[] bytes = (byte[])enhance.invoke(obj,loader,className,classBeingRedefined,protectionDomain,classfileBuffer);
transformerLoad.setSuccessCount();
System.out.println("[bytekit] "+new Date()+" bytekitTransformer transform 执行插件类enhance成功 loader:"+loader+" className:"+className+" plugin:"+abstractByteEnhance);
return bytes;
// Method math = abstractByteEnhanceCla.getMethod("math",ClassLoader.class,String.class,Class.class,ProtectionDomain.class);
// Boolean flag = (Boolean) math.invoke(obj,loader,className,classBeingRedefined,protectionDomain);
// if (flag){
// byte[] bytes = abstractByteEnhance.enhance(loader,className,classBeingRedefined,protectionDomain,classfileBuffer);
// System.err.println(Decompiler.decompile(bytes));
// return bytes;
// }
}catch (Throwable e){
e.printStackTrace();
System.err.println("[bytekit] BytekitTransformer transform error e:"+e.getMessage());
if (transformerLoad != null){
transformerLoad.setErrorCount();
transformerLoad.setError(e.getMessage());
}
}
}
return classfileBuffer;
}
/**
* bootstrap 的类加载loader是null,获取系统类加载器
* @return
*/
private ClassLoader getClassLoader(String cla){
ClassLoader loader;
loader = GovernanceAgent.class.getClassLoader();
try {
loader.loadClass(cla);
} catch (ClassNotFoundException e) {
loader = Thread.currentThread().getContextClassLoader();
try {
loader.loadClass(cla);
} catch (ClassNotFoundException e1) {
System.out.println("[bytekit] "+new Date()+" bytekitTransformer getClassLoader error loader:"+loader+" className:"+cla+ " e:"+e.getMessage());
return null;
}
}
return loader;
}
private static TransformerLoad getTransformerLoad(String className,ClassLoader classLoader,String enhance){
String key = className+"_"+classLoader.getClass().getTypeName();
TransformerLoad transformerLoad = TransformerLoadMap.get(key);
if (transformerLoad != null){
return transformerLoad;
}
transformerLoad = new TransformerLoad();
transformerLoad.setClassName(className);
transformerLoad.setByteEnhance(enhance);
transformerLoad.setClassLoader(classLoader);
TransformerLoadMap.put(key,transformerLoad);
return transformerLoad;
}
public void start(){
Thread thread = new Thread(this::run);
thread.setName("bytekitTransformer thread");
thread.setDaemon(true);
thread.start();
}
@Override
public void run() {
delayStartRetransformTime = Integer.valueOf(EnvUtils.getProperty("GOV_RETRANSORM_TIME","-1"));
if (delayStartRetransformTime < 0){
return;
}
System.out.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强启动");
Instrumentation inst = InstrumentConfig.INSTANCE.get();
try {
Thread.sleep(delayStartRetransformTime);
} catch (InterruptedException e) {
e.printStackTrace();
}
List<Class> unEnhance = new ArrayList<>();
ClassLoader classLoader = ClassLoaderUtil.getClassLoader();
System.out.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强 loader:"+classLoader);
enhanceModels.forEach((className,enhance) ->{
try {
className = className.replaceAll("/",".");
Class cla = classLoader.loadClass(className);
unEnhance.add(cla);
} catch (Exception e) {
System.err.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强加载增强类失败 loader:"+classLoader+" className:"+className+" enhance:"+enhance+" e:"+e.getMessage());
}
});
for (Class aClass : unEnhance){
if (!inst.isModifiableClass(aClass) || !inst.isRetransformClassesSupported()){
System.out.println("[bytekit] "+new Date()+" bytekitTransformer 不能retransformClasses className:"+aClass);
}
try {
System.out.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强 aClass"+aClass);
inst.retransformClasses(aClass);
} catch (Exception e) {
e.printStackTrace();
System.out.println("[bytekit] "+new Date()+" bytekitTransformer retransformClasses error className:"+aClass);
}
}
}
/**
* 转换加载模型
*/
public static class TransformerLoad {
private ClassLoader classLoader;
private String className;
private String ByteEnhance;
private int loadCount;
private int successCount;
private int errorCount;
public ClassLoader getClassLoader() {
return classLoader;
}
public void setClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader;
}
private List<String> error = new ArrayList<>();
public String getClassName() {
return className;
}
public void setClassName(String className) {
this.className = className;
}
public String getByteEnhance() {
return ByteEnhance;
}
public void setByteEnhance(String byteEnhance) {
ByteEnhance = byteEnhance;
}
public int getLoadCount() {
return loadCount;
}
public void setLoadCount() {
this.loadCount++;
}
public int getSuccessCount() {
return successCount;
}
public void setSuccessCount() {
this.successCount++;
}
public int getErrorCount() {
return errorCount;
}
public void setErrorCount() {
this.errorCount++;
}
public List<String> getError() {
return error;
}
public void setError(String error) {
this.error.add(error);
}
}
}
抽象的增强类
import java.security.ProtectionDomain;
public abstract class AbstractByteEnhance {
public byte[] enhance(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer){
try {
if (math(loader,className,classBeingRedefined,protectionDomain)){
byte[] bytes = EnhanceUtil.enhanceClass(classfileBuffer,getInterceptorClass(),loader);
return bytes;
}
System.out.println("AbstractByteEnhance bytekit enhance math匹配失败 className:"+className);
} catch (Throwable e) {
System.err.println("AbstractByteEnhance bytekit enhance error className:"+className);
e.printStackTrace();
}
return classfileBuffer;
}
public abstract Class getInterceptorClass();
/**
* 匹配是否进行字节码增强
* @param loader
* @param className
* @param classBeingRedefined
* @param protectionDomain
* @return
*/
public boolean math(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain){
try {
Class interceptorClass = getInterceptorClass();
if (NullInterceptor.class.equals(interceptorClass)){
System.err.println("AbstractByteEnhance bytekit 类加载器找不到 增强目标类:"+className+" loader:"+loader);
return false;
}
loader.loadClass(interceptorClass.getName());
return true;
} catch (Throwable e) {
System.err.println("AbstractByteEnhance bytekit 类加载器找不到 "+getInterceptorClass().getName()+" loader:"+loader+" e:"+e.getMessage());
return false;
}
}
}
拦截器
对logback进行增强
import com.sf.plough.governance.agent.dependencies.com.alibaba.bytekit.asm.binding.Binding;
import com.alibaba.bytekit.asm.interceptor.annotation.AtExceptionExit;
import com.alibaba.bytekit.asm.interceptor.annotation.AtExit;
import java.lang.reflect.Field;
public class LogbackInterceptor {
// private static Logger logger = Logger.getLogger("");
// public static Logger logger = LoggerFactory.getLogger("sf-bytekit");
// private final static String constructArgs = "(Ljava/lang/String;Lorg/apache/logging/log4j/core/Filter;Lorg/apache/logging/log4j/core/Layout;Z[Lorg/apache/logging/log4j/core/config/Property;)V";
// 拦截方法Entry点进行处理
/*@AtEnter(inline = true, suppress = RuntimeException.class, suppressHandler = PrintExceptionSuppressHandler.class)
public static void atEnter(@Binding.This Object object,
@Binding.Class Object clazz,
@Binding.Args Object[] args,
@Binding.MethodName String methodName,
@Binding.MethodDesc String methodDesc) throws NoSuchFieldException {
}*/
// 拦截方法正常返回的语句,在返回前进行处理
@AtExit(inline = true)
public static void atExit(@Binding.This Object object,
@Binding.Class Object clazz,
@Binding.Args Object[] args,
@Binding.MethodName String methodName,
@Binding.MethodDesc String methodDesc,
@Binding.Return Object returnObject) throws Exception {
if ("setPattern".equals(methodName) && args != null && args.length == 1){
Class cla = object.getClass().getSuperclass();
Field field = cla.getDeclaredField("pattern");
field.setAccessible(true);
String pattern = PatternUtil.patternConvert(String.valueOf(args[0]),"%tid");
// AgentLogger.info("--------logback PatternLayoutBase setPattern2 重写pattern:"+pattern);
//因为spring boot项目logger不生效
System.out.println("logback PatternLayoutBase setPattern 重写pattern:"+pattern);
field.set(object,pattern);
}
}
// 拦截方法内部抛出异常点
@AtExceptionExit(inline = true, onException = RuntimeException.class)
public static void atExceptionExit(@Binding.Class Object clazz,@Binding.Throwable RuntimeException ex/*,
@Binding.Field(name = "exceptionCount") int exceptionCount*/) {
ex.printStackTrace();
// AgentLogger.error("[SF AGENT]LogbackInterceptor atExceptionExit error, clazz:" + clazz,ex);
throw ex;
}
public static Boolean methodMatch(String method,String methodDesc){
return "setPattern".equals(method);
}
public static Boolean constructMatch(String methodDesc){
return false;
}
public static Boolean staticMatch(String methodDesc){
return false;
}
}