字节码增强工具bytekit封装使用

7 篇文章 0 订阅

bytekit是阿里开源的字节码增强工具,可以很方便的提供字节码增强的api

pom

<dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>bytekit-core</artifactId>
            <version>0.0.4</version>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>net.bytebuddy</groupId>
            <artifactId>byte-buddy-agent</artifactId>
            <version>1.10.18</version>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.benf</groupId>
            <artifactId>cfr</artifactId>
            <version>0.150</version>
            <optional>true</optional>
        </dependency>

工具类

import com.alibaba.bytekit.agent.inst.NewField;
import com.alibaba.bytekit.asm.MethodProcessor;
import com.alibaba.bytekit.asm.interceptor.InterceptorProcessor;
import com.alibaba.bytekit.asm.interceptor.annotation.InterceptorParserHander;
import com.alibaba.bytekit.asm.interceptor.parser.DefaultInterceptorClassParser;
import com.alibaba.bytekit.utils.AsmUtils;
import com.alibaba.deps.org.objectweb.asm.ClassReader;
import com.alibaba.deps.org.objectweb.asm.Type;
import com.alibaba.deps.org.objectweb.asm.tree.AnnotationNode;
import com.alibaba.deps.org.objectweb.asm.tree.ClassNode;
import com.alibaba.deps.org.objectweb.asm.tree.FieldNode;
import com.alibaba.deps.org.objectweb.asm.tree.MethodNode;

import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

public class EnhanceUtil {

    private static final String CONSTRUCT_MATCH_METHOD = "constructMatch";

    private static final String STATIC_MATCH_METHOD = "staticMatch";

    private static final String METHOD_MATCH_METHOD = "methodMatch";

    public static byte[] enhanceClass(byte[] bytes, Class interceptorClass,ClassLoader classLoader) throws Exception {
        // 初始化Instrumentation
//        AgentUtils.install();

        // 解析定义的 Interceptor类 和相关的注解
        DefaultInterceptorClassParser interceptorClassParser = new DefaultInterceptorClassParser();

        // 源文件字节码classNode
        ClassNode originClassNode = AsmUtils.toClassNode(bytes);

        //interceptor classNode
//        ClassNode apmClassNode = AsmUtils.loadClass(interceptorClass);
        ClassNode apmClassNode = getInterceptorClassNode(interceptorClass,classLoader);


        ClassNode targetClassNode = AsmUtils.copy(originClassNode);

        //把interceptor classNode 名称换成originClassNode 的名称
        byte[] renameClass = AsmUtils.renameClass(AsmUtils.toBytes(apmClassNode), Type.getObjectType(originClassNode.name).getClassName());

        apmClassNode = AsmUtils.toClassNode(renameClass);

        //添加@NewField 的字段
        List<FieldNode> fieldNodes = apmClassNode.fields;
        List<String> fieldMethodName = new ArrayList<>();
        if (fieldNodes != null){
            for (FieldNode fieldNode : fieldNodes){
                if (fieldNode.visibleAnnotations != null){
                    for( AnnotationNode annotationNode : fieldNode.visibleAnnotations) {

                        if(Type.getType(NewField.class).equals(Type.getType(annotationNode.desc))) {
                            String fieldName = fieldNode.name;
                            fieldName =  fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
                            fieldMethodName.add("set"+fieldName);
                            fieldMethodName.add("get"+fieldName);
                            AsmUtils.addField(targetClassNode, fieldNode);
                        }

                    }
                }
            }
        }

        //添加字段的get set方法
        for (MethodNode methodNode : apmClassNode.methods){

            if (fieldMethodName.contains(methodNode.name)){
                AsmUtils.addMethod(targetClassNode,methodNode);
            }
        }

//        List<String> methodNameList = Arrays.asList(targetMethodNames);

        //根据注解解析出拦截器InterceptorProcessor
        List<InterceptorProcessor> processors = interceptorClassParser.parse(interceptorClass);
        // 对加载到的字节码做增强处理
        for (MethodNode methodNode : targetClassNode.methods) {

            Object flag = null;
            //构造函数
            if ("<init>".contains(methodNode.name)) {
                try {
                    Method method = interceptorClass.getMethod(CONSTRUCT_MATCH_METHOD,String.class);
                    flag = method.invoke(null,methodNode.desc);
                }catch (Throwable e){
                    System.err.println("get construct math method error");
                    e.printStackTrace();
                }
//                flag = byteKitInterceptor.constructMatch(methodNode.desc);
            }else if ("<clinit>".contains(methodNode.name)){
                //静态方法快
                try {
                    Method method = interceptorClass.getMethod(STATIC_MATCH_METHOD,String.class);
                    flag = method.invoke(null,methodNode.desc);
                }catch (Throwable e){
                    System.err.println("get static math method error");
                    e.printStackTrace();
                }

            } else {
                //方法
                try {
                    Method method = interceptorClass.getMethod(METHOD_MATCH_METHOD,String.class,String.class);
                    flag = method.invoke(null,methodNode.name,methodNode.desc);
                }catch (Throwable e){
                    System.err.println("get method math method error");
                    e.printStackTrace();
                }

//                flag = byteKitInterceptor.methodMatch(methodNode.name,methodNode.desc);
            }

            if (flag!= null && "true".equals(flag.toString())){
                MethodProcessor methodProcessor = new MethodProcessor(targetClassNode, methodNode);
                for (InterceptorProcessor interceptor : processors) {
                    System.out.println("----------------[SF Agent bytekit]EnhanceUtil interceptor className:"+targetClassNode.name +" method:"+methodNode.name +" methodDesc:"+methodNode.desc +" methodProcessor:"+methodProcessor.getMethodNode().name );
                    //生成拦截器代码
                    interceptor.process(methodProcessor);
                }
            }


//            if (methodNameList.contains(methodNode.name)) {
//                MethodProcessor methodProcessor = new MethodProcessor(targetClassNode, methodNode);
//                for (InterceptorProcessor interceptor : processors) {
//                    interceptor.process(methodProcessor);
//                }
//            }
        }

        ClassReader classReader = AsmUtils.toClassNode(bytes, originClassNode);
//        ClassReader classReader2 = new ClassReader(bytes);
//        System.out.println(classReader2);
        byte[] resutlBytes = AsmUtils.toBytes(targetClassNode,classLoader,classReader);
        // 获取增强后的字节码
        return resutlBytes;
    }


    private static ClassNode getInterceptorClassNode(Class interceptorClass,ClassLoader classLoader) throws Exception {


        String resource = interceptorClass.getName().replace('.', '/') + ".class";
//
//        String path = interceptorClass.getProtectionDomain().getCodeSource().getLocation().getFile();
//        JarFile jarFile = new JarFile(path);
//
//        JarEntry jarEntry = jarFile.getJarEntry(resource);
//        InputStream is = jarFile.getInputStream(jarEntry);

        ClassLoader interceptorClassClassLoader = interceptorClass.getClassLoader();
        if (interceptorClassClassLoader != null){
            return AsmUtils.loadClass(interceptorClass);
        }

//
        interceptorClassClassLoader = ClassLoader.getSystemClassLoader();
        InputStream is = interceptorClassClassLoader.getResourceAsStream(resource);
        ClassReader cr = new ClassReader(is);
        ClassNode classNode = new ClassNode();
        cr.accept(classNode, ClassReader.SKIP_FRAMES);
        return classNode;
    }

//    public static void main(String[] args) {
//        URL resource = ClassLoader.getSystemClassLoader().getResource("sun/net/www/http/HttpClient.class");
//        System.out.println(resource);
//    }
}

使用

BytekitTransformer 字节码增强入口类
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.lang.reflect.Method;
import java.security.ProtectionDomain;
import java.util.*;


public class BytekitTransformer implements ClassFileTransformer,Runnable {


    private static Map<String, String> enhanceModels = new HashMap<>();

    private static Map<String, TransformerLoad> TransformerLoadMap = new HashMap<>();

    private static int delayStartRetransformTime = -1;

    static {
        //后面可以改成spi扩展的方式

        enhanceModels.put("org/apache/logging/log4j/core/layout/PatternLayout$Builder","bytekit.Log4j2ByteEnhance");
        enhanceModels.put("ch/qos/logback/core/pattern/PatternLayoutBase","bytekit.LogbackByteEnhance");
        enhanceModels.put("ch/qos/logback/classic/PatternLayout","bytekit.LogbackPatternByteEnhance");
//        enhanceModels.put("sun/net/www/protocol/http/HttpURLConnection","HttpUrlConnectionByteEnhance");
//        enhanceModels.put("java/net/HttpURLConnection","HttpUrlConnectionByteEnhance");
        // enhanceModels.put("sun/net/www/http/HttpClient","HttpUrlConnectionByteEnhance");

    }


    @Override
    public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
        if (enhanceModels.containsKey(className)) {
            System.out.println("[bytekit] "+new Date()+" bytekitTransformer transform loader:"+loader+" className:"+className);
            String abstractByteEnhance;
            TransformerLoad transformerLoad = null;
            try {
                abstractByteEnhance = enhanceModels.get(className);

                Class abstractByteEnhanceCla = null;
                try {
                    if (loader == null){
                        loader = getClassLoader(abstractByteEnhance);
                    }

                    transformerLoad = getTransformerLoad(className,loader,abstractByteEnhance);

                    if (transformerLoad.getLoadCount() >0 ){
                        System.out.println("[ bytekit] "+new Date()+" bytekitTransformer transform 已经加载过  loader:"+loader+" className:"+className+"  plugin:"+abstractByteEnhance);
                        return classfileBuffer;
                    }

                    transformerLoad.setLoadCount();
                    abstractByteEnhanceCla = loader.loadClass(abstractByteEnhance);
                }catch (Throwable e){
                    System.err.println("[bytekit] "+new Date()+" bytekitTransformer loader加载不到类 loader:"+loader+" className:"+className+"  plugin:"+abstractByteEnhance+"  e:"+e.getMessage());
                    if (transformerLoad != null){
                        transformerLoad.setErrorCount();
                        transformerLoad.setError(e.getMessage());
                    }
                    return classfileBuffer;
                }

                if (abstractByteEnhanceCla == null){
                    System.err.println("[bytekit] "+new Date()+" bytekitTransformer abstractByteEnhance is null loader:"+loader+" className:"+className+"  plugin:"+abstractByteEnhance);
                    throw new RuntimeException("abstractByteEnhanceCla is null");
                }

                Object obj = abstractByteEnhanceCla.newInstance();


                Method enhance = abstractByteEnhanceCla.getMethod("enhance",ClassLoader.class,String.class,Class.class,ProtectionDomain.class,classfileBuffer.getClass());

                byte[] bytes = (byte[])enhance.invoke(obj,loader,className,classBeingRedefined,protectionDomain,classfileBuffer);
                transformerLoad.setSuccessCount();
                System.out.println("[bytekit] "+new Date()+" bytekitTransformer transform 执行插件类enhance成功 loader:"+loader+" className:"+className+"  plugin:"+abstractByteEnhance);
                return bytes;
//                Method math = abstractByteEnhanceCla.getMethod("math",ClassLoader.class,String.class,Class.class,ProtectionDomain.class);

//                Boolean flag = (Boolean) math.invoke(obj,loader,className,classBeingRedefined,protectionDomain);
//                if (flag){

//                    byte[] bytes = abstractByteEnhance.enhance(loader,className,classBeingRedefined,protectionDomain,classfileBuffer);
//                System.err.println(Decompiler.decompile(bytes));
//                    return bytes;
//                }
            }catch (Throwable e){
                e.printStackTrace();
                System.err.println("[bytekit] BytekitTransformer  transform error e:"+e.getMessage());
                if (transformerLoad != null){
                    transformerLoad.setErrorCount();
                    transformerLoad.setError(e.getMessage());
                }

            }

        }
        return classfileBuffer;
    }

    /**
     * bootstrap 的类加载loader是null,获取系统类加载器
     * @return
     */
    private ClassLoader getClassLoader(String cla){
        ClassLoader loader;
        loader = GovernanceAgent.class.getClassLoader();
        try {
            loader.loadClass(cla);
        } catch (ClassNotFoundException e) {
            loader = Thread.currentThread().getContextClassLoader();
            try {
                loader.loadClass(cla);
            } catch (ClassNotFoundException e1) {
                System.out.println("[bytekit] "+new Date()+" bytekitTransformer getClassLoader error loader:"+loader+" className:"+cla+ " e:"+e.getMessage());
                return null;
            }
        }
        return loader;
    }


    private static TransformerLoad getTransformerLoad(String className,ClassLoader classLoader,String enhance){

        String key = className+"_"+classLoader.getClass().getTypeName();
        TransformerLoad transformerLoad = TransformerLoadMap.get(key);
        if (transformerLoad != null){
            return transformerLoad;
        }

        transformerLoad = new TransformerLoad();
        transformerLoad.setClassName(className);
        transformerLoad.setByteEnhance(enhance);
        transformerLoad.setClassLoader(classLoader);

        TransformerLoadMap.put(key,transformerLoad);
        return transformerLoad;
    }

    public void start(){
        Thread thread = new Thread(this::run);
        thread.setName("bytekitTransformer thread");
        thread.setDaemon(true);
        thread.start();
    }

    @Override
    public void run() {
        delayStartRetransformTime = Integer.valueOf(EnvUtils.getProperty("GOV_RETRANSORM_TIME","-1"));

        if (delayStartRetransformTime < 0){
            return;
        }

        System.out.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强启动");
        Instrumentation inst = InstrumentConfig.INSTANCE.get();
        try {
            Thread.sleep(delayStartRetransformTime);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        List<Class> unEnhance = new ArrayList<>();

        ClassLoader classLoader = ClassLoaderUtil.getClassLoader();
        System.out.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强 loader:"+classLoader);
        enhanceModels.forEach((className,enhance) ->{
            try {
                className = className.replaceAll("/",".");
                Class cla = classLoader.loadClass(className);
                unEnhance.add(cla);
            } catch (Exception e) {
                System.err.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强加载增强类失败 loader:"+classLoader+" className:"+className+" enhance:"+enhance+" e:"+e.getMessage());
            }
        });


        for (Class aClass : unEnhance){
            if (!inst.isModifiableClass(aClass) || !inst.isRetransformClassesSupported()){
                System.out.println("[bytekit] "+new Date()+" bytekitTransformer 不能retransformClasses className:"+aClass);
            }

            try {
                System.out.println("[bytekit] "+new Date()+" bytekitTransformer 延时增强 aClass"+aClass);
                inst.retransformClasses(aClass);
            } catch (Exception e) {
                e.printStackTrace();
                System.out.println("[bytekit] "+new Date()+" bytekitTransformer retransformClasses error className:"+aClass);
            }
        }
    }

    /**
     * 转换加载模型
     */
    public static class TransformerLoad {

        private ClassLoader classLoader;

        private String className;

        private String ByteEnhance;

        private int loadCount;

        private int successCount;

        private int errorCount;

        public ClassLoader getClassLoader() {
            return classLoader;
        }

        public void setClassLoader(ClassLoader classLoader) {
            this.classLoader = classLoader;
        }

        private List<String> error = new ArrayList<>();

        public String getClassName() {
            return className;
        }

        public void setClassName(String className) {
            this.className = className;
        }

        public String getByteEnhance() {
            return ByteEnhance;
        }

        public void setByteEnhance(String byteEnhance) {
            ByteEnhance = byteEnhance;
        }

        public int getLoadCount() {
            return loadCount;
        }

        public void setLoadCount() {
            this.loadCount++;
        }

        public int getSuccessCount() {
            return successCount;
        }

        public void setSuccessCount() {
            this.successCount++;
        }

        public int getErrorCount() {
            return errorCount;
        }

        public void setErrorCount() {
            this.errorCount++;
        }


        public List<String> getError() {
            return error;
        }

        public void setError(String error) {
            this.error.add(error);
        }
    }


}

抽象的增强类

import java.security.ProtectionDomain;

public abstract class AbstractByteEnhance {


    public  byte[] enhance(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer){
        try {
           if (math(loader,className,classBeingRedefined,protectionDomain)){
                byte[] bytes =  EnhanceUtil.enhanceClass(classfileBuffer,getInterceptorClass(),loader);

                return bytes;
            }

            System.out.println("AbstractByteEnhance bytekit enhance math匹配失败  className:"+className);
        } catch (Throwable e) {
            System.err.println("AbstractByteEnhance bytekit enhance error className:"+className);
            e.printStackTrace();
        }

        return classfileBuffer;
    }

    public abstract Class getInterceptorClass();

    /**
     * 匹配是否进行字节码增强
     * @param loader
     * @param className
     * @param classBeingRedefined
     * @param protectionDomain
     * @return
     */
    public boolean math(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain){
        try {
            Class interceptorClass = getInterceptorClass();
            if (NullInterceptor.class.equals(interceptorClass)){
                System.err.println("AbstractByteEnhance bytekit 类加载器找不到 增强目标类:"+className+" loader:"+loader);
                return false;
            }
            loader.loadClass(interceptorClass.getName());
            return true;
        } catch (Throwable e) {
            System.err.println("AbstractByteEnhance bytekit 类加载器找不到 "+getInterceptorClass().getName()+" loader:"+loader+" e:"+e.getMessage());
            return false;
        }
    }

}

拦截器

对logback进行增强

import com.sf.plough.governance.agent.dependencies.com.alibaba.bytekit.asm.binding.Binding;
import com.alibaba.bytekit.asm.interceptor.annotation.AtExceptionExit;
import com.alibaba.bytekit.asm.interceptor.annotation.AtExit;

import java.lang.reflect.Field;

public class LogbackInterceptor {
//    private static Logger logger = Logger.getLogger("");
//    public static Logger logger = LoggerFactory.getLogger("sf-bytekit");

//    private final static String constructArgs = "(Ljava/lang/String;Lorg/apache/logging/log4j/core/Filter;Lorg/apache/logging/log4j/core/Layout;Z[Lorg/apache/logging/log4j/core/config/Property;)V";

    // 拦截方法Entry点进行处理
    /*@AtEnter(inline = true, suppress = RuntimeException.class, suppressHandler = PrintExceptionSuppressHandler.class)
    public static void atEnter(@Binding.This Object object,
                               @Binding.Class Object clazz,
                               @Binding.Args Object[] args,
                               @Binding.MethodName String methodName,
                               @Binding.MethodDesc String methodDesc) throws NoSuchFieldException {

    }*/

    // 拦截方法正常返回的语句,在返回前进行处理
    @AtExit(inline = true)
    public static void atExit(@Binding.This Object object,
                              @Binding.Class Object clazz,
                              @Binding.Args Object[] args,
                              @Binding.MethodName String methodName,
                              @Binding.MethodDesc String methodDesc,
                              @Binding.Return Object returnObject) throws Exception {
        if ("setPattern".equals(methodName) && args != null && args.length == 1){
            Class cla = object.getClass().getSuperclass();
            Field field = cla.getDeclaredField("pattern");
            field.setAccessible(true);
            String pattern = PatternUtil.patternConvert(String.valueOf(args[0]),"%tid");
//            AgentLogger.info("--------logback PatternLayoutBase setPattern2 重写pattern:"+pattern);
            //因为spring boot项目logger不生效
            System.out.println("logback PatternLayoutBase setPattern 重写pattern:"+pattern);
            field.set(object,pattern);
        }
    }

    // 拦截方法内部抛出异常点
    @AtExceptionExit(inline = true, onException = RuntimeException.class)
    public static void atExceptionExit(@Binding.Class Object clazz,@Binding.Throwable RuntimeException ex/*,
                                       @Binding.Field(name = "exceptionCount") int exceptionCount*/) {
        ex.printStackTrace();
//        AgentLogger.error("[SF AGENT]LogbackInterceptor atExceptionExit error, clazz:" + clazz,ex);
        throw ex;
    }

    public static Boolean methodMatch(String method,String methodDesc){
        return "setPattern".equals(method);
    }

    public static Boolean constructMatch(String methodDesc){
        return false;
    }

    public static Boolean staticMatch(String methodDesc){
        return false;
    }

}

参考连接:https://github.com/kangzhenkang/bytekit

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值