由于Java是面向对象编程语言,按理来说,只要系统设计合理,对于现有代码的可扩展,完全可以通过增加新的类或模块来拥抱变化,而不要通过像过程式编程,来通过修改现有代码来支持变化,代码织入的本质就是修改现有类的字节码,来达到改变原有类的功能,故这个代码织入功能是不需要的。
然而要架构一个好的系统,是不容易的,所以往往实现原有功能的扩展,还是需要修改原有代码的,有的时候,源代码也许不可获得,同时为了是修改代码的粒度更小,代码织入功能还是需要的。
优点:
可以对已编译的类中的字段或方法进行二次修改,从而可以方便的对产品进行二次开发及功能扩展。
缺点:
1.由于更新了编译的class,在调试时无法跟踪到正确的行号,也不能正确的指定源码,所以调试起来比较困难
2.在对原class升级的时候也需要检查织入类,所以在升级产品时会增加部分工作量
思路:
因为是为了修改目标类的一部分(为了实用性,以及实现的复杂度考虑,假设织入类不会修改目标类的构造函数),所以以目标类为基础,然后添加织入类中的成员(字段、方法),若目标类中存在,则说明是覆盖,若目标类中不存在,则是添加新的成员。
覆盖:
1.字段覆盖
对于字段的覆盖,就是赋值修改,对于字段的赋值,是在方法中完成的。若为静态字段,则需要修改<clinit>方法,若为非静态字段,则需要修改<init>方法。
2.方法覆盖
对于方法的覆盖,将织入类的方法直接保存,这个过程需要注意方法体中的super关键字,同时为了能够在织入类中能调用目标类的方法,需要保存目标类的方法,方法重命名为method0,当然这个名字可以随便命名的,只要前后一致即可。
关于super关键字,分为两种:super.field,super.method
实现过程:
对于目标类来说,为了实现目标类的成员(字段、方法)的织入,理论上只需要通过编写织入类即可,但是由于织入类一般只修改目标类的局部,同时为了能让织入类能够调用目标类的成员,所以在使用织入功能的时候,需要额外编写一个存根类,这个类的作用就是声明目标类的成员,同时需要将成员的修饰符放宽(至少private声明为protected或以上)。这样编写的织入类只要继承自存根类,就可以方便调用目标类的成员了。
补充:这个存根类,其实也不是莫名其妙就有了的,一开始,我思考着直接让织入类继承自目标类就好,但是在实现织入功能中,发现直接继承自目标类虽然使用织入功能的过程,会变得简化,毕竟只需要编写一个类就好了,但是super关键字不能访问父类的私有成员,这也就是为什么需要存根类,同时需要将访问权限放宽的原因。
核心源代码:
package com.zherop.asm;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
/**
* @author zp
* @mail zherop@163.com
* @date 2016年5月6日
*/
public class ClassWeaver {
private static String INTERNAL_INIT_METHOD_NAME = "<init>";
private static String INTERNAL_STSTIC_INIT_METHOD_NAME = "<clinit>";
private Map<String, Integer> targetMethods;// 目标类方法信息<方法名,方法的访问权限>
private List<String> weaverMethodNames;// 织入类方法
private List<String> weaverFieldNames;// 织入类字段
private final Map<FieldInfo, Object> weaverInitFieldMap;// 织入类非静态字段初始化
private final Map<FieldInfo, Object> weaverStaticInitFieldMap;// 织入类静态字段初始化
private ClassReader[] classReaders;
private String targetStubClass; // 存根类
private byte[] classBytes; // 生成的字节码
public ClassWeaver(String targetClass, String targetStubClass,
String weaverClass) {
targetMethods = new HashMap<String, Integer>();
weaverMethodNames = new ArrayList<String>();
weaverFieldNames = new ArrayList<String>();
weaverInitFieldMap = new HashMap<FieldInfo, Object>();
weaverStaticInitFieldMap = new HashMap<FieldInfo, Object>();
setTargetStubClass(targetStubClass);
initClassReaders(targetClass, weaverClass);
}
private void setTargetStubClass(String targetStubClsName) {
if (targetStubClsName.contains(".")) {
this.targetStubClass = targetStubClsName.replaceAll("\\.", "/");
} else {
this.targetStubClass = targetStubClsName;
}
}
private void initClassReaders(String targetClass, String weaverClass) {
classReaders = new ClassReader[2];
try {
classReaders[0] = new ClassReader(targetClass);
classReaders[1] = new ClassReader(weaverClass);
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化织入类的成员(属性、方法)
*
* @throws IOException
*/
private void initWeaverMembers() throws IOException {
ClassWriter cw = new ClassWriter(0);
ClassReader cr = getWeaverClassReader();
ClassVisitor cv = new ClassVisitor(Opcodes.ASM5, cw) {
@Override
public FieldVisitor visitField(int access, String name,
String desc, String signature, Object value) {
weaverFieldNames.add(name);
return super.visitField(access, name, desc, signature, value);
}
@Override
public MethodVisitor visitMethod(int access, String name,
String desc, String signature, String[] exceptions) {
MethodVisitor visitMethod = super.visitMethod(access, name,
desc, signature, exceptions);
// 对织入类静态字段的处理(获取字段的初始化值)
if (INTERNAL_STSTIC_INIT_METHOD_NAME.equals(name)) {
visitMethod = new FieldInitMethodVisitor(Opcodes.ASM5,
visitMethod, weaverStaticInitFieldMap);
}
// 对织入类非静态字段的处理(获取字段的初始化值)
else if (INTERNAL_INIT_METHOD_NAME.equals(name)) {
visitMethod = new FieldInitMethodVisitor(Opcodes.ASM5,
visitMethod, weaverInitFieldMap);
} else {
// 不是静态块,构造函数就保存起来
weaverMethodNames.add(name);
}
return visitMethod;
}
};
cr.accept(cv, 0);
}
/**
* 初始化目标类的方法
*
* @throws IOException
*/
private void initTargetMethods() throws IOException {
ClassWriter cw = new ClassWriter(0);
ClassReader cr = getTargetClassReader();
ClassVisitor cv = new ClassVisitor(Opcodes.ASM5, cw) {
@Override
public MethodVisitor visitMethod(int access, String name,
String desc, String signature, String[] exceptions) {
MethodVisitor visitMethod = super.visitMethod(access, name,
desc, signature, exceptions);
if (!INTERNAL_INIT_METHOD_NAME.equals(name)
&& !INTERNAL_STSTIC_INIT_METHOD_NAME.equals(name)) {
targetMethods.put(name, access);
}
return visitMethod;
}
};
cr.accept(cv, 0);
}
private ClassReader getWeaverClassReader() {
return classReaders[1];
}
public void weaver() throws IOException {
initTargetMethods();
initWeaverMembers();
ClassReader cr = getTargetClassReader();
final ClassWriter cw = new ClassWriter(0);
ClassVisitor cv = new ClassVisitor(Opcodes.ASM5, cw) {
@Override
public FieldVisitor visitField(int access, String name,
String desc, String signature, Object value) {
// 织入类包含该字段,则删除目标类中该字段
if (weaverFieldNames.contains(name)) {
return null;
}
return super.visitField(access, name, desc, signature, value);
}
@Override
public MethodVisitor visitMethod(int access, String name,
String desc, String signature, String[] exceptions) {
// 对静态常量的处理
if (INTERNAL_STSTIC_INIT_METHOD_NAME.equals(name)) {
MethodVisitor visitMethod = super.visitMethod(access, name,
desc, signature, exceptions);
return new FieldInitValueSetMethodVisitor(Opcodes.ASM5,
visitMethod, weaverStaticInitFieldMap,
Opcodes.PUTSTATIC);
}
// 对目标类构造函数的处理(添加织入类中字段的初始化)
else if (INTERNAL_INIT_METHOD_NAME.equals(name)) {
MethodVisitor visitMethod = super.visitMethod(access, name,
desc, signature, exceptions);
return new FieldInitValueSetMethodVisitor(Opcodes.ASM5,
visitMethod, weaverInitFieldMap, Opcodes.PUTFIELD);
} else {
// 若织入类覆盖了目标类的方法,则将目标类的该方法重命名
if (weaverMethodNames.contains(name)) {
name = name + "0";
}
return super.visitMethod(access, name, desc, signature,
exceptions);
}
}
@Override
public void visitEnd() {
ClassReader weaverClassReader = getWeaverClassReader();
// 获取织入类的字段、方法
ClassVisitor classVisitor = new WeaverClassVisitor(
Opcodes.ASM5, cw);
weaverClassReader.accept(classVisitor, 0);
super.visitEnd();
}
};
cr.accept(cv, 0);
classBytes = cw.toByteArray();
}
private ClassReader getTargetClassReader() {
return classReaders[0];
}
public byte[] getClassBytes() {
return classBytes;
}
public void write2File(String output) throws IOException {
FileOutputStream fos = new FileOutputStream(output);
fos.write(classBytes);
if (fos != null) {
fos.close();
}
}
/**
* 用于保存字段的初始化的信息
*/
class FieldInfo {
String name;// 字段名称
String desc;// 字段类型描述符
String visitType;// 字段指令类型
public FieldInfo(String name, String desc, String visitType) {
this.name = name;
this.desc = desc;
this.visitType = visitType;
}
public String getName() {
return name;
}
public String getDesc() {
return desc;
}
public String getVisitType() {
return visitType;
}
}
/**
* 访问字段的初始化值(目的:保存字段的初始化值)
*/
class FieldInitMethodVisitor extends MethodVisitor {
private Map<FieldInfo, Object> output;
private Object cstValue;
private String type;
public FieldInitMethodVisitor(int api, MethodVisitor mv,
Map<FieldInfo, Object> output) {
super(api, mv);
this.output = output;
}
@Override
public void visitInsn(int opcode) {
cstValue = opcode;
type = "visitInsn";
super.visitInsn(opcode);
}
@Override
public void visitLdcInsn(Object cst) {
cstValue = cst;
type = "visitLdcInsn";
super.visitLdcInsn(cst);
}
@Override
public void visitFieldInsn(int opcode, String owner, String name,
String desc) {
FieldInfo fieldInfo = new FieldInfo(name, desc, type);
output.put(fieldInfo, cstValue);
super.visitFieldInsn(opcode, owner, name, desc);
}
}
/**
* 字段初始化(目的:在生成的字节码中,设置字段的初始化值)
*/
class FieldInitValueSetMethodVisitor extends MethodVisitor {
/**
* fieldOpcode取值为Opcodes.PUTFIELD,Opcodes.PUTSTATIC
*/
private int fieldOpcode;
private Map<FieldInfo, Object> initFieldMap;
public FieldInitValueSetMethodVisitor(int api, MethodVisitor mv,
Map<FieldInfo, Object> input, int fieldOpcode) {
super(api, mv);
this.initFieldMap = input;
this.fieldOpcode = fieldOpcode;
}
@Override
public void visitInsn(int opcode) {
if (Opcodes.RETURN == opcode) {
initWeaverField();
}
super.visitInsn(opcode);
}
private void initWeaverField() {
for (Entry<FieldInfo, Object> entry : initFieldMap.entrySet()) {
if (Opcodes.PUTFIELD == this.fieldOpcode) {
visitVarInsn(Opcodes.ALOAD, 0);
}
FieldInfo fieldInfo = entry.getKey();
if ("visitLdcInsn".equals(fieldInfo.getVisitType())) {
visitLdcInsn(entry.getValue());
} else if ("visitInsn".equals(fieldInfo.getVisitType())) {
visitInsn((Integer) entry.getValue());
} else if ("visitIntInsn".equals(fieldInfo.getVisitType())) {
visitIntInsn(Opcodes.BIPUSH, (Integer) entry.getValue());
}
visitFieldInsn(fieldOpcode, getTargetClassReader()
.getClassName(), fieldInfo.getName(),
fieldInfo.getDesc());
}
}
}
/**
* 织入类访问(只保留字段,方法)
*/
class WeaverClassVisitor extends ClassVisitor {
public WeaverClassVisitor(int api, ClassVisitor cv) {
super(api, cv);
}
// 织入类的类信息删除
@Override
public void visit(int version, int access, String name,
String signature, String superName, String[] interfaces) {
}
// 织入类的字段保留
@Override
public FieldVisitor visitField(int access, String name, String desc,
String signature, Object value) {
return super.visitField(access, name, desc, signature, value);
}
// 织入类的方法保留
@Override
public MethodVisitor visitMethod(int access, String name, String desc,
String signature, String[] exceptions) {
// 织入类的构造函数、静态块删除
if (INTERNAL_INIT_METHOD_NAME.equals(name)
|| INTERNAL_STSTIC_INIT_METHOD_NAME.equals(name)) {
return null;
}
// 修改织入类中覆盖了目标类方法的访问权限(被存根类放宽的访问权限)
if (targetMethods.containsKey(name)) {
access = targetMethods.get(name);
}
MethodVisitor visitMethod = super.visitMethod(access, name, desc,
signature, exceptions);
// 织入类的方法处理
return new MethodVisitor(Opcodes.ASM5, visitMethod) {
String owner = getTargetClassReader().getClassName();
@Override
public void visitMethodInsn(int opcode, String owner,
String name, String desc, boolean itf) {
// 织入类新增方法调用
if (!targetMethods.containsKey(name)
&& weaverMethodNames.contains(name)) {
owner = this.owner;
}
// 非新增方法调用
else {
// 方法体中包含super关键字的方法调用
if (owner.equals(targetStubClass)
&& Opcodes.INVOKESPECIAL == opcode) {
opcode = Opcodes.INVOKEVIRTUAL;
owner = this.owner;
name = name + "0";
}
}
super.visitMethodInsn(opcode, owner, name, desc, itf);
}
// 字段的owner为super的改为目标类
@Override
public void visitFieldInsn(int opcode, String owner,
String name, String desc) {
if (owner.equals(getWeaverClassReader().getClassName())) {
owner = getTargetClassReader().getClassName();
}
super.visitFieldInsn(opcode, owner, name, desc);
}
};
}
}
}
/**
*
*/
package com.zherop.asm;
import java.io.IOException;
/**
* @author zp
* @mail zherop@163.com
* @date 2016年5月6号
*/
public class TestClassLoader extends ClassLoader {
public TestClassLoader() {
super();
}
public static Class<?> testWeaver(final String targetClassName,
final String targetStubClsName, final String weaverClassName) {
final TestClassLoader cl = new TestClassLoader();
ClassWeaver classWeaver = new ClassWeaver(targetClassName,
targetStubClsName, weaverClassName);
try {
classWeaver.weaver();
classWeaver.write2File("E:/asm/" + targetClassName + ".class");
} catch (IOException e) {
e.printStackTrace();
}
byte[] classBytes = classWeaver.getClassBytes();
try {
Class<?> newClass = cl.defineClass(targetClassName, classBytes, 0,
classBytes.length);
return newClass;
} catch (SecurityException e) {
e.printStackTrace();
} catch (IllegalArgumentException e) {
e.printStackTrace();
}
return null;
}
}
示例代码:
public class DemoTarget {
private static String FIELD_STR_DEFAULT = "Target_Default"; // 静态私有成员
private String privStringField = FIELD_STR_DEFAULT; // 私有字段
protected String protStringField = FIELD_STR_DEFAULT; // 保护字段
String packStringField = FIELD_STR_DEFAULT; // 包作用域字段
public String pubStringField = FIELD_STR_DEFAULT; // 公共字段
/**
* 私有方法
*/
private void privMethod() {
staticPrivMethod();
System.out
.println("This is private method in Target Class!#privMethod");
}
/**
* 静态私有方法
*/
private static void staticPrivMethod() {
System.out
.println("This is static private method in Target Class!#staticPrivMethod");
}
/**
* 公共方法
*/
public void publishMethod() {
System.out
.println("This is publish method in Target Class!#publishMethod");
}
/**
* 静态公共方法
*/
public static void staticPublishMethod() {
System.out
.println("This is static publish method in Target Class!#staticPublishMethod");
}
/**
* 打印所有字段
*/
public void printAllField() {
System.out.println("Print All Field --------------Begin");
System.out.println(privStringField);
System.out.println(protStringField);
System.out.println(packStringField);
System.out.println(pubStringField);
System.out.println("Print All Field --------------End");
}
/**
* 执行所有方法
*/
public void doAllMethod() {
System.out.println("Do All Method --------------Begin");
privMethod();
staticPrivMethod();
publishMethod();
staticPublishMethod();
System.out.println("Do All Method --------------End");
}
}
public class DemoWeaver extends DemoStub {
public static String FIELD_STR_DEFAULT = "织入后的#FIELD_STR_DEFAULT";
private String weaverAddField = "weaverAddField";
public String privStringField = "织入后的#privStringField";
/**
* 织入添加方法
*
* @return
*/
private String weaverAddMethod() {
System.out.println("Do weaverAddMethod");
return weaverAddField;
}
/**
* 替换目标类的privMethod方法
*/
public void privMethod() {
System.out.println(weaverAddMethod());// 调用weaverAddMethod方法并打印返回结果
System.out.println("**调用目标类原privMethod方法--begin");
super.privMethod();// 调用目标类原privMethod方法
System.out.println("**调用目标类原privMethod方法--end");
System.out.println("打印目标类protStringField字段值:" + protStringField);// 访问目标类字段protStringField
}
/**
* 替换目标类的staticPrivMethod方法
*/
public static void staticPrivMethod() {
System.out
.println("This is weaver static PrivMethod //织入后的打印#staticPrivMethod");
}
/**
* 替换目标类的publishMethod方法
*/
public void publishMethod() {
System.out
.println("This is weaver static PrivMethod //织入后的打印#publishMethod");
}
}
public class DemoStub {
public static String FIELD_STR_DEFAULT;
public String privStringField; // 私有字段
public String protStringField; // 保护字段
public String packStringField; // 包作用域字段
public String pubStringField; // 公共字段
public void privMethod() {
}
public static void staticPrivMethod() {
}
public void publishMethod() {
}
public static void staticPublishMethod() {
}
}
package com.zherop.asm.test;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import com.zherop.asm.TestClassLoader;
public class Test {
public static void main(String[] args) throws SecurityException,
NoSuchMethodException, InstantiationException,
IllegalAccessException, IllegalArgumentException,
InvocationTargetException {
String targetClassName = DemoTarget.class.getName();
String targetStubClsName = DemoStub.class.getName();
String weaverClassName = DemoWeaver.class.getName();
Class<?> newClass = TestClassLoader.testWeaver(targetClassName,
targetStubClsName, weaverClassName);
if (newClass == null) {
System.out.println("类生成失败");
return;
}
Object object = newClass.newInstance();
Method printAllField = newClass.getDeclaredMethod("printAllField",
new Class<?>[] {});
printAllField.invoke(object, new Object[] {});
Method doAllMethod = newClass.getDeclaredMethod("doAllMethod",
new Class<?>[] {});
doAllMethod.invoke(object, new Object[] {});
}
}
运行结果:
核心代码的实现使用了第三方字节码工具包asm。代码编写还算规范,代码中注释也比较详细,可读性应该还是不错的,不足之处,还望不吝赐教!!!
代码编写初衷:
因为自己所在的公司,框架中有这个代码织入功能,其实现是使用asm包中的tree API,然后出于兴趣,就打算根据自己的理解,将这个过程实现一遍,于是就使用了asm包中的event API,哈哈,之所以不一样,当时是避免雷同,同时这样才有挑战性。由于公司1,3,5加班,只能在闲暇之时才完成这个功能,陆陆续续花了一周时间,周六的晚上,一不小心就写到了快1点了,也算是完成了这个想法,为了检验功能是否OK了,测试代码直接从公司知识库中copy的[偷笑],毕竟前辈考虑问题还是要全面些。
在网上能找到的关于asm的文章,都是比较简单的示例,比较深入点的例子还是很少的,要实现代码织入功能,还是有点难度的,希望这篇博客可以给需要的人作为参考,少走弯路。平时不怎么写博客的我,特意写下此文,算是做点小小滴贡献!!哈哈哈