JDK动态代理原理剖析(上)

JDK动态代理示例

首先从一个常见的JDK动态代理示例来开始:
代理目标接口:

/**
 * @Description: 代理目标接口
 * @Author: binga
 * @Date: 2020/8/25 10:19
 * @Blog: https://blog.csdn.net/pang5356
 */
public interface TargetInterface {
    void doSomething();
}

实现类如下:

/**
 * @Description: 接口实现
 * @Author: binga
 * @Date: 2020/8/25 10:28
 * @Blog: https://blog.csdn.net/pang5356
 */
public class TargetImpl implements TargetInterface {
    public void doSomething() {
        System.out.println("TargetImpl method");
    }
}

处理器实现:

/**
 * @Description: 代理类
 * @Author: binga
 * @Date: 2020/8/25 10:16
 * @Blog: https://blog.csdn.net/pang5356
 */
public class TargetProxy implements InvocationHandler {

    private TargetInterface targetInterface;

    public TargetProxy(TargetInterface targetInterface) {
        this.targetInterface = targetInterface;
    }

    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        Object invoke = null;

        System.out.println(method.getName() + "前置处理");
        if (targetInterface != null)
            invoke  = method.invoke(targetInterface, args);
        System.out.println(method.getName() + "后置处理");
        return invoke;
    }

    public TargetInterface getProxyInstance() {
        Object o = Proxy.newProxyInstance(TargetProxy.class.getClassLoader(),
                targetInterface.getClass().getInterfaces(), this);
		return (TargetInterface) o;
    }
}

测试如下:

/**
 * @Description: 代理模式测试
 * @Author: binga
 * @Date: 2020/8/25 10:15
 * @Blog: https://blog.csdn.net/pang5356
 */
public class ProxyTest {

    public static void main(String[] args) {
        TargetImpl target = new TargetImpl();
        TargetProxy proxy = new TargetProxy(target);
        TargetInterface proxyInstance = proxy.getProxyInstance();
        proxyInstance.doSomething();
    }
}

结果如下:

doSomething前置处理
TargetImpl method
doSomething后置处理

可以看到,执行doSomething方法前轴有加强信息输出。

原理剖析

JDK动态代理使用两种技术反射和字节码动态生成。

  1. 反射:通过反射获取到每接口的方法对应的Method对象,然后通过Method以及被代理对象调用invoke方法,从而实现对原对象的方法调用,而在Method对象调用invoke前后可以自定义的增强,而InvocationHander的invoke方法则是扮演的“钩子”角色。
  2. 字节码动态生成:JDK在生成代理对象的过程中,首先是在内存中根据参数的接口从而动态的生成一个类的字节码数据,这个类实现了所有的参数接口,同时这个类还继承了Proxy类,在内存中生成类的字节码数据后,然后使用参数类加载器进行加载并返回该类的Class对象,然后通过这个类的Class对象生成代理对象。
    那么其流程如下:
    在这里插入图片描述
    通过这个流程可以了解,代理对象对应的类是在内存中动态生成的,而InvocationHandler的实现并不是代理对象,这个角色是要搞清楚的,上面说了InvocationHandler就是一个“钩子”,用于提供我们进行增强的,那么这个代理类究竟是什么样的呢?可以通过以下代码将其输出,如下:
/**
 * @Description: 输出字节码
 * @Author: binga
 * @Date: 2020/8/25 10:34
 * @Blog: https://blog.csdn.net/pang5356
 */
public class GenerateByteCodeTest {

    public static final String CLAZZ_PATH = "F:\\test\\";

    public static void main(String[] args) {
        TargetProxy proxy = new TargetProxy(new TargetImpl());
        TargetInterface proxyInstance = proxy.getProxyInstance();
        proxyInstance.doSomething();

        String clazzName = proxyInstance.getClass().getName();

        byte[] classByteCode = ProxyGenerator.generateProxyClass(clazzName,
                new Class[]{TargetInterface.class});

        String name = clazzName.substring(clazzName.lastIndexOf(".") + 1);
        File file = new File(CLAZZ_PATH, name + ".class");

        try (FileOutputStream fout = new FileOutputStream(file);) {
            fout.write(classByteCode);
            fout.flush();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

运行以后则会在F:\test\下生成代理类的class文件。

package com.sun.proxy;


import com.binga.designpatterns.proxy.TargetInterface;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.lang.reflect.UndeclaredThrowableException;


public final class $Proxy0 extends Proxy
  implements TargetInterface
{
  private static Method m1;
  private static Method m3;
  private static Method m2;
  private static Method m0;


  public $Proxy0(InvocationHandler paramInvocationHandler)
    throws 
  {
    super(paramInvocationHandler);
  }


  public final boolean equals(Object paramObject)
    throws 
  {
    try
    {
      return ((Boolean)this.h.invoke(this, m1, new Object[] { paramObject })).booleanValue();
    }
    catch (RuntimeException localRuntimeException)
    {
      throw localRuntimeException;
    }
    catch (Throwable localThrowable)
    {
      throw new UndeclaredThrowableException(localThrowable);
    }
  }


  public final void doSomething()
    throws 
  {
    try
    {
      this.h.invoke(this, m3, null);
      return;
    }
    catch (RuntimeException localRuntimeException)
    {
      throw localRuntimeException;
    }
    catch (Throwable localThrowable)
    {
      throw new UndeclaredThrowableException(localThrowable);
    }
  }


  public final String toString()
    throws 
  {
    try
    {
      return ((String)this.h.invoke(this, m2, null));
    }
    catch (RuntimeException localRuntimeException)
    {
      throw localRuntimeException;
    }
    catch (Throwable localThrowable)
    {
      throw new UndeclaredThrowableException(localThrowable);
    }
  }


  public final int hashCode()
    throws 
  {
    try
    {
      return ((Integer)this.h.invoke(this, m0, null)).intValue();
    }
    catch (RuntimeException localRuntimeException)
    {
      throw localRuntimeException;
    }
    catch (Throwable localThrowable)
    {
      throw new UndeclaredThrowableException(localThrowable);
    }
  }


  static
  {
    try
    {
      m1 = Class.forName("java.lang.Object").getMethod("equals", new Class[] { Class.forName("java.lang.Object") });
      m3 = Class.forName("com.binga.designpatterns.proxy.TargetInterface").getMethod("doSomething", new Class[0]);
      m2 = Class.forName("java.lang.Object").getMethod("toString", new Class[0]);
      m0 = Class.forName("java.lang.Object").getMethod("hashCode", new Class[0]);
      return;
    }
    catch (NoSuchMethodException localNoSuchMethodException)
    {
      throw new NoSuchMethodError(localNoSuchMethodException.getMessage());
    }
    catch (ClassNotFoundException localClassNotFoundException)
    {
      throw new NoClassDefFoundError(localClassNotFoundException.getMessage());
    }
  }
}

通过生成的代理类的类信息可以知道以下几点:

  1. 生成的代理类继承了Proxy,那么其该类的实例就有一个InvocationHandler属性,为h。
  2. 代理类实现了所有参数接口。这里是实现的TargetInterface接口。
  3. 代理通过反射维护了所有接口的方法的Method属性。
  4. 在所有的继承的方法中,调用了h(InvocationHandler)的invoke方法,参数是代表当前方法的Mehtod对象,代理类对象以及传输的参数。

通过上面的分析,就可以了解自己实现的InvocationHandler的invoke方法中的各个参数的含义了:

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    Object invoke = null;

    System.out.println(method.getName() + "前置处理");
    if (targetInterface != null)
        invoke  = method.invoke(targetInterface, args);
    System.out.println(method.getName() + "后置处理");
    return invoke;
}

那个这几个参数:

  • proxy:就是动态生成代理类的实例,即代理对象。
  • mehtod:代表接口的某一个方法的对应反射的Method实例。
  • args:调用代理对象方法传输的参数。

这里有一个注意点,就是在invoke方法中调用method的invoke方法时,不能将proxy参数传入,如果传入会导致递归的调用,从而抛出StackOverflowError异常。

实现自己的Proxy

接下来来自己实现一个Proxy,如下:

/**
 * @Description: 模拟JDK中的Proxy类
 *               生成代理类主要分为三步骤
 *               1>通过需要指定的代理接口,从而生成对应的实现类代码;
 *               2>使用JavaCompiler对生成的代码进行编译,生成Class文件;
 *               3>生成class文件后,通过AppClassLoader加载class文件初始化,然后实例化就是代理对象了。
 * @Author: binga
 * @Date: 2020/8/25 11:02
 * @Blog: https://blog.csdn.net/pang5356
 */
public class MyProxy {
    
    /**
     * 功能描述: 定义的钩子属性,InvocationHandler的 invoke方法在代理对象调用方法时会调用
     * @param: 
     * @return: 
     * @auther: binga
     * @date: 2020/8/25 18:05
     */
    protected InvocationHandler h;
    
    /**
     * 功能描述: 
     * @param: 不允许实例化
     * @return: 
     * @auther: binga
     * @date: 2020/8/25 18:07
     */
    private MyProxy() {

    }

    /**
     * 功能描述: 定义有参的构造函数,用于添加InvocationHandler的添加
     * @param: [h]
     * @return:
     * @auther: binga
     * @date: 2020/8/26 10:02
     */
    protected MyProxy(InvocationHandler h) {
        this.h = h;
    }
    
    /**
     * 功能描述: 该方法就是通过生成代码、编译、加载及创建实例从而动态的生成代理类的
     * @param: [classLoader, interfaces, h]
     * @return: java.lang.Object
     * @auther: binga
     * @date: 2020/8/25 18:07
     */
    public static Object getProxyInstance(ClassLoader classLoader,Class<?>[] interfaces, InvocationHandler h) {
        // 校验
        if (null == interfaces || interfaces.length ==0) {
            throw new RuntimeException("interfaces can not be null");
        }
        // 1.生成代码
        String code = generateCode(interfaces);
        // 2.保存代码
        File file = saveCode(code);
        // 3.编译文件,生成class文件
        compileCode(file);
        // 4. 通过反射生产代理对象
        Object proxyObj = generateProxyObject(classLoader, h);
        return proxyObj;
    }

    /**
     * 功能描述: 在加载完代理类的字节码文件后通过反射获取代理对象实例
     * @param: [classLoader, h]
     * @return: java.lang.Object
     * @auther: binga
     * @date: 2020/8/25 18:08
     */
    private static Object generateProxyObject(ClassLoader classLoader, InvocationHandler h) {
        Object object = null;
        try {
            String clazzName = MyProxy.class.getPackage().getName() + "." + "$Proxy0";
            Class<?> aClass = classLoader.loadClass(clazzName);
            Constructor<?> declaredConstructor = aClass.getDeclaredConstructor(new Class[]{InvocationHandler.class});
            object = declaredConstructor.newInstance(h);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
        return object;
    }

    /**
     * 功能描述: 保存动态生成的代码,保存至当前类编译后的字节码路径下,以便于编译时
     *          能够将class字节码文件输出的指定的路径下,从而可以是AppClassLoader扫描
     *          到
     * @param: [code]
     * @return: java.io.File
     * @auther: binga
     * @date: 2020/8/25 18:09
     */
    private static File saveCode(String code) {
        String path = MyProxy.class.getResource("").getPath();
        File file = new File(path, "$Proxy0.java");
        try(FileOutputStream out = new FileOutputStream(file)) {
            out.write(code.getBytes());
            out.flush();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return file;
    }

    /**
     * 功能描述: 编译动态生成的代码生成字节码文件
     * @param: [file]
     * @return: void
     * @auther: binga
     * @date: 2020/8/25 18:11
     */
    public static void compileCode(File file) {
        try {
            // 获取编译器
            JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
            StandardJavaFileManager manager = compiler.getStandardFileManager(null, null, null);
            Iterable<? extends JavaFileObject> javaFileObjects = manager.getJavaFileObjects(file);

            JavaCompiler.CompilationTask task = compiler.getTask(null, manager, null, null, null, javaFileObjects);
            task.call();
            manager.close();
            System.out.println("compile finish");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 功能描述: 根据指定的接口生成MyProxy的子类代码,即.java文件
     * @param: [interfaces]
     * @return: java.lang.String
     * @auther: binga
     * @date: 2020/8/25 18:11
     */
    private static String generateCode(Class<?>[] interfaces) {
        StringBuilder code = new StringBuilder();
        // 声明包
        code.append("package com.binga.designpatterns.proxy.simulateproxy; \r\n");

        // 导入
        for (Class<?> anInterface: interfaces) {
            code.append("import " + anInterface.getName() + "; \r\n");
        }

        // 声明类
        code.append("public class $Proxy0 extends com.binga.designpatterns.proxy.simulateproxy.MyProxy implements ");
        for (int i = 0; i < interfaces.length; i++) {
            code.append(interfaces[i].getName());
            if (i != interfaces.length -1) {
                code.append(",");
            }
        }
        code.append(" { \r\n");

        // equal
        code.append("   private static java.lang.reflect.Method m0; \r\n");
        // toString
        code.append("   private static java.lang.reflect.Method m1; \r\n");
        // hashCode
        code.append("   private static java.lang.reflect.Method m2; \r\n");

        int index = 3;
        for (int i = 0; i < interfaces.length; i++) {
            Class<?> anInterface = interfaces[i];
            Method[] declaredMethods = anInterface.getDeclaredMethods();
            for (int j = 0; j < declaredMethods.length; j++) {
                code.append("   private static java.lang.reflect.Method m" + index++ + "; \r\n");
            }
        }

        // 构造
        code.append("   public $Proxy0(java.lang.reflect.InvocationHandler h) { \r\n");
        code.append("       super(h); \r\n");
        code.append("   } \r\n");

        //equals方法
        code.append("   @Override\r\n");
        code.append("   public final boolean equals(java.lang.Object paramObject) { \r\n");
        code.append("       try { \r\n");
        code.append("           return ((Boolean)this.h.invoke(this, m0, new Object[] { paramObject })).booleanValue(); \r\n");
        code.append("       } catch(Throwable e) { \r\n");
        code.append("           e.printStackTrace(); \r\n");
        code.append("       }\r\n");
        code.append("       return false;\r\n");
        code.append("   }\r\n");

        // toString方法
        code.append("   @Override\r\n");
        code.append("   public java.lang.String toString() { \r\n");
        code.append("       try { \r\n");
        code.append("           return ((String)this.h.invoke(this, m1, null));\r\n");
        code.append("       } catch(Throwable e) { \r\n");
        code.append("           e.printStackTrace(); \r\n");
        code.append("       }\r\n");
        code.append("       return null;\r\n");
        code.append("   }\r\n");

        //hashCode 方法
        code.append("   @Override\r\n");
        code.append("   public int hashCode() {\r\n");
        code.append("       try { \r\n");
        code.append("           return ((Integer)this.h.invoke(this, m2, null)).intValue();\r\n");
        code.append("       } catch(Throwable e) { \r\n");
        code.append("           e.printStackTrace(); \r\n");
        code.append("       }\r\n");
        code.append("       return 0;\r\n");
        code.append("   }\r\n");

        int methodIndex = 3;
        for (int i = 0; i < interfaces.length; i++) {
            Class<?> anInterface = interfaces[i];
            Method[] declaredMethods = anInterface.getDeclaredMethods();
            for (int j = 0; j < declaredMethods.length; j++) {
                Method declaredMethod = declaredMethods[j];
                Class<?> returnType = declaredMethod.getReturnType();
                String returnStr = getReturnStr(returnType);
                code.append("   public " + returnStr + " " + declaredMethod.getName() + "( ");
                Class<?>[] parameterTypes = declaredMethod.getParameterTypes();
                for (int k = 0; k < parameterTypes.length; k++) {
                    String paramStr = getReturnStr(parameterTypes[k]);
                    code.append(paramStr + "var" + k);
                    if (parameterTypes.length - 1 != k)
                        code.append(", ");
                }
                code.append(") { \r\n");
                code.append("       try { \r\n");
                code.append("           java.lang.Object[] args = {");
                for (int k = 0; k < parameterTypes.length; k++) {
                    code.append("var" + k);
                    if (parameterTypes.length != 1) {
                        code.append(",");
                    }
                }
                code.append("}; \r\n");
                if ("void".equals(returnStr)) {
                    code.append("       this.h.invoke(this, m" + methodIndex++  + ", null); \r\n");
                    code.append("       return;\r\n");
                } else {
                    if(isPrimitive(returnType)) {
                        String primitiveCast = getPrimitiveCast(returnType);
                        String primitive = getReturnStr(returnType) + "Value()";
                        code.append("       return ((" + primitiveCast + ") this.h.invoke(this, m" + methodIndex++ + ", args))." + primitive +"; \r\n");
                    } else {
                        code.append("       return (" + returnStr + ") this.h.invoke(this, m" + methodIndex++ + ", args); \r\n");
                    }
                }

                code.append("       } catch(Throwable e) { \r\n");
                code.append("           e.printStackTrace(); \r\n");
                code.append("       }\r\n");
                code.append("   }\r\n");
            }
        }

        // 静态代码块
        code.append("   static { \r\n");
        code.append("       try{ \r\n");
        code.append("           m0 = Class.forName(\"java.lang.Object\").getMethod(\"equals\", new Class[] { Class.forName(\"java.lang.Object\") });\r\n");
        code.append("           m1 = Class.forName(\"java.lang.Object\").getMethod(\"toString\", new Class[0]);\r\n");
        code.append("           m2 = Class.forName(\"java.lang.Object\").getMethod(\"hashCode\", new Class[0]);\r\n");
        int staticIndex = 3;
        for (int i = 0; i < interfaces.length; i++) {
            Class<?> anInterface = interfaces[i];
            String interfaceName = anInterface.getName();
            Method[] declaredMethods = anInterface.getDeclaredMethods();
            for (int j = 0; j < declaredMethods.length; j++) {
                Method declaredMethod = declaredMethods[j];
                String methodName = declaredMethod.getName();
                code.append("           m" + staticIndex++ + " = Class.forName(\"" + interfaceName + "\").getMethod(\"" + methodName + "\",");
                if (declaredMethod.getParameterTypes().length == 0) {
                    code.append("new Class[0]");
                } else {
                    Class<?>[] parameterTypes = declaredMethod.getParameterTypes();
                    code.append("new Class[]{");
                    for (int k = 0; k < parameterTypes.length; k++) {
                        if(isPrimitive(parameterTypes[k])) {
                            String paraStr = getReturnStr(parameterTypes[k]) + ".class";
                            code.append(paraStr);
                        } else {
                            code.append("Class.forName(\"" + getReturnStr(parameterTypes[k]) + "\")");
                        }

                        if (k != parameterTypes.length -1) {
                            code.append(",");
                        }
                    }
                    code.append("}");
                }

                code.append("); \r\n");
                // 组装param

            }
        }
        code.append("       } catch(Throwable e) { \r\n ");
        code.append("           e.printStackTrace(); \r\n");
        code.append("       }\r\n");
        code.append("   }\r\n");

        code.append("}");

        return code.toString();
    }


    /**
     * 功能描述: 根据基本类型,返回其封装类的名称
     * @param: [returnType]
     * @return: java.lang.String
     * @auther: binga
     * @date: 2020/8/26 10:03
     */
    private static String getPrimitiveCast(Class<?> returnType) {
        if (byte.class == returnType)
            return "Byte";
        if (short.class == returnType)
            return "Short";
        if (char.class == returnType)
            return "Character";
        if (int.class == returnType)
            return "Integer";
        if (boolean.class == returnType)
            return "Boolean";
        if (float.class == returnType)
            return "Float";
        if (double.class == returnType)
            return "Double";
        if (long.class == returnType)
            return "Long";
        return "";
    }

    /**
     * 功能描述: 判断是否是基础数据
     * @param: [returnType]
     * @return: boolean
     * @auther: binga
     * @date: 2020/8/26 10:04
     */
    private static boolean isPrimitive(Class<?> returnType) {
        if (byte.class == returnType || short.class == returnType
                || char.class == returnType || int.class == returnType
                || boolean.class == returnType || float.class == returnType
                || double.class == returnType || long.class == returnType) {
            return true;
        }
        return false;
    }

    /**
     * 功能描述: 组装返回类型的字符串信息
     * @param: [returnType]
     * @return: java.lang.String
     * @auther: binga
     * @date: 2020/8/26 10:04
     */
    private static String getReturnStr(Class<?> returnType) {
        //无返回类型
        if (void.class == returnType)
            return "void";

        // 基础数据类型
        if (byte.class == returnType)
            return "byte";
        if (short.class == returnType)
            return "short";
        if (char.class == returnType)
            return "char";
        if (int.class == returnType)
            return "int";
        if (boolean.class == returnType)
            return "boolean";
        if (float.class == returnType)
            return "float";
        if (double.class == returnType)
            return "double";
        if (long.class == returnType)
            return "long";
        return returnType.getName();
    }
}

该类主要是通过字符串的拼接从而生成代理类的.java文件,然后对其进行编译和加载(该实现与JDK的动态代理相去甚远,只做参考用于理解JDK的动态代机制即可)。
接下来实现自己的InvocationHandler,如下:

/**
 * @Description: handler
 * @Author: binga
 * @Date: 2020/8/25 15:30
 * @Blog: https://blog.csdn.net/pang5356
 */
public class MyHandler implements InvocationHandler {

    private Object object;

    public MyHandler(Object object) {
        this.object = object;
    }


    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        Object obj = null;
        System.out.println("method before");
        if (proxy != null)
            obj =  method.invoke(this.object, args);

        System.out.println("method after");
        return obj;
    }
}

测试代码如下:

/**
 * @Description: 自定义生成代理
 * @Author: binga
 * @Date: 2020/8/25 15:29
 * @Blog: https://blog.csdn.net/pang5356
 */
public class CustomProxyTest extends java.lang.Object implements Serializable {
    
    @Override
    public String toString() {
        return super.toString();
    }

    public static void main(String[] args) {
        MyHandler myHandler = new MyHandler(new TargetImpl());
        Object proxyInstance = MyProxy.getProxyInstance(CustomProxyTest.class.getClassLoader(),
                new Class[]{TargetInterface.class}, myHandler);
        TargetInterface proxyInstance1 = (TargetInterface) proxyInstance;
        ((TargetInterface) proxyInstance).doSomething();
    }
}

测试结果如下:

compile finish
method before
TargetImpl method
method after

接下来来了解一下JDK的动态代理源码吧–> JDK动态代理原理剖析(下)
代码下载:码云地址

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值