Spring源码:手写SpringIOC

一、分析

IOC是Inversion of Control(控制反转)的缩写。它是一种软件设计原则,用于解耦组件之间的依赖关系。

也就是依赖对象的获取被反转了,从原来我们自己创建对象-》从IOC容器中获取,由它来管理。

这样的好处是什么?

代码更加简洁,不用自己去new对象
面向接口编程:解耦,易扩展,替换实现类;方便进行AOP编程

那么,你有啥思路?

其实可以这么理解,IOC容器=Bean工厂,Beanfactory会对外提供bean实例,所以需要提供getBean()方法;那么你要什么样的Bean,得描述告诉Bean工厂吧,所以需要一个Bean定义信息BeanDefinition,告诉它应该创建什么对象;我们定义的这些BeanDefinition存在哪?就需要一个注册器BeanDefinitionRegistry去维护这些信息

二、实现

1、版本1:实现Bean注入IOC容器,并从容器中获取

1)定义BeanDefinition

描述我们的bean是要单例还是多例,是通过什么去创建(直接new,还是通过工厂类创建),初始化以及销毁方法

import com.baomidou.mybatisplus.core.toolkit.StringUtils;

/**
 * bean定义接口
 */
public interface BeanDefinition {

    String SCOPE_SINGLETON = "singleton";

    String SCOPE_PROTOTYPE = "prototype";

    /**
     * 类
     */
    Class<?> getBeanClass();
    void setBeanClass(Class<?> beanClass);

    /**
     * Scope
     */
    void setScope(String scope);
    String getScope();

    /**
     * 是否单例
     */
    boolean isSingleton();

    /**
     * 是否原型
     */
    boolean isPrototype();

    /**
     * 工厂bean名
     */
    String getFactoryBeanName();
    void setFactoryBeanName(String factoryBeanName);

    /**
     * 工厂方法名
     */
    String getFactoryMethodName();
    void setFactoryMethodName(String factoryMethodName);

    /**
     * 初始化方法
     */
    String getInitMethodName();
    void setInitMethodName(String initMethodName);

    /**
     * 销毁方法
     */
    String getDestroyMethodName();
    void setDestroyMethodName(String destroyMethodName);

    /**
     * 是否为主要自动候选对象
     */
    boolean isPrimary();
    void setPrimary(boolean primary);

    /**
     * 校验bean定义的合法性,BeanClass和FactoryMethodName只能存在一个,定义了FactoryMethodName,必须指定FactoryMethodName
     */
    default boolean validate() {
        // 没定义类
        if (this.getBeanClass() == null) {
            // 工厂bean或工厂方法都没定义,则不合法
            if (StringUtils.isBlank(getFactoryBeanName()) || StringUtils.isBlank(getFactoryMethodName())) {
                return false;
            }
        }

        // 定义了类,又定义工厂bean,不合法
        if (this.getBeanClass() != null && StringUtils.isNotBlank(getFactoryBeanName())) {
            return false;
        }

        return true;
    }

}

2)定义BeanDefinition实现类

定义一个通用实现类,实现BeanDefinition接口,对值的设置和获取

public class GenericBeanDefinition implements BeanDefinition {

	public static final String SCOPE_DEFAULT = "";

	private Class<?> beanClass;

	private String scope = SCOPE_DEFAULT;

	private String factoryBeanName;

	private String factoryMethodName;

	private String initMethodName;

	private String destroyMethodName;

	private boolean primary;

	@Override
	public Class<?> getBeanClass() {
		return beanClass;
	}

	@Override
	public void setBeanClass(Class<?> beanClass) {
		this.beanClass = beanClass;
	}

	@Override
	public String getScope() {
		return scope;
	}

	@Override
	public void setScope(String scope) {
		this.scope = scope;
	}

	/**
	 * 默认是单例
	 */
	@Override
	public boolean isSingleton() {
		return SCOPE_SINGLETON.equals(this.scope) || SCOPE_DEFAULT.equals(this.scope);
	}

	@Override
	public boolean isPrototype() {
		return SCOPE_PROTOTYPE.equals(this.scope);
	}

	@Override
	public String getFactoryBeanName() {
		return factoryBeanName;
	}

	@Override
	public void setFactoryBeanName(String factoryBeanName) {
		this.factoryBeanName = factoryBeanName;
	}

	@Override
	public String getFactoryMethodName() {
		return factoryMethodName;
	}

	@Override
	public void setFactoryMethodName(String factoryMethodName) {
		this.factoryMethodName = factoryMethodName;
	}

	@Override
	public String getInitMethodName() {
		return initMethodName;
	}

	@Override
	public void setInitMethodName(String initMethodName) {
		this.initMethodName = initMethodName;
	}

	@Override
	public String getDestroyMethodName() {
		return destroyMethodName;
	}

	@Override
	public void setDestroyMethodName(String destroyMethodName) {
		this.destroyMethodName = destroyMethodName;
	}

	@Override
	public boolean isPrimary() {
		return primary;
	}

	@Override
	public void setPrimary(boolean primary) {
		this.primary = primary;
	}
}

3)定义BeanDefinitionRegistry

public interface BeanDefinitionRegistry {

	void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception;

	BeanDefinition getBeanDefinition(String beanName);

	boolean containsBeanDefinition(String beanName);
}

4)定义Beanfactory

提供getBean方法,方便外部获取bean

public interface Beanfactory {

	Object getBean(String name) throws Exception;
}

5)定义默认Beanfactory实现类

实现BeanDefinitionRegistry和Beanfactory 接口,定义一个存储存储结构,存beanName和beanDefinition的映射关系;重写registerBeanDefinition和getBean等方法,定义了三种创建对象的方式:

  • 直接new:new BeanClass
  • 工厂静态方法:BeanClass.factoryMethodName()
  • 工厂bean对象调用方法:new FactoryBeanName().factoryMethodName()
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory {

	// 存储结构:存beanName和beanDefinition的映射关系
	protected Map<String, BeanDefinition> beanDefintionMap = new ConcurrentHashMap<>(256);

	@Override
	public void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception {
		Objects.requireNonNull(beanName, "注册bean需要提供beanName");
		Objects.requireNonNull(beanDefinition, "注册bean需要提供beanDefinition");

		// 校验给入的bean是否合法
		if (!beanDefinition.validate()) {
			throw new RuntimeException("名字为[" + beanName + "] 的bean定义不合法:" + beanDefinition);
		}

		// Spring中默认是不可覆盖(抛异常),可通过参数 spring.main.allow-bean-definition-overriding: true 来允许覆盖
		if (this.containsBeanDefinition(beanName)) {
			throw new RuntimeException("名字为[" + beanName + "] 的bean定义已存在:" + this.getBeanDefinition(beanName));
		}

		beanDefintionMap.put(beanName, beanDefinition);
	}

	@Override
	public BeanDefinition getBeanDefinition(String beanName) {
		return beanDefintionMap.get(beanName);
	}

	@Override
	public boolean containsBeanDefinition(String beanName) {
		return beanDefintionMap.containsKey(beanName);
	}

	@Override
	public Object getBean(String name) throws Exception {
		return this.doGetBean(name);
	}

	private Object doGetBean(String beanName) throws Exception {
		Objects.requireNonNull(beanName, "beanName不能为空");

		BeanDefinition bd = this.getBeanDefinition(beanName);
		Objects.requireNonNull(bd, "beanDefinition不能为空");

		Object instance = doCreateInstance(bd);
		return instance;
	}

	private Object doCreateInstance(BeanDefinition bd) throws Exception {
		Class<?> beanClass = bd.getBeanClass();
		Object instance = null;
		if (beanClass != null) {
			if (StringUtils.isBlank(bd.getFactoryMethodName())) {
				// 构造方法来构造对象
				instance = this.createInstanceByConstructor(bd);
			} else {
				// 静态工厂方法
				instance = this.createInstanceByStaticFactoryMethod(bd);
			}
		} else {
			// 工厂bean方式来构造对象
			instance = this.createInstanceByFactoryBean(bd);
		}

		// 执行初始化方法
		this.doInit(bd, instance);

		return instance;
	}

	// 构造方法来构造对象:new BeanClass()
	private Object createInstanceByConstructor(BeanDefinition bd) throws Exception {
		try {
			Class<?> beanClass = bd.getBeanClass();
			return beanClass.newInstance();
		} catch (SecurityException e) {
			log.error("创建bean的实例异常,beanDefinition:" + bd, e);
			throw e;
		}
	}

	// 静态工厂方法:BeanClass.factoryMethodName()
	private Object createInstanceByStaticFactoryMethod(BeanDefinition bd) throws Exception {
		Class<?> beanClass = bd.getBeanClass();
		Method m = beanClass.getMethod(bd.getFactoryMethodName(), null);
		return m.invoke(beanClass, null);
	}

	// 工厂bean方式来构造对象:new FactoryBeanName().factoryMethodName()
	private Object createInstanceByFactoryBean(BeanDefinition bd) throws Exception {
		Object factoryBean = this.doGetBean(bd.getFactoryBeanName());
		Class<?> factoryBeanClass = factoryBean.getClass();
		Method m = factoryBeanClass.getMethod(bd.getFactoryMethodName(), null);
		return m.invoke(factoryBean, null);
	}

	private void doInit(BeanDefinition bd, Object instance) throws Exception {
		// 执行初始化方法
		if (StringUtils.isNotBlank(bd.getInitMethodName())) {
			Method m = instance.getClass().getMethod(bd.getInitMethodName(), null);
			m.invoke(instance, null);
		}
	}
}

2、版本2:新增工厂关闭方法和支持单例bean

主要调整DefaultBeanFactory,实现了Closeable,重写了close方法,只针对单例Bean做了销毁处理,原型不用去销毁
,不知道会创建多少个对象,也得不到这些对象,不用去管,调整doGetBean方法支持单例模式。

import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import java.io.Closeable;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;


@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory, Closeable {
    
    // 单例存储结构:存beanName和对应实例的映射关系
	private Map<String, Object> singletonBeanMap = new ConcurrentHashMap<>(256);

	// 省去其它代码...
	
	private Object doGetBean(String beanName) throws Exception {
		Objects.requireNonNull(beanName, "beanName不能为空");

		Object instance = singletonBeanMap.get(beanName);
		if (instance != null) {
			return instance;
		}

		BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
		Objects.requireNonNull(beanDefinition, "beanDefinition不能为空");

		if (beanDefinition.isSingleton()) {
			synchronized (singletonBeanMap) {
				instance = singletonBeanMap.get(beanName);
				if (instance == null) {
					instance = doCreateInstance(beanDefinition);
					singletonBeanMap.put(beanName, instance);
				}
			}

		} else {
			instance = doCreateInstance(beanDefinition);
		}
		return instance;
	}

	@Override
	public void close() {
		for (Map.Entry<String, BeanDefinition> beanDefinitionEntry : beanDefintionMap.entrySet()) {
			String beanName = beanDefinitionEntry.getKey();
			BeanDefinition beanDefinition = beanDefinitionEntry.getValue();
			if (beanDefinition.isSingleton() && singletonBeanMap.containsKey(beanName)) {
				Object instance = this.singletonBeanMap.get(beanName);
				try {
					Method m = instance.getClass().getMethod(beanDefinition.getDestroyMethodName(), null);
					m.invoke(instance, null);
				} catch (Exception e) {
					log.error("执行名字为[" + beanName + "] 的bean销毁方法异常", e);
				}
			}
		}
	}
}

额外扩展:实现一个预构建BeanFactory,可以在系统启动时,提前初始化

import lombok.extern.slf4j.Slf4j;
import java.util.Map;

@Slf4j
public class PreBuildBeanFactory extends DefaultBeanFactory{
	public void preInstantiateSingletons() throws Exception {
		synchronized (this.beanDefintionMap) {
			for (Map.Entry<String, BeanDefinition> entry : this.beanDefintionMap.entrySet()) {
				String name = entry.getKey();
				BeanDefinition beanDefinition = entry.getValue();
				if (beanDefinition.isSingleton()) {
					this.getBean(name);
					if (log.isDebugEnabled()) {
						log.debug("preInstantiate: name=" + name + " " + beanDefinition);
					}
				}
			}
		}
	}
}

3、版本3:支持获取指定名字的类型

主要调整Beanfactory和DefaultBeanFactory,Beanfactory新增一个getType方法,让子类去实现

public interface Beanfactory {
	Class<?> getType(String name) throws Exception;
}

重写getType方法

@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory, Closeable {
	@Override
	public Class<?> getType(String name) throws Exception {
		BeanDefinition beanDefinition = this.getBeanDefinition(name);
		Class<?> beanClass = beanDefinition.getBeanClass();
		if (beanClass != null) {
			// 如果是静态工厂方法,需要获取工厂创建的对象
			if (StringUtils.isNotBlank(beanDefinition.getFactoryMethodName())) {
				return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
			} else {
				return beanClass;
			}
		} else {
			// 其它情况是工厂方法的处理
			beanClass = getType(beanDefinition.getFactoryBeanName());
			return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
		}
	}
	
	@Override
	public BeanDefinition getBeanDefinition(String beanName) {
		return beanDefintionMap.get(beanName);
	}
}

4、版本4:获取指定类型的所有bean以及唯一bean

一个type可能对应多个name,使用的存储结构如下:

private Map<Class<?>, Set<String>> typeNameMap = new ConcurrentHashMap<>(256);

调整DefaultBeanFactory的registerBeanDefinition,新增registerTypeNameMap方法,实现映射类型和beanName集合,支持当前类,父类,以及实现的接口,在Spring中,当你注入一个子类时,它会自动注入该子类实现的接口,而不会自动注入其继承的父类。所以下面实现的registerSuperClassTypeNaemMap,是注入其继承的父类,可以选择不要。

@Override
public void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception {
	Objects.requireNonNull(beanName, "注册bean需要提供beanName");
	Objects.requireNonNull(beanDefinition, "注册bean需要提供beanDefinition");
	// 校验给入的bean是否合法
	if (!beanDefinition.validate()) {
		throw new RuntimeException("名字为[" + beanName + "] 的bean定义不合法:" + beanDefinition);
	}
	// Spring中默认是不可覆盖(抛异常),可通过参数 spring.main.allow-bean-definition-overriding: true 来允许覆盖
	if (this.containsBeanDefinition(beanName)) {
		throw new RuntimeException("名字为[" + beanName + "] 的bean定义已存在:" + this.getBeanDefinition(beanName));
	}
	beanDefintionMap.put(beanName, beanDefinition);
	registerTypeNameMap();
}

	public void registerTypeNameMap() throws Exception {
		for (String name : beanDefintionMap.keySet()) {
			Class<?> type = this.getType(name);
			// 注册本类
			this.registerTypeNameMap(type, name);
			// 注册父类
			this.registerSuperClassTypeNaemMap(type, name);
			// 注册实现的接口
			this.registerInterfaceTypeNaemMap(type, name);
		}
	}

	private void registerTypeNameMap(Class<?> type, String name) {
		Set<String> beanNames = typeNameMap.get(type);
		if (beanNames == null) {
			beanNames = new HashSet<>();
			typeNameMap.put(type, beanNames);
		}
		beanNames.add(name);
	}

	private void registerSuperClassTypeNaemMap(Class<?> type, String name) {
		Class<?> superclass = type.getSuperclass();
		if (superclass != null && !superclass.equals(Object.class)) {
			// 注册本类
			this.registerTypeNameMap(superclass, name);
			// 注册父类
			this.registerSuperClassTypeNaemMap(superclass, name);
			// 注册实现的接口
			this.registerInterfaceTypeNaemMap(superclass, name);
		}
	}

	private void registerInterfaceTypeNaemMap(Class<?> type, String name) {
	Class<?>[] interfaces = type.getInterfaces();
	for (Class<?> anInterface : interfaces) {
		this.registerTypeNameMap(anInterface, name);
		this.registerInterfaceTypeNaemMap(anInterface, name);
	}
}

Beanfactory新增两个方法如下:

public interface Beanfactory {

	<T> T getBean(Class<T> requiredType) throws Exception;

	<T> Map<String,T> getBeansOfType(Class<T> type)throws Exception;
}

DefaultBeanFactory 实现如下:

@Override
public <T> T getBean(Class<T> requiredType) throws Exception {
	Set<String> beanNames = typeNameMap.get(requiredType);
	if (null == beanNames) {
		return null;
	}
	if (beanNames.size() == 1) {
		String beanName = beanNames.iterator().next();
		return (T) this.getBean(beanName);
	} else {
		String primaryBeanName = null;
		for (String beanName : beanNames) {
			BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
			if (beanDefinition != null && beanDefinition.isPrimary()) {
				if (primaryBeanName == null) {
					primaryBeanName = beanName;
				} else {
					throw new RuntimeException(requiredType + "类存在多个Primary,无法确定唯一一个Bean");
				}
			}
		}
		if (primaryBeanName != null) {
			return (T) this.getBean(primaryBeanName);
		} else {
			throw new RuntimeException(requiredType + "类未找到对应的Bean");
		}
	}
}

@Override
public <T> Map<String, T> getBeansOfType(Class<T> type) throws Exception {
	Set<String> beanNames = typeNameMap.get(type);
	if (null == beanNames) {
		return null;
	}
	Map<String, T> nameBeanMap = new HashMap<String, T>();
	for (String beanName : beanNames) {
		nameBeanMap.put(beanName, (T) this.getBean(beanName));
	}
	return nameBeanMap;
}

5、版本5:支持Bean别名

待实现
在这里插入图片描述

三、最终完整版本

在这里插入图片描述

BeanDefinition

import com.baomidou.mybatisplus.core.toolkit.StringUtils;

/**
 * bean定义接口
 */
public interface BeanDefinition {

    String SCOPE_SINGLETON = "singleton";

    String SCOPE_PROTOTYPE = "prototype";

    /**
     * 类
     */
    Class<?> getBeanClass();
    void setBeanClass(Class<?> beanClass);

    /**
     * Scope
     */
    void setScope(String scope);
    String getScope();

    /**
     * 是否单例
     */
    boolean isSingleton();

    /**
     * 是否原型
     */
    boolean isPrototype();

    /**
     * 工厂bean名
     */
    String getFactoryBeanName();
    void setFactoryBeanName(String factoryBeanName);

    /**
     * 工厂方法名
     */
    String getFactoryMethodName();
    void setFactoryMethodName(String factoryMethodName);

    /**
     * 初始化方法
     */
    String getInitMethodName();
    void setInitMethodName(String initMethodName);

    /**
     * 销毁方法
     */
    String getDestroyMethodName();
    void setDestroyMethodName(String destroyMethodName);

    /**
     * 是否为主要自动候选对象
     */
    boolean isPrimary();
    void setPrimary(boolean primary);

    /**
     * 校验bean定义的合法性,BeanClass和FactoryMethodName只能存在一个,定义了FactoryMethodName,必须指定FactoryMethodName
     */
    default boolean validate() {
        // 没定义类
        if (this.getBeanClass() == null) {
            // 工厂bean或工厂方法都没定义,则不合法
            if (StringUtils.isBlank(getFactoryBeanName()) || StringUtils.isBlank(getFactoryMethodName())) {
                return false;
            }
        }

        // 定义了类,又定义工厂bean,不合法
        if (this.getBeanClass() != null && StringUtils.isNotBlank(getFactoryBeanName())) {
            return false;
        }

        return true;
    }

}

GenericBeanDefinition

public class GenericBeanDefinition implements BeanDefinition {

	public static final String SCOPE_DEFAULT = "";

	private Class<?> beanClass;

	private String scope = SCOPE_DEFAULT;

	private String factoryBeanName;

	private String factoryMethodName;

	private String initMethodName;

	private String destroyMethodName;

	private boolean primary;

	@Override
	public Class<?> getBeanClass() {
		return beanClass;
	}

	@Override
	public void setBeanClass(Class<?> beanClass) {
		this.beanClass = beanClass;
	}

	@Override
	public String getScope() {
		return scope;
	}

	@Override
	public void setScope(String scope) {
		this.scope = scope;
	}

	/**
	 * 默认是单例
	 */
	@Override
	public boolean isSingleton() {
		return SCOPE_SINGLETON.equals(this.scope) || SCOPE_DEFAULT.equals(this.scope);
	}

	@Override
	public boolean isPrototype() {
		return SCOPE_PROTOTYPE.equals(this.scope);
	}

	@Override
	public String getFactoryBeanName() {
		return factoryBeanName;
	}

	@Override
	public void setFactoryBeanName(String factoryBeanName) {
		this.factoryBeanName = factoryBeanName;
	}

	@Override
	public String getFactoryMethodName() {
		return factoryMethodName;
	}

	@Override
	public void setFactoryMethodName(String factoryMethodName) {
		this.factoryMethodName = factoryMethodName;
	}

	@Override
	public String getInitMethodName() {
		return initMethodName;
	}

	@Override
	public void setInitMethodName(String initMethodName) {
		this.initMethodName = initMethodName;
	}

	@Override
	public String getDestroyMethodName() {
		return destroyMethodName;
	}

	@Override
	public void setDestroyMethodName(String destroyMethodName) {
		this.destroyMethodName = destroyMethodName;
	}

	@Override
	public boolean isPrimary() {
		return primary;
	}

	@Override
	public void setPrimary(boolean primary) {
		this.primary = primary;
	}
}

BeanDefinitionRegistry

public interface BeanDefinitionRegistry {

	void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception;

	BeanDefinition getBeanDefinition(String beanName);

	boolean containsBeanDefinition(String beanName);
}

Beanfactory

import java.util.Map;

public interface Beanfactory {

	Object getBean(String name) throws Exception;

	Class<?> getType(String name) throws Exception;

	<T> T getBean(Class<T> requiredType) throws Exception;

	<T> Map<String,T> getBeansOfType(Class<T> type)throws Exception;
}

DefaultBeanFactory

import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;

import java.io.Closeable;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory, Closeable {

	// 存储结构:存beanName和beanDefinition的映射关系
	protected Map<String, BeanDefinition> beanDefintionMap = new ConcurrentHashMap<>(256);

	private Map<String, Object> singletonBeanMap = new ConcurrentHashMap<>(256);

	private Map<Class<?>, Set<String>> typeNameMap = new ConcurrentHashMap<>(256);

	@Override
	public void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception {
		Objects.requireNonNull(beanName, "注册bean需要提供beanName");
		Objects.requireNonNull(beanDefinition, "注册bean需要提供beanDefinition");

		// 校验给入的bean是否合法
		if (!beanDefinition.validate()) {
			throw new RuntimeException("名字为[" + beanName + "] 的bean定义不合法:" + beanDefinition);
		}

		// Spring中默认是不可覆盖(抛异常),可通过参数 spring.main.allow-bean-definition-overriding: true 来允许覆盖
		if (this.containsBeanDefinition(beanName)) {
			throw new RuntimeException("名字为[" + beanName + "] 的bean定义已存在:" + this.getBeanDefinition(beanName));
		}

		beanDefintionMap.put(beanName, beanDefinition);
		registerTypeNameMap();
	}

	public void registerTypeNameMap() throws Exception {
		for (String name : beanDefintionMap.keySet()) {
			Class<?> type = this.getType(name);
			// 注册本类
			this.registerTypeNameMap(type, name);
			// 注册父类:Spring不支持,我们这里也模拟注释掉,不开放
			// this.registerSuperClassTypeNaemMap(type, name);
			// 注册实现的接口
			this.registerInterfaceTypeNaemMap(type, name);
		}
	}

	private void registerTypeNameMap(Class<?> type, String name) {
		Set<String> beanNames = typeNameMap.get(type);
		if (beanNames == null) {
			beanNames = new HashSet<>();
			typeNameMap.put(type, beanNames);
		}
		beanNames.add(name);
	}

	private void registerSuperClassTypeNaemMap(Class<?> type, String name) {
		Class<?> superclass = type.getSuperclass();
		if (superclass != null && !superclass.equals(Object.class)) {
			// 注册本类
			this.registerTypeNameMap(superclass, name);
			// 注册父类
			this.registerSuperClassTypeNaemMap(superclass, name);
			// 注册实现的接口
			this.registerInterfaceTypeNaemMap(superclass, name);
		}
	}

	private void registerInterfaceTypeNaemMap(Class<?> type, String name) {
		Class<?>[] interfaces = type.getInterfaces();
		for (Class<?> anInterface : interfaces) {
			this.registerTypeNameMap(anInterface, name);
			this.registerInterfaceTypeNaemMap(anInterface, name);
		}
	}

	@Override
	public BeanDefinition getBeanDefinition(String beanName) {
		return beanDefintionMap.get(beanName);
	}

	@Override
	public boolean containsBeanDefinition(String beanName) {
		return beanDefintionMap.containsKey(beanName);
	}

	@Override
	public Object getBean(String name) throws Exception {
		return this.doGetBean(name);
	}

	@Override
	public Class<?> getType(String name) throws Exception {
		BeanDefinition beanDefinition = this.getBeanDefinition(name);
		Class<?> beanClass = beanDefinition.getBeanClass();
		if (beanClass != null) {
			// 如果是静态工厂方法,需要获取工厂创建的对象
			if (StringUtils.isNotBlank(beanDefinition.getFactoryMethodName())) {
				return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
			} else {
				return beanClass;
			}
		} else {
			// 其它情况是工厂方法的处理
			beanClass = getType(beanDefinition.getFactoryBeanName());
			return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
		}
	}

	@Override
	public <T> T getBean(Class<T> requiredType) throws Exception {
		Set<String> beanNames = typeNameMap.get(requiredType);
		if (null == beanNames) {
			return null;
		}
		if (beanNames.size() == 1) {
			String beanName = beanNames.iterator().next();
			return (T) this.getBean(beanName);
		} else {
			String primaryBeanName = null;
			for (String beanName : beanNames) {
				BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
				if (beanDefinition != null && beanDefinition.isPrimary()) {
					if (primaryBeanName == null) {
						primaryBeanName = beanName;
					} else {
						throw new RuntimeException(requiredType + "类存在多个Primary,无法确定唯一一个Bean");
					}
				}
			}
			if (primaryBeanName != null) {
				return (T) this.getBean(primaryBeanName);
			} else {
				throw new RuntimeException(requiredType + "类未找到对应的Bean");
			}
		}
	}

	@Override
	public <T> Map<String, T> getBeansOfType(Class<T> type) throws Exception {
		Set<String> beanNames = typeNameMap.get(type);
		if (null == beanNames) {
			return null;
		}
		Map<String, T> nameBeanMap = new HashMap<String, T>();
		for (String beanName : beanNames) {
			nameBeanMap.put(beanName, (T) this.getBean(beanName));
		}
		return nameBeanMap;
	}

	private Object doGetBean(String beanName) throws Exception {
		Objects.requireNonNull(beanName, "beanName不能为空");

		Object instance = singletonBeanMap.get(beanName);
		if (instance != null) {
			return instance;
		}

		BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
		Objects.requireNonNull(beanDefinition, "beanDefinition不能为空");

		if (beanDefinition.isSingleton()) {
			synchronized (singletonBeanMap) {
				instance = singletonBeanMap.get(beanName);
				if (instance == null) {
					instance = doCreateInstance(beanDefinition);
					singletonBeanMap.put(beanName, instance);
				}
			}

		} else {
			instance = doCreateInstance(beanDefinition);
		}
		return instance;
	}

	private Object doCreateInstance(BeanDefinition bd) throws Exception {
		Class<?> beanClass = bd.getBeanClass();
		Object instance = null;
		if (beanClass != null) {
			if (StringUtils.isBlank(bd.getFactoryMethodName())) {
				// 构造方法来构造对象
				instance = this.createInstanceByConstructor(bd);
			} else {
				// 静态工厂方法
				instance = this.createInstanceByStaticFactoryMethod(bd);
			}
		} else {
			// 工厂bean方式来构造对象
			instance = this.createInstanceByFactoryBean(bd);
		}

		// 执行初始化方法
		this.doInit(bd, instance);

		return instance;
	}

	// 构造方法来构造对象:new BeanClass()
	private Object createInstanceByConstructor(BeanDefinition bd) throws Exception {
		try {
			Class<?> beanClass = bd.getBeanClass();
			return beanClass.newInstance();
		} catch (SecurityException e) {
			log.error("创建bean的实例异常,beanDefinition:" + bd, e);
			throw e;
		}
	}

	// 静态工厂方法:BeanClass.factoryMethodName()
	private Object createInstanceByStaticFactoryMethod(BeanDefinition bd) throws Exception {
		Class<?> beanClass = bd.getBeanClass();
		Method m = beanClass.getMethod(bd.getFactoryMethodName(), null);
		return m.invoke(beanClass, null);
	}

	// 工厂bean方式来构造对象:new FactoryBeanName().factoryMethodName()
	private Object createInstanceByFactoryBean(BeanDefinition bd) throws Exception {
		Object factoryBean = this.doGetBean(bd.getFactoryBeanName());
		Class<?> factoryBeanClass = factoryBean.getClass();
		Method m = factoryBeanClass.getMethod(bd.getFactoryMethodName(), null);
		return m.invoke(factoryBean, null);
	}

	private void doInit(BeanDefinition bd, Object instance) throws Exception {
		// 执行初始化方法
		if (StringUtils.isNotBlank(bd.getInitMethodName())) {
			Method m = instance.getClass().getMethod(bd.getInitMethodName(), null);
			m.invoke(instance, null);
		}
	}

	@Override
	public void close() {
		for (Map.Entry<String, BeanDefinition> beanDefinitionEntry : beanDefintionMap.entrySet()) {
			String beanName = beanDefinitionEntry.getKey();
			BeanDefinition beanDefinition = beanDefinitionEntry.getValue();
			if (beanDefinition.isSingleton() && singletonBeanMap.containsKey(beanName)) {
				Object instance = this.singletonBeanMap.get(beanName);
				try {
					Method m = instance.getClass().getMethod(beanDefinition.getDestroyMethodName(), null);
					m.invoke(instance, null);
				} catch (Exception e) {
					log.error("执行名字为[" + beanName + "] 的bean销毁方法异常", e);
				}
			}
		}
	}
}

PreBuildBeanFactory

import lombok.extern.slf4j.Slf4j;
import java.util.Map;

@Slf4j
public class PreBuildBeanFactory extends DefaultBeanFactory{
	public void preInstantiateSingletons() throws Exception {
		synchronized (this.beanDefintionMap) {
			for (Map.Entry<String, BeanDefinition> entry : this.beanDefintionMap.entrySet()) {
				String name = entry.getKey();
				BeanDefinition beanDefinition = entry.getValue();
				if (beanDefinition.isSingleton()) {
					this.getBean(name);
					if (log.isDebugEnabled()) {
						log.debug("preInstantiate: name=" + name + " " + beanDefinition);
					}
				}
			}
		}
	}
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序员Forlan

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值