模拟Spring核心IOC实现类的注入--第二篇(代码的实现)

2 篇文章 0 订阅
2 篇文章 0 订阅

这篇是实现我们上一篇分析的简单模拟Spring核心敲的代码!建议对比着看。
点击这里,跳转到上一篇!


com.mec.spring.core包主要就是内部逻辑,以及代码实现,最终只有BeanFactory对外开放使用。
com.mec.spring.demo包里面写的类加上了各种注解。用来测试代码逻辑是否正确。
com.mec.spring.test包里面一个Test类,主函数,测试用。


我们从com.mec.spring.core这个包开始上代码!

com.mec.spring.core

Component注解类

@Retention(RUNTIME)
@Target(TYPE)
public @interface Component {
	 boolean singleton() default true;
}

AutoWired注解类

@Retention(RUNTIME)
@Target({ FIELD, METHOD })
public @interface AutoWired {

}

Bean注解类

@Retention(RUNTIME)
@Target(METHOD)
public @interface Bean {

}

BeanDefinition类

package com.mec.spring.core;

public class BeanDefinition {
	private Class<?> klass;
	private Object object;
	private boolean inject;
	private boolean singleton;
	
	BeanDefinition() {
		this.inject = false;
		this.singleton = true;
	}

	Class<?> getKlass() {
		return klass;
	}
	
	boolean isInject() {
		return inject;
	}

	void setInject(boolean inject) {
		this.inject = inject;
	}

	boolean isSingleton() {
		return singleton;
	}

	void setSingleton(boolean singleton) {
		this.singleton = singleton;
	}

	void setKlass(Class<?> klass) {
		this.klass = klass;
	}

	Object getObject() {
		return object;
	}

	void setObject(Object object) {
		this.object = object;
	}
}

BeanFactory类

package com.mec.spring.core;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.mec.util.PackageScanner;

public class BeanFactory {
	public static final Map<String, BeanDefinition> beanPool  = new HashMap<>();
	
	public BeanFactory() {
	}
	
	public static void scanPackage(String packageName) {
		new PackageScanner() {
			
			@Override
			public void dealClass(Class<?> klass) {
				if(klass.isPrimitive()
					|| klass == String.class
					|| klass.isAnnotation()
					|| klass.isArray()
					|| klass.isEnum()
					|| !klass.isAnnotationPresent(Component.class)) {
					return;
				}
				BeanDefinition definition = new BeanDefinition();
				Object object = null;
				Component component = klass.getAnnotation(Component.class);
				boolean singleton = component.singleton();
				try {
					if(singleton) {
						object = klass.newInstance();
					}
					definition.setSingleton(singleton);
					definition.setKlass(klass);
					definition.setObject(object);
					beanPool.put(klass.getName(), definition);
					MethodDependence.checkMethodPara(klass);
					scanBean(object, klass);
					} catch (InstantiationException e) {
						e.printStackTrace();
					} catch (IllegalAccessException e) {
						e.printStackTrace();
					}
				}
		}.scanPackage(packageName);;
		MethodDependence.invokeMethod();
	}
	
	private static Map<Class<?>, Integer> getMethodPara(Method method) {
		Class<?>[] paraType = method.getParameterTypes();
		Map<Class<?>, Integer> paraPool = new HashMap<Class<?>, Integer>();
		for(Class<?> paraKlass : paraType) {
			paraPool.put(paraKlass, 0);
		}
		
		List<Class<?>> okList = new ArrayList<Class<?>>();
		for(Class<?> paklass : paraPool.keySet()) {
			BeanDefinition bean = getBeanDefinition(paklass.getName());
			if(bean != null) {
				okList.add(paklass);
			}
		}
		for(int index = 0;index < okList.size();index++) {
			paraPool.remove(okList.get(index));
		}
		return paraPool;
	}
	
	private static void dealMethodPara(Object object,Class<?> klass,Method method) {
		Map<Class<?>, Integer> inParaPool = getMethodPara(method);
		
		MethodDefinition de = new MethodDefinition();
		de.setKlass(klass);
		de.setObject(object);
		de.setMethod(method);
		de.setParaCount(inParaPool.size());
		
		if(inParaPool.size() == 0) {
			MethodDependence.addInvokeList(de);
			return;
		}
		MethodDependence.addUninvokeList(de, inParaPool);
	}
	
	public static String showDependence() {
		return MethodDependence.showDependence();
	}
	
	static void invokeParaMethod(Object object,Method method) {
		Class<?>[] paraType = method.getParameterTypes();
		Object[] values = new Object[paraType.length];
		for(int index = 0; index < paraType.length;index++) {
			BeanDefinition bean = getBeanDefinition(paraType[index].getName());
			values[index] = bean.getObject();
		}
		 try {
			Object value = method.invoke(object, values);
			BeanDefinition bd = new BeanDefinition();
			bd.setKlass(method.getReturnType());
			bd.setObject(value);
			beanPool.put(method.getReturnType().getName(), bd);
			MethodDependence.checkMethodPara(method.getReturnType());
		} catch (IllegalAccessException e) {
			e.printStackTrace();
		} catch (IllegalArgumentException e) {
			e.printStackTrace();
		} catch (InvocationTargetException e) {
			e.printStackTrace();
		}	
	}
	
	
	private static void scanBean(Object object,Class<?> klass) {
		Method[] methods = klass.getDeclaredMethods();
		for(Method method : methods) {
			if(!method.isAnnotationPresent(Bean.class)) {
				continue;
			}
			if(method.getParameterCount() > 0) {
				dealMethodPara(object, klass, method);
				continue;
			}
			try {
				Class<?> type = method.getReturnType();
				Object value = method.invoke(object);
				BeanDefinition bean = new BeanDefinition();
				bean.setKlass(type);
				bean.setObject(value);
		
				beanPool.put(type.getName(), bean);
				MethodDependence.checkMethodPara(type);
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			} catch (IllegalArgumentException e) {
				e.printStackTrace();
			} catch (InvocationTargetException e) {
				e.printStackTrace();
			}
			
			
		}
	}
	
	public static <T> T getBean(Class<?> klass) {
		return getBean(klass.getName());
	}
	
	private static BeanDefinition getBeanDefinition(String klassName) {
		BeanDefinition bean = beanPool.get(klassName);
		if(bean == null) {
			//System.out.println(klassName + " 该对象不存在!");
			return null;
		}
		Object object = bean.getObject();
	
		if(!bean.isSingleton()) {
			try {
				Class<?> klass = bean.getKlass();
				object = klass.newInstance();
			} catch (InstantiationException e) {
				e.printStackTrace();
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			}
		}
		bean.setObject(object);
		return bean;
	}
	
	@SuppressWarnings("unchecked")
	public static <T> T getBean(String klassName) {
		BeanDefinition bean = getBeanDefinition(klassName);
		if(bean == null) {
			System.out.println("这个bean不存在!");
		}
		Object object = bean.getObject();
		if(!bean.isInject() || !bean.isSingleton()) {
			bean.setInject(true);
			inject(object);
		}
		return (T) object;
	}
	
	private static void inject(Object object) {
		Class<?> klass = object.getClass();
		injectMember(object, klass);
		injectMethod(object, klass);
	}
	
	private static void injectMethod(Object object,Class<?> klass) {
		Method[] methods = klass.getDeclaredMethods();
		for(Method method : methods) {
			int modify = method.getModifiers();
			if(!method.isAnnotationPresent(AutoWired.class)
					|| !method.getName().startsWith("set")
					|| method.getParameterCount() != 1
					|| modify != 1) {
				continue;
			}
			Class<?> type = method.getParameterTypes()[0];
			Object value = getBean(type);
			try {
				method.invoke(object, new Object[] {value});
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			} catch (IllegalArgumentException e) {
				e.printStackTrace();
			} catch (InvocationTargetException e) {
				e.printStackTrace();
			}
		
		}
	}
	
	
	private static void injectMember(Object object,Class<?> klass) {
		Field[] fields = klass.getDeclaredFields();
		for(Field field : fields) {
			if(!field.isAnnotationPresent(AutoWired.class)) {
				continue;
			}
			//
			Object value = getBean(field.getType());	//成员不能getClass(),需要getType()	
			field.setAccessible(true);
			try {
				field.set(object, value);
				//
			} catch (IllegalArgumentException e) {
				e.printStackTrace();
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			}
		}
	}
	
}

MethodDefinition类

package com.mec.spring.core;

import java.lang.reflect.Method;

public class MethodDefinition {
	private Object object;
	private Class<?> klass;
	private Method method;
	private int paraCount;
	
	public MethodDefinition() {
		paraCount = 0;
	}
	
	Method getMethod() {
		return method;
	}

	void setMethod(Method method) {
		this.method = method;
	}



	Object getObject() {
		return object;
	}

	void setObject(Object object) {
		this.object = object;
	}

	Class<?> getKlass() {
		return klass;
	}

	void setKlass(Class<?> klass) {
		this.klass = klass;
	}

	int getParaCount() {
		return paraCount;
	}
	
	void setParaCount(int paraCount) {
		this.paraCount = paraCount;
	}
	
	int subParaCount() {
		return --paraCount;
	}
	
}

MethodDependence类

package com.mec.spring.core;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

public class MethodDependence {
	private static final List<MethodDefinition> uninvokeList = new ArrayList<>();
	private static final List<MethodDefinition> invokeList = new LinkedList<>();
	private static final Map<Class<?>, List<MethodDefinition>>  dependenceMethodPool=  new HashMap<>();
	
	public MethodDependence() {
	}
	
	public static void addUninvokeList(MethodDefinition me,Map<Class<?>, Integer> inParaPool ) {
		if(me == null) {
			return;
		}
		uninvokeList.add(me);
		
		for(Class<?> klass : inParaPool.keySet()) {
			if(!dependenceMethodPool.containsKey(klass)) {
				List<MethodDefinition> mdList = new ArrayList<MethodDefinition>();
				dependenceMethodPool.put(klass, mdList);
			}
			List<MethodDefinition> list = dependenceMethodPool.get(klass);
			list.add(me);
		}
	}
	
	public static void addInvokeList(MethodDefinition me) {
		invokeList.add(me);
	}
	
	public static void checkMethodPara(Class<?> klass) {
		List<MethodDefinition> list = dependenceMethodPool.get(klass);
		if(list == null) {
			return;
		}
		for(MethodDefinition md : list) {
			if(md.subParaCount() == 0) {
				invokeList.add(md);
				uninvokeList.remove(md);
			}
		}
		dependenceMethodPool.remove(klass);
	}
	
	public static String showDependence() {
		StringBuffer str = new StringBuffer();
		str.append("依赖关系如下:");
		for(Class<?> klass : dependenceMethodPool.keySet()) {
			List<MethodDefinition> list = dependenceMethodPool.get(klass);
			if(!list.isEmpty()) {
				for(MethodDefinition md : list) {
					String klassName = md.getMethod().getReturnType().getName();
					str.append("\n").append(klassName).append("——>>——").append(klass);
				}
			}
		}
		return str.toString();
	}
	
	public static void invokeMethod() {
		while(!invokeList.isEmpty()) {
			MethodDefinition md = 	invokeList.get(0);
			invokeList.remove(0);
			Method method = md.getMethod();
			Object object = md.getObject();
			BeanFactory.invokeParaMethod(object, method);
		}
		
	}
	
	
}


com.mec.spring.demo

这里列举几个类意思意思,就不全放了。

ClassTwo类

package com.mec.spring.demo;


public class ClassTwo {

	public ClassTwo() {
	}

	@Override
	public String toString() {
		return "这是一个ClassTwo的对象!";
	}
	
}

ClassThree类

package com.mec.spring.demo;

import com.mec.spring.core.AutoWired;
import com.mec.spring.core.Component;

@Component
public class ClassThree {
	@AutoWired
	ClassTwo two;

	public ClassTwo getTwo() {
		return two;
	}

	public void setTwo(ClassTwo two) {
		this.two = two;
	}

}

Config类!!!

package com.mec.spring.demo;

import java.util.Calendar;

import com.mec.complex.core.Complex;
import com.mec.spring.core.Bean;
import com.mec.spring.core.Component;
import com.mec.spring.tmp.One;
import com.mec.spring.tmp.Two;

@Component
public class Config {
	
	public Config() {
	}
	
	@Bean
	public ClassTwo getTwoClass() {
		ClassTwo two = new ClassTwo();
		return two;
	}
	
	@Bean
	public Two getTwo(One one) {
		Two two = new Two();
		two.setOne(one);
		return two;
	}
	
	@Bean
	public One getOne(Two two) {
		One one = new One();
		one.setTwo(two);
		return one;
	}
	
	@Bean
	public ClassForth getClassForth(Complex com) {
		ClassForth forth = new ClassForth();
		forth.setComplex(com);
		return forth;
	}

	@Bean
	public Calendar getDate() {
		Calendar date = Calendar.getInstance();
		return date;
	}
	
	
//	@Bean
//	public Connection getConnection() {
//		Connection connection = null;
//		try {
//			Class.forName("com.mysql.jdbc.Driver");
//			connection = DriverManager.getConnection(
//					"jdbc.odbc.localhost:3306//InformationTable", 
//					"dxy", 
//					"dxy222926");
//		} catch (ClassNotFoundException e) {
//			e.printStackTrace();
//		} catch (SQLException e) {
//			e.printStackTrace();
//		}
//		return connection;
//	}
	
}

还有一个包是为了处理依赖关系建的,代码也简单实现了发现并告知使用者存在的循环依赖!下次再介绍吧!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值