1.利用classfileTransform机制,premain()方式启动做。
2、实现其transform方法
@Override public byte[] transform(final ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
3、由于上面classLoader加载前就要对classLoader进行mock[aop],因此需要在这个找到触发这个classLoader的加载 初始classLoader,一般来说为main()的入类的classLoader,
因此需要找到main()方法启动的入口类的classLoader去mock需要mock的的classLoader,如自定义的,spring-boot 的org.springframework.boot.loader.LaunchedURLClassLoader。
4. 如果mock classloader加载的类,由于知道 这时就可以mock并加载。为了第二次加载时需要分组,
baseLoader -->1.启动类--》找到classloader 目的:mock 当前应用加载的目的classLoader
-->2. 目的classLoader->mock 的方法列表{类:方法}: 找出了目的classLoader,就可以mocki当前cloassLoader下的类和方法了。
1.建立baseLoader组 --》初始化类,也就是调用main方法的类。
transferClass之前建立好初始化类的分组-->baseLoader分组
2、当在transferClass时会找到这个baseLoader分组了,这时,就可以配置在这个baseLoader下需要mock的类和方法了
以下:spring -boot为列
org.springframework.boot.loader.LaunchedURLClassLoader mock方式
public class SpringBootLoaderMocker extends BaseAsmMocker {
@Override
public String getClassName() {
return "org.springframework.boot.loader.LaunchedURLClassLoader";
}
/**
* protected Class<?> loadClass(String name, boolean resolve)
*
* @return
*/
@Override
public List<String> getMethodPatterns() {
List<String> patterns = new ArrayList<String>();
patterns.add("* loadClass(java.lang.String,boolean)");
return patterns;
}
@Override
public InvocationInterceptor getInterceptor() {
return new InvocationInterceptor() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
String className = (String) args[0];
//改变需找找类的文式,在设值可以类型不一样时,在这儿从需要转换的classloader拿取这个类进行赋值,而解决类冲突包类相关的错误异常
Class<?> clazz = ClassRepo.getClass(className);
if (clazz != null) {
DoomLogger.info("tracer.spring-boot" + className + "@" + clazz.getClassLoader());
return clazz;
}
return AsmInjector.invoke(proxy, method, args);
}
};
}
}
=================
public class AsmClassUtil { private static Method defineClass1, defineClass2; private static Method definePackage; private static final Map<String, Class<?>> baseType = new TreeMap<String, Class<?>>(); private static final Set<String> baseTypes = new HashSet<String>(); private static Unsafe unSafe; static { try { AccessController.doPrivileged(new PrivilegedExceptionAction() { public Object run() throws Exception { Class cl = Class.forName("java.lang.ClassLoader"); defineClass1 = cl.getDeclaredMethod("defineClass", new Class[]{String.class, byte[].class, int.class, int.class}); defineClass2 = cl.getDeclaredMethod("defineClass", new Class[]{String.class, byte[].class, int.class, int.class, ProtectionDomain.class}); definePackage = cl.getDeclaredMethod("definePackage", new Class[]{String.class, String.class, String.class, String.class, String.class, String.class, String.class, java.net.URL.class}); Field f = Unsafe.class.getDeclaredField("theUnsafe"); // Internal reference f.setAccessible(true); unSafe = (Unsafe) f.get(null); return null; } }); } catch (PrivilegedActionException pae) { throw new RuntimeException("cannot initialize ClassPool", pae.getException()); } baseType.put("byte", Byte.TYPE); baseType.put("int", Integer.TYPE); baseType.put("char", Character.TYPE); baseType.put("boolean", Boolean.TYPE); baseType.put("long", Long.TYPE); baseType.put("float", Float.TYPE); baseType.put("double", Double.TYPE); baseType.put("short", Short.TYPE); baseType.put("void", Void.TYPE); baseTypes.addAll(baseType.keySet()); } public static Class reload(byte[] b, String clazz, ClassLoader classLoader) throws Exception { Object domain = null; if (classLoader == null) { classLoader = Thread.currentThread().getContextClassLoader(); } try { Method method; Object[] args; if (domain == null) { method = defineClass1; args = new Object[]{clazz, b, new Integer(0), new Integer(b.length)}; } else { method = defineClass2; args = new Object[]{clazz, b, new Integer(0), new Integer(b.length), domain}; } System.out.println("loaded " + clazz + " to " + classLoader.getClass().getName() + "@" + classLoader); return (Class) toClass2(method, classLoader, args); } catch (RuntimeException e) { throw e; } } private static synchronized Object toClass2(Method method, ClassLoader loader, Object[] args) throws Exception { method.setAccessible(true); try { return method.invoke(loader, args); } finally { method.setAccessible(false); } } public static int pushVarInStackWithBox(MethodVisitor mv, int i, Class<?> clazz, String targetClass) { if (clazz != null && clazz.isPrimitive()) { if (clazz == Integer.TYPE) { mv.visitIntInsn(Opcodes.ILOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Integer", "valueOf", "(I)Ljava/lang/Integer;"); return 1; } else if (clazz == Character.TYPE) { mv.visitIntInsn(Opcodes.ILOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Character", "valueOf", "(C)Ljava/lang/Character;"); return 1; } else if (clazz == Long.TYPE) { mv.visitIntInsn(Opcodes.LLOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;"); return 2; } else if (clazz == Float.TYPE) { mv.visitIntInsn(Opcodes.FLOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Float", "valueOf", "(F)Ljava/lang/Float;"); return 1; } else if (clazz == Double.TYPE) { mv.visitIntInsn(Opcodes.DLOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Double", "valueOf", "(D)Ljava/lang/Double;"); return 2; } else if (clazz == Byte.TYPE) { mv.visitIntInsn(Opcodes.ILOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Byte", "valueOf", "(B)Ljava/lang/Byte;"); return 1; } else if (clazz == Short.TYPE) { mv.visitIntInsn(Opcodes.ILOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Short", "valueOf", "(S)Ljava/lang/Short;"); return 1; } else if (clazz == Boolean.TYPE) { mv.visitIntInsn(Opcodes.ILOAD, i); mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Boolean", "valueOf", "(Z)Ljava/lang/Boolean;"); return 1; } } else { mv.visitVarInsn(Opcodes.ALOAD, i); } return 1; } public static int convertAndReturn(MethodVisitor mv, Class<?> clazz, String selfClass) { if (clazz == null) { mv.visitTypeInsn(Opcodes.CHECKCAST, selfClass.replaceAll("\\.", "/")); mv.visitInsn(Opcodes.ARETURN); return 1; } if (clazz.isPrimitive()) { if (clazz == Integer.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Integer"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Integer", "intValue", "()I"); mv.visitInsn(Opcodes.IRETURN); return 1; } else if (clazz == Character.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Character"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Character", "charValue", "()C"); mv.visitInsn(Opcodes.IRETURN); return 1; } else if (clazz == Long.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Long"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Long", "longValue", "()J"); mv.visitInsn(Opcodes.LRETURN); return 2; } else if (clazz == Float.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Float"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Float", "floatValue", "()F"); mv.visitInsn(Opcodes.FRETURN); return 1; } else if (clazz == Double.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Double"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Double", "doubleValue", "()D"); mv.visitInsn(Opcodes.DRETURN); return 2; } else if (clazz == Byte.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Byte"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Byte", "byteValue", "()B"); mv.visitInsn(Opcodes.IRETURN); return 1; } else if (clazz == Short.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Short"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Short", "shortValue", "()S"); mv.visitInsn(Opcodes.IRETURN); return 1; } else if (clazz == Boolean.TYPE) { mv.visitTypeInsn(Opcodes.CHECKCAST, "java/lang/Boolean"); mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Boolean", "booleanValue", "()Z"); mv.visitInsn(Opcodes.IRETURN); return 1; } } else { mv.visitTypeInsn(Opcodes.CHECKCAST, clazz.getName().replaceAll("\\.", "/")); mv.visitInsn(Opcodes.ARETURN); } return 1; } public static class MethodInfoReadAdaptor extends ClassVisitor { private static final Set<String> baseMethodSet = new HashSet<String>(); /** * 是否忽略private方法的拦截 */ private boolean isIgnorePrivate; /** * 是否忽略protected方法的拦截 */ private boolean isIgnoreProtected; static { baseMethodSet.add("<init>"); baseMethodSet.add("toString"); baseMethodSet.add("clone"); baseMethodSet.add("hashCode()"); baseMethodSet.add("<clinit>"); } private List<String> methodInfoList = new ArrayList<String>(); public MethodInfoReadAdaptor(ClassVisitor classAdapter, boolean isIgnorePrivate, boolean isIgnoreProtected) { super(AsmAop.ASM_VERSION, classAdapter); this.isIgnorePrivate = isIgnorePrivate; this.isIgnoreProtected = isIgnoreProtected; } @Override public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) { //排除特殊方法 if (!baseMethodSet.contains(name) && !name.contains("jacocoInit") && !name.startsWith("$")) { boolean isSkip = false; if ((isIgnorePrivate && Modifier.isPrivate(access)) || (isIgnoreProtected && Modifier.isProtected(access))) { isSkip = true; } //如果忽略private和 protected,则只有public才拦截 if (isIgnorePrivate && isIgnoreProtected && !Modifier.isPublic(access)) { isSkip = true; } if (!isSkip) { Type[] paramTypes = Type.getArgumentTypes(desc); Type returnType = Type.getReturnType(desc); StringBuilder sb = new StringBuilder(); sb.append(getConvertedClassName(returnType.getClassName())); sb.append(" ").append(name).append("("); List<String> types = new ArrayList<String>(); for (Type type : paramTypes) { types.add(getConvertedClassName(type.getClassName())); } sb.append(StringUtils.join(types, ",")); sb.append(")"); methodInfoList.add(sb.toString()); } } return super.visitMethod(access, name, desc, signature, exceptions); } public List<String> getMethodInfoList() { return methodInfoList; } } public static List<String> readMethodInfoList(InputStream inputStream, boolean isIgnorePrivate, boolean isIgnoreProtected) throws IOException { ClassReader cr = new ClassReader(inputStream); ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS); MethodInfoReadAdaptor adaptor = new MethodInfoReadAdaptor(cw, isIgnorePrivate, isIgnoreProtected); cr.accept(adaptor, ClassWriter.COMPUTE_MAXS); return adaptor.getMethodInfoList(); } public static List<String> readMethodInfoList(InputStream inputStream) throws IOException { return readMethodInfoList(inputStream, false, false); } public static String getConvertedClassName(String className) { if (baseTypes.contains(className)) { return className; } return className; } public static Class<?> getClassByName(String className, final ClassLoader classLoader, final ResourceResolver resourceResolver) throws ClassNotFoundException { Class<?> clazz = baseType.get(className); if (clazz != null) { return clazz; } if (className.indexOf("[][]") > 0) { String subClassName = className.replaceAll("\\[]\\[]", ""); Class<?> subClazz = getClassByName(subClassName, classLoader, resourceResolver); if (subClazz.isPrimitive()) { className = String.format("[[%s", Type.getType(subClazz).getDescriptor()); } else { className = String.format("[[L%s;", subClazz.getName()); } } else if (className.indexOf("[]") > 0) { String subClassName = className.replaceAll("\\[]", ""); Class<?> subClazz = getClassByName(subClassName, classLoader, resourceResolver); if (subClazz.isPrimitive()) { className = String.format("[%s", Type.getType(subClazz).getDescriptor()); } else { className = String.format("[L%s;", subClazz.getName()); } } Class<?> typeClass = null; try { typeClass = Class.forName(className, false, classLoader); } catch (ClassNotFoundException e) { if (resourceResolver != null) { ResourceInfo resourceInfo = resourceResolver.loadStream(AsmClassUtil.forPath(className)); if (resourceInfo != null && resourceInfo.getInputStream() != null) { try { final String finalClassName = className; safelyExecute(classLoader, resourceResolver, new SafeExecute() { public void execute(ClassLoader loader, ResourceResolver resolver) throws Exception { ResourceInfo loadedResourceInfo = resourceResolver.loadStream(AsmClassUtil.forPath(finalClassName)); defineClassFromStream(loadedResourceInfo.getInputStream(), finalClassName, classLoader); } }); typeClass = Class.forName(className, false, classLoader); } catch (Exception e1) { e1.printStackTrace(); } } } } if (typeClass == null) { throw new ClassNotFoundException(className); } return typeClass; } public static String toClassName(Class<?> clazz) { if (clazz.isPrimitive()) { return Type.getType(clazz).getClassName(); } else if (clazz.isArray()) { Class<?> c = clazz.getComponentType(); return toClassName(c) + "[]"; } else { return clazz.getName(); } } /** * 将class加载到classLoader中 * * @param inputStream * @param className * @param classLoader * @throws Exception */ public static void defineClassFromStream(InputStream inputStream, String className, ClassLoader classLoader) throws Exception { if (inputStream != null) { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); int b = 0; try { while ((b = inputStream.read()) != -1) { outputStream.write(b); } inputStream.close(); } catch (IOException e) { throw new IllegalStateException("load class " + className + " error!", e); } //exposure to parent byte[] bytes = outputStream.toByteArray(); reload(bytes, className, classLoader); System.out.println("loaded " + className + " to " + classLoader); } } public static String forPath(String className) { return className.replaceAll("\\.", "/").concat(".class"); } public static void safelyExecute(ClassLoader loader, ResourceResolver resolver, SafeExecute executor) throws Exception { boolean isExecuteOk = false; int deep = 0; DO_EXECUTE: while (!isExecuteOk) { try { executor.execute(loader, resolver); isExecuteOk = true; } catch (Throwable e) { if (resolver == null) { throw new Exception(e); } while (e != null) { String missedClassName = null; if (e instanceof ClassNotFoundException) { missedClassName = e.getMessage(); } if (missedClassName != null) { missedClassName = missedClassName.trim(); if (missedClassName.startsWith("L") & missedClassName.endsWith(";")) { missedClassName = missedClassName.substring(1, missedClassName.length() - 1); } final String missedClassPath = AsmClassUtil.forPath(missedClassName); System.out.println("-->safelyExecute retry class is " + missedClassPath + "-->" + missedClassName); ResourceInfo resourceInfo = resolver.loadStream(missedClassPath); if (resourceInfo == null || resourceInfo.getInputStream() == null) { System.out.println("-->safelyExecute not find class is " + missedClassName); throw new RuntimeException(missedClassName); } final String finalMissedClassName = missedClassName; safelyExecute(loader, resolver, new SafeExecute() { public void execute(ClassLoader loader, ResourceResolver resolver) throws Exception { InputStream inputStream = resolver.loadStream(missedClassPath).getInputStream(); defineClassFromStream(inputStream, finalMissedClassName, loader); } }); System.out.println("-->safelyExecute define class success " + missedClassName); continue DO_EXECUTE; } else { if (e.getCause() == null) { e.printStackTrace(); } if (deep++ > 100) { throw new Exception(e); } e = e.getCause(); } } } } } public static void checkInit(Class<?> clazz) { unSafe.ensureClassInitialized(clazz); } public static List<MethodInfo> buildMatchedMethod(InjectContext injectContext) throws Exception { String fullClassName = injectContext.getFullClassName(); List<String> patterns = injectContext.getPatterns(); if (fullClassName == null || fullClassName.trim().isEmpty() || patterns == null) { System.out.println("buildMatchedMethod error:" + fullClassName + "@" + patterns); throw new IllegalArgumentException("argument error!"); } List<MethodInfo> matchedMethods = new ArrayList<MethodInfo>(); String classFile = AsmClassUtil.forPath(fullClassName); InputStream is = injectContext.getClassLoader().getResourceAsStream(classFile); if (is == null) { throw new RuntimeException("class file not found ->" + fullClassName); } //从字节码中读取到类型总的方法 List<String> methodDescs = AsmClassUtil.readMethodInfoList(is, injectContext.isIgnorePrivate(), injectContext.isIgnoreProtected()); is.close(); if (methodDescs.isEmpty()) { return matchedMethods; } int idx = 0; List<Pattern> ptns = new ArrayList<Pattern>(); List<Pattern> excludePtns = new ArrayList<Pattern>(); for (String pattern : patterns) { Pattern ptn = Pattern.compile(processPattern(pattern)); ptns.add(ptn); } if (injectContext.getExcludePatterns() != null) { for (String patten : injectContext.getExcludePatterns()) { excludePtns.add(Pattern.compile(processPattern(patten))); } } //匹配方法 for (String desc : methodDescs) { MATCH: for (Pattern pattern : ptns) { Matcher matcher = pattern.matcher(desc); if (matcher.matches()) { //执行排除逻辑 boolean isExclude = false; EXCLUDE: for (Pattern ptn : excludePtns) { Matcher excludeMatcher = ptn.matcher(desc); if (excludeMatcher.matches()) { isExclude = true; break EXCLUDE; } } if (!isExclude) { MethodInfo methodInfo = new MethodInfo().fromV2(desc, injectContext.getClassLoader(), null, fullClassName); methodInfo.setIndex(idx++); matchedMethods.add(methodInfo); } break MATCH; } } } return matchedMethods; } private static String processPattern(String ptn) { ptn = ptn.replaceAll("\\[]", "【】"); ptn = ptn.replaceAll("$__[(]", "_$_"); ptn = ptn.replaceAll("[*]", ".*").replaceAll("[(]", "[(]").replaceAll("[)]", "[)]") .replaceAll("###", "(").replaceAll("&&&", ")"); ptn = ptn.replaceAll("【】", "\\\\[]"); return ptn; } public static List<MethodInfo> buildMatchedMethod(String fullClassName, List<String> patterns, ClassLoader classLoader, ResourceResolver resourceResolver) throws Exception { InjectContext injectContext = new InjectContext(); injectContext.setFullClassName(fullClassName); injectContext.setPatterns(patterns); injectContext.setClassLoader(classLoader); return buildMatchedMethod(injectContext); } public static String getObjectDescriptor(String name) { StringBuilder buf = new StringBuilder(); buf.append('L'); int len = name.length(); for (int i = 0; i < len; ++i) { char car = name.charAt(i); buf.append(car == '.' ? '/' : car); } buf.append(';'); return buf.toString(); } }以上是mock 实现aop类加载