1.概述
在Spring的生态中,借助@Autowired注解来实现依赖注入,可以说是非常普遍的事情了,如果让我们自定义一个注解,也实现类似的功能,那么我们可以怎么做呢?
本文介绍如何实现一个自定义的@Autowired,实现依赖服务注入
主要知识点:
BeanPostProcessor
代理类创建
2.项目环境
本项目借助SpringBoot 2.2.1.RELEASE + maven 3.5.3 + IDEA进行开发
下面是核心的pom.xml(源码可以再文末获取)
<!-- 这个依赖是干嘛的,后文会介绍 -->
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
</dependency>
</dependencies>
3.实现姿势
3.1. 代理封装类
借助Spring的Enhance来实现代理类生成,比如一个基础的工具类如下,用于自定义注入的增强
package com.spring.annotation.autowire;
import org.springframework.cglib.proxy.Enhancer;
import org.springframework.cglib.proxy.MethodInterceptor;
import org.springframework.cglib.proxy.MethodProxy;
import java.io.Serializable;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
public class ProxyUtil {
public static <T> T newProxyInstance(Class<?> targetClass,
InvocationHandler invocationHandler,
ProxyUtil.CallbackFilter filter) {
if (targetClass == null) {
return null;
} else {
SimpleMethodInterceptor simpleMethodInterceptor = new SimpleMethodInterceptor(invocationHandler, filter);
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(targetClass);
enhancer.setUseCache(true);
enhancer.setCallback(simpleMethodInterceptor);
// 无参构造方法
return (T) enhancer.create();
}
}
public interface CallbackFilter {
boolean accept(Method var1);
}
private static class SimpleMethodInterceptor implements MethodInterceptor, Serializable {
private transient InvocationHandler invocationHandler;
private transient ProxyUtil.CallbackFilter filter;
public SimpleMethodInterceptor(InvocationHandler invocationHandler, ProxyUtil.CallbackFilter filter) {
this.invocationHandler = invocationHandler;
this.filter = filter;
}
@Override
public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
return this.filter.accept(method) ? this.invocationHandler.invoke(o, method, objects) : methodProxy.invokeSuper(o, objects);
}
}
}
3.2. 自定义注解
参照@Autowired的定义,实现一个自定义的注解(缩减版)
package com.spring.annotation.autowire;
import java.lang.annotation.*;
@Target({ElementType.METHOD, ElementType.FIELD, ElementType.ANNOTATION_TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface AutoInject {
}
3.3 自定义注入
实现BeanPostProcessor,在bean初始化之后,扫描field/method,为了做一个区分,下面创建一个代理类,注入依赖
package com.spring.annotation.autowire;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
@Component
public class AutoInjectPostProcessor implements BeanPostProcessor {
private ApplicationContext applicationContext;
public AutoInjectPostProcessor(ApplicationContext applicationContext) {
this.applicationContext = applicationContext;
}
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
Class<?> clazz = bean.getClass();
do {
for (final Field field : clazz.getDeclaredFields()) {
final AutoInject annotation = AnnotationUtils.findAnnotation(field, AutoInject.class);
if (annotation != null) {
ReflectionUtils.makeAccessible(field);
ReflectionUtils.setField(field, bean, processInjectionPoint(field.getType()));
}
}
for (final Method method : clazz.getDeclaredMethods()) {
final AutoInject annotation = AnnotationUtils.findAnnotation(method, AutoInject.class);
if (annotation != null) {
final Class<?>[] paramTypes = method.getParameterTypes();
if (paramTypes.length != 1) {
throw new BeanDefinitionStoreException(
"Method " + method + " doesn't have exactly one parameter.");
}
ReflectionUtils.makeAccessible(method);
ReflectionUtils.invokeMethod(method, bean,
processInjectionPoint(paramTypes[0]));
}
}
clazz = clazz.getSuperclass();
} while (clazz != null);
return bean;
}
// 创建代理类,在具体方法执行前后输出一个日志
protected <T> T processInjectionPoint(final Class<T> injectionType) {
return ProxyUtil.newProxyInstance(injectionType, new InvocationHandler() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
System.out.println("do before " + method.getName() + " | " + Thread.currentThread());
try {
Object obj = applicationContext.getBean(injectionType);
return method.invoke(obj, args);
} finally {
System.out.println("do after " + method.getName() + " | " + Thread.currentThread());
}
}
}, new ProxyUtil.CallbackFilter() {
@Override
public boolean accept(Method var1) {
return true;
}
});
}
}
3.4 测试
接下来验证一下自定义注入方式
package com.spring.service.atowire;
import org.springframework.stereotype.Component;
@Component
public class DemoService {
public int calculate(int a, int b) {
doBefore();
return a + b;
}
private void doBefore() {
System.out.println("-------- inner ----------: " + Thread.currentThread());
}
}
package com.spring.service.atowire;
import org.springframework.stereotype.Component;
@Component
public class DemoService2 {
public int calculate(int a, int b) {
doBefore();
return a + b;
}
private void doBefore() {
System.out.println("-------- inner ----------: " + Thread.currentThread());
}
}
package com.spring.service.atowire;
import com.spring.annotation.autowire.AutoInject;
import org.springframework.stereotype.Service;
@Service
public class RestService {
@AutoInject
private DemoService demoService;
private DemoService2 demoService2;
@AutoInject
public void setDemoService2(DemoService2 demoService2) {
this.demoService2 = demoService2;
}
public void test() {
int ans = demoService.calculate(10, 20);
System.out.println(ans);
ans = demoService2.calculate(11, 22);
System.out.println(ans);
}
}
测试类
package com.spring.service.atowire;
import com.spring.BaseTest;
import com.spring.annotation.autowire.AutoInject;
import org.junit.Test;
import static org.junit.Assert.*;
public class RestServiceTest extends BaseTest {
@AutoInject
private RestService restService;
/***
* todo: 九师兄 2023/4/29 22:21
* 测试点:测试自定义注解
*
* do before test | Thread[main,5,main]
* do before calculate | Thread[main,5,main]
* -------- inner ----------: Thread[main,5,main]
* do after calculate | Thread[main,5,main]
* 30
* do before calculate | Thread[main,5,main]
* -------- inner ----------: Thread[main,5,main]
* do after calculate | Thread[main,5,main]
* 33
* do after test | Thread[main,5,main]
*/
@Test
public void test1() {
restService.test();
}
}