这篇是实现我们上一篇分析的简单模拟Spring核心敲的代码!建议对比着看。
点击这里,跳转到上一篇!
模拟Spring核心IOC实现类的注入--第二篇(代码的实现)
com.mec.spring.core包主要就是内部逻辑,以及代码实现,最终只有BeanFactory对外开放使用。
com.mec.spring.demo包里面写的类加上了各种注解。用来测试代码逻辑是否正确。
com.mec.spring.test包里面一个Test类,主函数,测试用。
![](https://i-blog.csdnimg.cn/blog_migrate/f34f6ae64c3127da008b249b8ec39c31.png)
我们从com.mec.spring.core这个包开始上代码!
com.mec.spring.core
Component注解类
@Retention(RUNTIME)
@Target(TYPE)
public @interface Component {
boolean singleton() default true;
}
AutoWired注解类
@Retention(RUNTIME)
@Target({ FIELD, METHOD })
public @interface AutoWired {
}
Bean注解类
@Retention(RUNTIME)
@Target(METHOD)
public @interface Bean {
}
BeanDefinition类
package com.mec.spring.core;
public class BeanDefinition {
private Class<?> klass;
private Object object;
private boolean inject;
private boolean singleton;
BeanDefinition() {
this.inject = false;
this.singleton = true;
}
Class<?> getKlass() {
return klass;
}
boolean isInject() {
return inject;
}
void setInject(boolean inject) {
this.inject = inject;
}
boolean isSingleton() {
return singleton;
}
void setSingleton(boolean singleton) {
this.singleton = singleton;
}
void setKlass(Class<?> klass) {
this.klass = klass;
}
Object getObject() {
return object;
}
void setObject(Object object) {
this.object = object;
}
}
BeanFactory类
package com.mec.spring.core;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.mec.util.PackageScanner;
public class BeanFactory {
public static final Map<String, BeanDefinition> beanPool = new HashMap<>();
public BeanFactory() {
}
public static void scanPackage(String packageName) {
new PackageScanner() {
@Override
public void dealClass(Class<?> klass) {
if(klass.isPrimitive()
|| klass == String.class
|| klass.isAnnotation()
|| klass.isArray()
|| klass.isEnum()
|| !klass.isAnnotationPresent(Component.class)) {
return;
}
BeanDefinition definition = new BeanDefinition();
Object object = null;
Component component = klass.getAnnotation(Component.class);
boolean singleton = component.singleton();
try {
if(singleton) {
object = klass.newInstance();
}
definition.setSingleton(singleton);
definition.setKlass(klass);
definition.setObject(object);
beanPool.put(klass.getName(), definition);
MethodDependence.checkMethodPara(klass);
scanBean(object, klass);
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}.scanPackage(packageName);;
MethodDependence.invokeMethod();
}
private static Map<Class<?>, Integer> getMethodPara(Method method) {
Class<?>[] paraType = method.getParameterTypes();
Map<Class<?>, Integer> paraPool = new HashMap<Class<?>, Integer>();
for(Class<?> paraKlass : paraType) {
paraPool.put(paraKlass, 0);
}
List<Class<?>> okList = new ArrayList<Class<?>>();
for(Class<?> paklass : paraPool.keySet()) {
BeanDefinition bean = getBeanDefinition(paklass.getName());
if(bean != null) {
okList.add(paklass);
}
}
for(int index = 0;index < okList.size();index++) {
paraPool.remove(okList.get(index));
}
return paraPool;
}
private static void dealMethodPara(Object object,Class<?> klass,Method method) {
Map<Class<?>, Integer> inParaPool = getMethodPara(method);
MethodDefinition de = new MethodDefinition();
de.setKlass(klass);
de.setObject(object);
de.setMethod(method);
de.setParaCount(inParaPool.size());
if(inParaPool.size() == 0) {
MethodDependence.addInvokeList(de);
return;
}
MethodDependence.addUninvokeList(de, inParaPool);
}
public static String showDependence() {
return MethodDependence.showDependence();
}
static void invokeParaMethod(Object object,Method method) {
Class<?>[] paraType = method.getParameterTypes();
Object[] values = new Object[paraType.length];
for(int index = 0; index < paraType.length;index++) {
BeanDefinition bean = getBeanDefinition(paraType[index].getName());
values[index] = bean.getObject();
}
try {
Object value = method.invoke(object, values);
BeanDefinition bd = new BeanDefinition();
bd.setKlass(method.getReturnType());
bd.setObject(value);
beanPool.put(method.getReturnType().getName(), bd);
MethodDependence.checkMethodPara(method.getReturnType());
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
}
private static void scanBean(Object object,Class<?> klass) {
Method[] methods = klass.getDeclaredMethods();
for(Method method : methods) {
if(!method.isAnnotationPresent(Bean.class)) {
continue;
}
if(method.getParameterCount() > 0) {
dealMethodPara(object, klass, method);
continue;
}
try {
Class<?> type = method.getReturnType();
Object value = method.invoke(object);
BeanDefinition bean = new BeanDefinition();
bean.setKlass(type);
bean.setObject(value);
beanPool.put(type.getName(), bean);
MethodDependence.checkMethodPara(type);
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
}
}
public static <T> T getBean(Class<?> klass) {
return getBean(klass.getName());
}
private static BeanDefinition getBeanDefinition(String klassName) {
BeanDefinition bean = beanPool.get(klassName);
if(bean == null) {
//System.out.println(klassName + " 该对象不存在!");
return null;
}
Object object = bean.getObject();
if(!bean.isSingleton()) {
try {
Class<?> klass = bean.getKlass();
object = klass.newInstance();
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
bean.setObject(object);
return bean;
}
@SuppressWarnings("unchecked")
public static <T> T getBean(String klassName) {
BeanDefinition bean = getBeanDefinition(klassName);
if(bean == null) {
System.out.println("这个bean不存在!");
}
Object object = bean.getObject();
if(!bean.isInject() || !bean.isSingleton()) {
bean.setInject(true);
inject(object);
}
return (T) object;
}
private static void inject(Object object) {
Class<?> klass = object.getClass();
injectMember(object, klass);
injectMethod(object, klass);
}
private static void injectMethod(Object object,Class<?> klass) {
Method[] methods = klass.getDeclaredMethods();
for(Method method : methods) {
int modify = method.getModifiers();
if(!method.isAnnotationPresent(AutoWired.class)
|| !method.getName().startsWith("set")
|| method.getParameterCount() != 1
|| modify != 1) {
continue;
}
Class<?> type = method.getParameterTypes()[0];
Object value = getBean(type);
try {
method.invoke(object, new Object[] {value});
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
}
}
private static void injectMember(Object object,Class<?> klass) {
Field[] fields = klass.getDeclaredFields();
for(Field field : fields) {
if(!field.isAnnotationPresent(AutoWired.class)) {
continue;
}
//
Object value = getBean(field.getType()); //成员不能getClass(),需要getType()
field.setAccessible(true);
try {
field.set(object, value);
//
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
MethodDefinition类
package com.mec.spring.core;
import java.lang.reflect.Method;
public class MethodDefinition {
private Object object;
private Class<?> klass;
private Method method;
private int paraCount;
public MethodDefinition() {
paraCount = 0;
}
Method getMethod() {
return method;
}
void setMethod(Method method) {
this.method = method;
}
Object getObject() {
return object;
}
void setObject(Object object) {
this.object = object;
}
Class<?> getKlass() {
return klass;
}
void setKlass(Class<?> klass) {
this.klass = klass;
}
int getParaCount() {
return paraCount;
}
void setParaCount(int paraCount) {
this.paraCount = paraCount;
}
int subParaCount() {
return --paraCount;
}
}
MethodDependence类
package com.mec.spring.core;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class MethodDependence {
private static final List<MethodDefinition> uninvokeList = new ArrayList<>();
private static final List<MethodDefinition> invokeList = new LinkedList<>();
private static final Map<Class<?>, List<MethodDefinition>> dependenceMethodPool= new HashMap<>();
public MethodDependence() {
}
public static void addUninvokeList(MethodDefinition me,Map<Class<?>, Integer> inParaPool ) {
if(me == null) {
return;
}
uninvokeList.add(me);
for(Class<?> klass : inParaPool.keySet()) {
if(!dependenceMethodPool.containsKey(klass)) {
List<MethodDefinition> mdList = new ArrayList<MethodDefinition>();
dependenceMethodPool.put(klass, mdList);
}
List<MethodDefinition> list = dependenceMethodPool.get(klass);
list.add(me);
}
}
public static void addInvokeList(MethodDefinition me) {
invokeList.add(me);
}
public static void checkMethodPara(Class<?> klass) {
List<MethodDefinition> list = dependenceMethodPool.get(klass);
if(list == null) {
return;
}
for(MethodDefinition md : list) {
if(md.subParaCount() == 0) {
invokeList.add(md);
uninvokeList.remove(md);
}
}
dependenceMethodPool.remove(klass);
}
public static String showDependence() {
StringBuffer str = new StringBuffer();
str.append("依赖关系如下:");
for(Class<?> klass : dependenceMethodPool.keySet()) {
List<MethodDefinition> list = dependenceMethodPool.get(klass);
if(!list.isEmpty()) {
for(MethodDefinition md : list) {
String klassName = md.getMethod().getReturnType().getName();
str.append("\n").append(klassName).append("——>>——").append(klass);
}
}
}
return str.toString();
}
public static void invokeMethod() {
while(!invokeList.isEmpty()) {
MethodDefinition md = invokeList.get(0);
invokeList.remove(0);
Method method = md.getMethod();
Object object = md.getObject();
BeanFactory.invokeParaMethod(object, method);
}
}
}
com.mec.spring.demo
这里列举几个类意思意思,就不全放了。
ClassTwo类
package com.mec.spring.demo;
public class ClassTwo {
public ClassTwo() {
}
@Override
public String toString() {
return "这是一个ClassTwo的对象!";
}
}
ClassThree类
package com.mec.spring.demo;
import com.mec.spring.core.AutoWired;
import com.mec.spring.core.Component;
@Component
public class ClassThree {
@AutoWired
ClassTwo two;
public ClassTwo getTwo() {
return two;
}
public void setTwo(ClassTwo two) {
this.two = two;
}
}
Config类!!!
package com.mec.spring.demo;
import java.util.Calendar;
import com.mec.complex.core.Complex;
import com.mec.spring.core.Bean;
import com.mec.spring.core.Component;
import com.mec.spring.tmp.One;
import com.mec.spring.tmp.Two;
@Component
public class Config {
public Config() {
}
@Bean
public ClassTwo getTwoClass() {
ClassTwo two = new ClassTwo();
return two;
}
@Bean
public Two getTwo(One one) {
Two two = new Two();
two.setOne(one);
return two;
}
@Bean
public One getOne(Two two) {
One one = new One();
one.setTwo(two);
return one;
}
@Bean
public ClassForth getClassForth(Complex com) {
ClassForth forth = new ClassForth();
forth.setComplex(com);
return forth;
}
@Bean
public Calendar getDate() {
Calendar date = Calendar.getInstance();
return date;
}
// @Bean
// public Connection getConnection() {
// Connection connection = null;
// try {
// Class.forName("com.mysql.jdbc.Driver");
// connection = DriverManager.getConnection(
// "jdbc.odbc.localhost:3306//InformationTable",
// "dxy",
// "dxy222926");
// } catch (ClassNotFoundException e) {
// e.printStackTrace();
// } catch (SQLException e) {
// e.printStackTrace();
// }
// return connection;
// }
}
还有一个包是为了处理依赖关系建的,代码也简单实现了发现并告知使用者存在的循环依赖!下次再介绍吧!