关于mock或者aop classloader思路描述

1.利用classfileTransform机制,premain()方式启动做。

2、实现其transform方法

@Override
public byte[] transform(final ClassLoader loader, String className, Class<?> classBeingRedefined,
                        ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {



3、由于上面classLoader加载前就要对classLoader进行mock[aop],因此需要在这个找到触发这个classLoader的加载 初始classLoader,一般来说为main()的入类的classLoader,

因此需要找到main()方法启动的入口类的classLoader去mock需要mock的的classLoader,如自定义的,spring-boot 的org.springframework.boot.loader.LaunchedURLClassLoader。


4. 如果mock  classloader加载的类,由于知道 这时就可以mock并加载。为了第二次加载时需要分组,

baseLoader -->1.启动类--》找到classloader 目的:mock 当前应用加载的目的classLoader

                   -->2. 目的classLoader->mock 的方法列表{类:方法}:  找出了目的classLoader,就可以mocki当前cloassLoader下的类和方法了。


 1.建立baseLoader组  --》初始化类,也就是调用main方法的类。 

    transferClass之前建立好初始化类的分组-->baseLoader分组

2、当在transferClass时会找到这个baseLoader分组了,这时,就可以配置在这个baseLoader下需要mock的类和方法了



以下:spring -boot为列

org.springframework.boot.loader.LaunchedURLClassLoader mock方式


public class SpringBootLoaderMocker extends BaseAsmMocker {
@Override
public String getClassName() {
return "org.springframework.boot.loader.LaunchedURLClassLoader";
}

/**
* protected Class<?> loadClass(String name, boolean resolve)
*
* @return
*/
@Override
public List<String> getMethodPatterns() {
List<String> patterns = new ArrayList<String>();
patterns.add("* loadClass(java.lang.String,boolean)");
return patterns;
}

@Override
public InvocationInterceptor getInterceptor() {

return new InvocationInterceptor() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

String className = (String) args[0];

//改变需找找类的文式,在设值可以类型不一样时,在这儿从需要转换的classloader拿取这个类进行赋值,而解决类冲突包类相关的错误异常

Class<?> clazz = ClassRepo.getClass(className);

if (clazz != null) {
DoomLogger.info("tracer.spring-boot" + className + "@" + clazz.getClassLoader());
return clazz;
}
return AsmInjector.invoke(proxy, method, args);
}
};

}
}


=================

public class AsmClassUtil {
    private static Method defineClass1, defineClass2;
    private static Method definePackage;
    private static final Map<String, Class<?>> baseType = new TreeMap<String, Class<?>>();
    private static final Set<String> baseTypes = new HashSet<String>();

    private static Unsafe unSafe;

    static {
        try {
            AccessController.doPrivileged(new PrivilegedExceptionAction() {
                public Object run() throws Exception {
                    Class cl = Class.forName("java.lang.ClassLoader");
                    defineClass1 = cl.getDeclaredMethod("defineClass",
                            new Class[]{String.class, byte[].class,
                                    int.class, int.class});

                    defineClass2 = cl.getDeclaredMethod("defineClass",
                            new Class[]{String.class, byte[].class,
                                    int.class, int.class, ProtectionDomain.class});

                    definePackage = cl.getDeclaredMethod("definePackage",
                            new Class[]{String.class, String.class, String.class,
                                    String.class, String.class, String.class,
                                    String.class, java.net.URL.class});

                    Field f = Unsafe.class.getDeclaredField("theUnsafe"); // Internal reference
                    f.setAccessible(true);
                    unSafe = (Unsafe) f.get(null);
                    return null;
                }
            });
        } catch (PrivilegedActionException pae) {
            throw new RuntimeException("cannot initialize ClassPool", pae.getException());
        }
        baseType.put("byte", Byte.TYPE);
        baseType.put("int", Integer.TYPE);
        baseType.put("char", Character.TYPE);
        baseType.put("boolean", Boolean.TYPE);
        baseType.put("long", Long.TYPE);
        baseType.put("float", Float.TYPE);
        baseType.put("double", Double.TYPE);
        baseType.put("short", Short.TYPE);
        baseType.put("void", Void.TYPE);
        baseTypes.addAll(baseType.keySet());
    }


    public static Class reload(byte[] b, String clazz, ClassLoader classLoader) throws Exception {
        Object domain = null;
        if (classLoader == null) {
            classLoader = Thread.currentThread().getContextClassLoader();
        }
        try {
            Method method;
            Object[] args;
            if (domain == null) {
                method = defineClass1;
                args = new Object[]{clazz, b, new Integer(0),
                        new Integer(b.length)};
            } else {
                method = defineClass2;
                args = new Object[]{clazz, b, new Integer(0),
                        new Integer(b.length), domain};
            }
            System.out.println("loaded " + clazz + " to " + classLoader.getClass().getName() + "@" + classLoader);
            return (Class) toClass2(method, classLoader, args);

        } catch (RuntimeException e) {
            throw e;
        }
    }

    private static synchronized Object toClass2(Method method,
                                                ClassLoader loader, Object[] args)
            throws Exception {
        method.setAccessible(true);
        try {
            return method.invoke(loader, args);
        } finally {
            method.setAccessible(false);
        }
    }


    public static int pushVarInStackWithBox(MethodVisitor mv, int i, Class<?> clazz, String targetClass) {

        if (clazz != null && clazz.isPrimitive()) {
            if (clazz == Integer.TYPE) {
                mv.visitIntInsn(Opcodes.ILOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Integer", "valueOf", "(I)Ljava/lang/Integer;");
                return 1;
            } else if (clazz == Character.TYPE) {
                mv.visitIntInsn(Opcodes.ILOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Character", "valueOf", "(C)Ljava/lang/Character;");
                return 1;
            } else if (clazz == Long.TYPE) {
                mv.visitIntInsn(Opcodes.LLOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;");
                return 2;
            } else if (clazz == Float.TYPE) {
                mv.visitIntInsn(Opcodes.FLOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Float", "valueOf", "(F)Ljava/lang/Float;");
                return 1;
            } else if (clazz == Double.TYPE) {
                mv.visitIntInsn(Opcodes.DLOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Double", "valueOf", "(D)Ljava/lang/Double;");
                return 2;
            } else if (clazz == Byte.TYPE) {
                mv.visitIntInsn(Opcodes.ILOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Byte", "valueOf", "(B)Ljava/lang/Byte;");
                return 1;
            } else if (clazz == Short.TYPE) {
                mv.visitIntInsn(Opcodes.ILOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Short", "valueOf", "(S)Ljava/lang/Short;");
                return 1;
            } else if (clazz == Boolean.TYPE) {
                mv.visitIntInsn(Opcodes.ILOAD, i);
                mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Boolean", "valueOf", "(Z)Ljava/lang/Boolean;");
                return 1;
            }
        } else {
            mv.visitVarInsn(Opcodes.ALOAD, i);
        }
        return 1;
    }


    public static int convertAndReturn(MethodVisitor mv, Class<?> clazz, String selfClass) {

        if (clazz == null) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, selfClass.replaceAll("\\.", "/"));
            mv.visitInsn(Opcodes.ARETURN);
            return 1;
        }

        if (clazz.isPrimitive()) {
            if (clazz == Integer.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Integer");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Integer", "intValue", "()I");
                mv.visitInsn(Opcodes.IRETURN);
                return 1;
            } else if (clazz == Character.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Character");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Character", "charValue", "()C");
                mv.visitInsn(Opcodes.IRETURN);
                return 1;
            } else if (clazz == Long.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Long");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Long", "longValue", "()J");
                mv.visitInsn(Opcodes.LRETURN);
                return 2;
            } else if (clazz == Float.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Float");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Float", "floatValue", "()F");
                mv.visitInsn(Opcodes.FRETURN);
                return 1;
            } else if (clazz == Double.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Double");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Double", "doubleValue", "()D");
                mv.visitInsn(Opcodes.DRETURN);
                return 2;
            } else if (clazz == Byte.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Byte");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Byte", "byteValue", "()B");
                mv.visitInsn(Opcodes.IRETURN);
                return 1;
            } else if (clazz == Short.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Short");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Short", "shortValue", "()S");
                mv.visitInsn(Opcodes.IRETURN);
                return 1;
            } else if (clazz == Boolean.TYPE) {
                mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Boolean");
                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Boolean", "booleanValue", "()Z");
                mv.visitInsn(Opcodes.IRETURN);
                return 1;
            }
        } else {
            mv.visitTypeInsn(Opcodes.CHECKCAST, clazz.getName().replaceAll("\\.", "/"));
            mv.visitInsn(Opcodes.ARETURN);
        }
        return 1;
    }

    public static class MethodInfoReadAdaptor extends ClassVisitor {

        private static final Set<String> baseMethodSet = new HashSet<String>();

        /**
         * 是否忽略private方法的拦截
         */
        private boolean isIgnorePrivate;
        /**
         * 是否忽略protected方法的拦截
         */
        private boolean isIgnoreProtected;

        static {
            baseMethodSet.add("<init>");
            baseMethodSet.add("toString");
            baseMethodSet.add("clone");
            baseMethodSet.add("hashCode()");
            baseMethodSet.add("<clinit>");
        }

        private List<String> methodInfoList = new ArrayList<String>();

        public MethodInfoReadAdaptor(ClassVisitor classAdapter, boolean isIgnorePrivate, boolean isIgnoreProtected) {
            super(AsmAop.ASM_VERSION, classAdapter);
            this.isIgnorePrivate = isIgnorePrivate;
            this.isIgnoreProtected = isIgnoreProtected;
        }

        @Override
        public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
            //排除特殊方法
            if (!baseMethodSet.contains(name) &&
                    !name.contains("jacocoInit")
                    && !name.startsWith("$")) {

                boolean isSkip = false;

                if ((isIgnorePrivate && Modifier.isPrivate(access))
                        || (isIgnoreProtected && Modifier.isProtected(access))) {
                    isSkip = true;
                }

                //如果忽略private和 protected,则只有public才拦截
                if (isIgnorePrivate && isIgnoreProtected && !Modifier.isPublic(access)) {
                    isSkip = true;
                }

                if (!isSkip) {
                    Type[] paramTypes = Type.getArgumentTypes(desc);
                    Type returnType = Type.getReturnType(desc);
                    StringBuilder sb = new StringBuilder();
                    sb.append(getConvertedClassName(returnType.getClassName()));
                    sb.append(" ").append(name).append("(");
                    List<String> types = new ArrayList<String>();
                    for (Type type : paramTypes) {
                        types.add(getConvertedClassName(type.getClassName()));
                    }
                    sb.append(StringUtils.join(types, ","));
                    sb.append(")");
                    methodInfoList.add(sb.toString());
                }
            }
            return super.visitMethod(access, name, desc, signature, exceptions);
        }

        public List<String> getMethodInfoList() {
            return methodInfoList;
        }
    }


    public static List<String> readMethodInfoList(InputStream inputStream, boolean isIgnorePrivate, boolean isIgnoreProtected) throws IOException {
        ClassReader cr = new ClassReader(inputStream);
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS);
        MethodInfoReadAdaptor adaptor = new MethodInfoReadAdaptor(cw, isIgnorePrivate, isIgnoreProtected);
        cr.accept(adaptor, ClassWriter.COMPUTE_MAXS);
        return adaptor.getMethodInfoList();
    }

    public static List<String> readMethodInfoList(InputStream inputStream) throws IOException {
        return readMethodInfoList(inputStream, false, false);
    }

    public static String getConvertedClassName(String className) {

        if (baseTypes.contains(className)) {
            return className;
        }

        return className;
    }

    public static Class<?> getClassByName(String className, final ClassLoader classLoader, final ResourceResolver resourceResolver) throws ClassNotFoundException {
        Class<?> clazz = baseType.get(className);
        if (clazz != null) {
            return clazz;
        }
        if (className.indexOf("[][]") > 0) {
            String subClassName = className.replaceAll("\\[]\\[]", "");
            Class<?> subClazz = getClassByName(subClassName, classLoader, resourceResolver);
            if (subClazz.isPrimitive()) {
                className = String.format("[[%s", Type.getType(subClazz).getDescriptor());
            } else {
                className = String.format("[[L%s;", subClazz.getName());
            }
        } else if (className.indexOf("[]") > 0) {
            String subClassName = className.replaceAll("\\[]", "");
            Class<?> subClazz = getClassByName(subClassName, classLoader, resourceResolver);
            if (subClazz.isPrimitive()) {
                className = String.format("[%s", Type.getType(subClazz).getDescriptor());
            } else {
                className = String.format("[L%s;", subClazz.getName());
            }
        }

        Class<?> typeClass = null;
        try {
            typeClass = Class.forName(className, false, classLoader);
        } catch (ClassNotFoundException e) {
            if (resourceResolver != null) {
                ResourceInfo resourceInfo = resourceResolver.loadStream(AsmClassUtil.forPath(className));
                if (resourceInfo != null && resourceInfo.getInputStream() != null) {
                    try {
                        final String finalClassName = className;
                        safelyExecute(classLoader, resourceResolver, new SafeExecute() {
                            public void execute(ClassLoader loader, ResourceResolver resolver) throws Exception {
                                ResourceInfo loadedResourceInfo = resourceResolver.loadStream(AsmClassUtil.forPath(finalClassName));
                                defineClassFromStream(loadedResourceInfo.getInputStream(), finalClassName, classLoader);
                            }
                        });
                        typeClass = Class.forName(className, false, classLoader);
                    } catch (Exception e1) {
                        e1.printStackTrace();
                    }
                }
            }
        }
        if (typeClass == null) {
            throw new ClassNotFoundException(className);
        }
        return typeClass;
    }

    public static String toClassName(Class<?> clazz) {
        if (clazz.isPrimitive()) {
            return Type.getType(clazz).getClassName();
        } else if (clazz.isArray()) {
            Class<?> c = clazz.getComponentType();
            return toClassName(c) + "[]";
        } else {
            return clazz.getName();
        }
    }

    /**
     * 将class加载到classLoader中
     *
     * @param inputStream
     * @param className
     * @param classLoader
     * @throws Exception
     */
    public static void defineClassFromStream(InputStream inputStream, String className, ClassLoader classLoader) throws Exception {
        if (inputStream != null) {
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            int b = 0;
            try {
                while ((b = inputStream.read()) != -1) {
                    outputStream.write(b);
                }
                inputStream.close();
            } catch (IOException e) {
                throw new IllegalStateException("load class " + className + " error!", e);
            }
            //exposure to parent
            byte[] bytes = outputStream.toByteArray();
            reload(bytes, className, classLoader);
            System.out.println("loaded " + className + " to " + classLoader);
        }
    }

    public static String forPath(String className) {
        return className.replaceAll("\\.", "/").concat(".class");
    }


    public static void safelyExecute(ClassLoader loader, ResourceResolver resolver, SafeExecute executor) throws Exception {
        boolean isExecuteOk = false;
        int deep = 0;
        DO_EXECUTE:
        while (!isExecuteOk) {
            try {
                executor.execute(loader, resolver);
                isExecuteOk = true;
            } catch (Throwable e) {

                if (resolver == null) {
                    throw new Exception(e);
                }
                while (e != null) {
                    String missedClassName = null;
                    if (e instanceof ClassNotFoundException) {
                        missedClassName = e.getMessage();
                    }
                    if (missedClassName != null) {
                        missedClassName = missedClassName.trim();
                        if (missedClassName.startsWith("L") & missedClassName.endsWith(";")) {
                            missedClassName = missedClassName.substring(1, missedClassName.length() - 1);
                        }
                        final String missedClassPath = AsmClassUtil.forPath(missedClassName);
                        System.out.println("-->safelyExecute retry class is " + missedClassPath + "-->" + missedClassName);
                        ResourceInfo resourceInfo = resolver.loadStream(missedClassPath);
                        if (resourceInfo == null || resourceInfo.getInputStream() == null) {
                            System.out.println("-->safelyExecute not find class is " + missedClassName);
                            throw new RuntimeException(missedClassName);
                        }
                        final String finalMissedClassName = missedClassName;
                        safelyExecute(loader, resolver, new SafeExecute() {
                            public void execute(ClassLoader loader, ResourceResolver resolver) throws Exception {
                                InputStream inputStream = resolver.loadStream(missedClassPath).getInputStream();
                                defineClassFromStream(inputStream, finalMissedClassName, loader);
                            }
                        });
                        System.out.println("-->safelyExecute define class success " + missedClassName);
                        continue DO_EXECUTE;
                    } else {
                        if (e.getCause() == null) {
                            e.printStackTrace();
                        }
                        if (deep++ > 100) {
                            throw new Exception(e);
                        }
                        e = e.getCause();
                    }
                }
            }
        }
    }


    public static void checkInit(Class<?> clazz) {
        unSafe.ensureClassInitialized(clazz);
    }


    public static List<MethodInfo> buildMatchedMethod(InjectContext injectContext) throws Exception {

        String fullClassName = injectContext.getFullClassName();
        List<String> patterns = injectContext.getPatterns();
        if (fullClassName == null || fullClassName.trim().isEmpty() || patterns == null) {
            System.out.println("buildMatchedMethod error:" + fullClassName + "@" + patterns);
            throw new IllegalArgumentException("argument error!");
        }

        List<MethodInfo> matchedMethods = new ArrayList<MethodInfo>();

        String classFile = AsmClassUtil.forPath(fullClassName);

        InputStream is = injectContext.getClassLoader().getResourceAsStream(classFile);


        if (is == null) {
            throw new RuntimeException("class file not found ->" + fullClassName);
        }

        //从字节码中读取到类型总的方法
        List<String> methodDescs = AsmClassUtil.readMethodInfoList(is, injectContext.isIgnorePrivate(), injectContext.isIgnoreProtected());

        is.close();

        if (methodDescs.isEmpty()) {
            return matchedMethods;
        }
        int idx = 0;
        List<Pattern> ptns = new ArrayList<Pattern>();
        List<Pattern> excludePtns = new ArrayList<Pattern>();

        for (String pattern : patterns) {
            Pattern ptn = Pattern.compile(processPattern(pattern));
            ptns.add(ptn);
        }
        if (injectContext.getExcludePatterns() != null) {
            for (String patten : injectContext.getExcludePatterns()) {
                excludePtns.add(Pattern.compile(processPattern(patten)));
            }
        }
        //匹配方法
        for (String desc : methodDescs) {
            MATCH:
            for (Pattern pattern : ptns) {
                Matcher matcher = pattern.matcher(desc);
                if (matcher.matches()) {
                    //执行排除逻辑
                    boolean isExclude = false;
                    EXCLUDE:
                    for (Pattern ptn : excludePtns) {
                        Matcher excludeMatcher = ptn.matcher(desc);
                        if (excludeMatcher.matches()) {
                            isExclude = true;
                            break EXCLUDE;
                        }
                    }
                    if (!isExclude) {
                        MethodInfo methodInfo = new MethodInfo().fromV2(desc, injectContext.getClassLoader(), null, fullClassName);
                        methodInfo.setIndex(idx++);
                        matchedMethods.add(methodInfo);
                    }
                    break MATCH;
                }
            }
        }

        return matchedMethods;
    }

    private static String processPattern(String ptn) {
        ptn = ptn.replaceAll("\\[]", "【】");
        ptn = ptn.replaceAll("$__[(]", "_$_");
        ptn = ptn.replaceAll("[*]", ".*").replaceAll("[(]", "[(]").replaceAll("[)]", "[)]")
                .replaceAll("###", "(").replaceAll("&&&", ")");
        ptn = ptn.replaceAll("【】", "\\\\[]");

        return ptn;
    }

    public static List<MethodInfo> buildMatchedMethod(String fullClassName, List<String> patterns, ClassLoader classLoader, ResourceResolver resourceResolver) throws Exception {
        InjectContext injectContext = new InjectContext();
        injectContext.setFullClassName(fullClassName);
        injectContext.setPatterns(patterns);
        injectContext.setClassLoader(classLoader);
        return buildMatchedMethod(injectContext);
    }


    public static String getObjectDescriptor(String name) {
        StringBuilder buf = new StringBuilder();

        buf.append('L');
        int len = name.length();
        for (int i = 0; i < len; ++i) {
            char car = name.charAt(i);
            buf.append(car == '.' ? '/' : car);
        }
        buf.append(';');
        return buf.toString();
    }

}
 以上是mock 实现aop类加载 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值