一、分析
IOC是Inversion of Control(控制反转)的缩写。它是一种软件设计原则,用于解耦组件之间的依赖关系。
也就是依赖对象的获取被反转了,从原来我们自己创建对象-》从IOC容器中获取,由它来管理。
这样的好处是什么?
代码更加简洁,不用自己去new对象
面向接口编程:解耦,易扩展,替换实现类;方便进行AOP编程
那么,你有啥思路?
其实可以这么理解,IOC容器=Bean工厂,Beanfactory会对外提供bean实例,所以需要提供getBean()方法;那么你要什么样的Bean,得描述告诉Bean工厂吧,所以需要一个Bean定义信息BeanDefinition,告诉它应该创建什么对象;我们定义的这些BeanDefinition存在哪?就需要一个注册器BeanDefinitionRegistry去维护这些信息
二、实现
1、版本1:实现Bean注入IOC容器,并从容器中获取
1)定义BeanDefinition
描述我们的bean是要单例还是多例,是通过什么去创建(直接new,还是通过工厂类创建),初始化以及销毁方法
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
/**
* bean定义接口
*/
public interface BeanDefinition {
String SCOPE_SINGLETON = "singleton";
String SCOPE_PROTOTYPE = "prototype";
/**
* 类
*/
Class<?> getBeanClass();
void setBeanClass(Class<?> beanClass);
/**
* Scope
*/
void setScope(String scope);
String getScope();
/**
* 是否单例
*/
boolean isSingleton();
/**
* 是否原型
*/
boolean isPrototype();
/**
* 工厂bean名
*/
String getFactoryBeanName();
void setFactoryBeanName(String factoryBeanName);
/**
* 工厂方法名
*/
String getFactoryMethodName();
void setFactoryMethodName(String factoryMethodName);
/**
* 初始化方法
*/
String getInitMethodName();
void setInitMethodName(String initMethodName);
/**
* 销毁方法
*/
String getDestroyMethodName();
void setDestroyMethodName(String destroyMethodName);
/**
* 是否为主要自动候选对象
*/
boolean isPrimary();
void setPrimary(boolean primary);
/**
* 校验bean定义的合法性,BeanClass和FactoryMethodName只能存在一个,定义了FactoryMethodName,必须指定FactoryMethodName
*/
default boolean validate() {
// 没定义类
if (this.getBeanClass() == null) {
// 工厂bean或工厂方法都没定义,则不合法
if (StringUtils.isBlank(getFactoryBeanName()) || StringUtils.isBlank(getFactoryMethodName())) {
return false;
}
}
// 定义了类,又定义工厂bean,不合法
if (this.getBeanClass() != null && StringUtils.isNotBlank(getFactoryBeanName())) {
return false;
}
return true;
}
}
2)定义BeanDefinition实现类
定义一个通用实现类,实现BeanDefinition接口,对值的设置和获取
public class GenericBeanDefinition implements BeanDefinition {
public static final String SCOPE_DEFAULT = "";
private Class<?> beanClass;
private String scope = SCOPE_DEFAULT;
private String factoryBeanName;
private String factoryMethodName;
private String initMethodName;
private String destroyMethodName;
private boolean primary;
@Override
public Class<?> getBeanClass() {
return beanClass;
}
@Override
public void setBeanClass(Class<?> beanClass) {
this.beanClass = beanClass;
}
@Override
public String getScope() {
return scope;
}
@Override
public void setScope(String scope) {
this.scope = scope;
}
/**
* 默认是单例
*/
@Override
public boolean isSingleton() {
return SCOPE_SINGLETON.equals(this.scope) || SCOPE_DEFAULT.equals(this.scope);
}
@Override
public boolean isPrototype() {
return SCOPE_PROTOTYPE.equals(this.scope);
}
@Override
public String getFactoryBeanName() {
return factoryBeanName;
}
@Override
public void setFactoryBeanName(String factoryBeanName) {
this.factoryBeanName = factoryBeanName;
}
@Override
public String getFactoryMethodName() {
return factoryMethodName;
}
@Override
public void setFactoryMethodName(String factoryMethodName) {
this.factoryMethodName = factoryMethodName;
}
@Override
public String getInitMethodName() {
return initMethodName;
}
@Override
public void setInitMethodName(String initMethodName) {
this.initMethodName = initMethodName;
}
@Override
public String getDestroyMethodName() {
return destroyMethodName;
}
@Override
public void setDestroyMethodName(String destroyMethodName) {
this.destroyMethodName = destroyMethodName;
}
@Override
public boolean isPrimary() {
return primary;
}
@Override
public void setPrimary(boolean primary) {
this.primary = primary;
}
}
3)定义BeanDefinitionRegistry
public interface BeanDefinitionRegistry {
void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception;
BeanDefinition getBeanDefinition(String beanName);
boolean containsBeanDefinition(String beanName);
}
4)定义Beanfactory
提供getBean方法,方便外部获取bean
public interface Beanfactory {
Object getBean(String name) throws Exception;
}
5)定义默认Beanfactory实现类
实现BeanDefinitionRegistry和Beanfactory 接口,定义一个存储存储结构,存beanName和beanDefinition的映射关系;重写registerBeanDefinition和getBean等方法,定义了三种创建对象的方式:
- 直接new:new BeanClass
- 工厂静态方法:BeanClass.factoryMethodName()
- 工厂bean对象调用方法:new FactoryBeanName().factoryMethodName()
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory {
// 存储结构:存beanName和beanDefinition的映射关系
protected Map<String, BeanDefinition> beanDefintionMap = new ConcurrentHashMap<>(256);
@Override
public void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception {
Objects.requireNonNull(beanName, "注册bean需要提供beanName");
Objects.requireNonNull(beanDefinition, "注册bean需要提供beanDefinition");
// 校验给入的bean是否合法
if (!beanDefinition.validate()) {
throw new RuntimeException("名字为[" + beanName + "] 的bean定义不合法:" + beanDefinition);
}
// Spring中默认是不可覆盖(抛异常),可通过参数 spring.main.allow-bean-definition-overriding: true 来允许覆盖
if (this.containsBeanDefinition(beanName)) {
throw new RuntimeException("名字为[" + beanName + "] 的bean定义已存在:" + this.getBeanDefinition(beanName));
}
beanDefintionMap.put(beanName, beanDefinition);
}
@Override
public BeanDefinition getBeanDefinition(String beanName) {
return beanDefintionMap.get(beanName);
}
@Override
public boolean containsBeanDefinition(String beanName) {
return beanDefintionMap.containsKey(beanName);
}
@Override
public Object getBean(String name) throws Exception {
return this.doGetBean(name);
}
private Object doGetBean(String beanName) throws Exception {
Objects.requireNonNull(beanName, "beanName不能为空");
BeanDefinition bd = this.getBeanDefinition(beanName);
Objects.requireNonNull(bd, "beanDefinition不能为空");
Object instance = doCreateInstance(bd);
return instance;
}
private Object doCreateInstance(BeanDefinition bd) throws Exception {
Class<?> beanClass = bd.getBeanClass();
Object instance = null;
if (beanClass != null) {
if (StringUtils.isBlank(bd.getFactoryMethodName())) {
// 构造方法来构造对象
instance = this.createInstanceByConstructor(bd);
} else {
// 静态工厂方法
instance = this.createInstanceByStaticFactoryMethod(bd);
}
} else {
// 工厂bean方式来构造对象
instance = this.createInstanceByFactoryBean(bd);
}
// 执行初始化方法
this.doInit(bd, instance);
return instance;
}
// 构造方法来构造对象:new BeanClass()
private Object createInstanceByConstructor(BeanDefinition bd) throws Exception {
try {
Class<?> beanClass = bd.getBeanClass();
return beanClass.newInstance();
} catch (SecurityException e) {
log.error("创建bean的实例异常,beanDefinition:" + bd, e);
throw e;
}
}
// 静态工厂方法:BeanClass.factoryMethodName()
private Object createInstanceByStaticFactoryMethod(BeanDefinition bd) throws Exception {
Class<?> beanClass = bd.getBeanClass();
Method m = beanClass.getMethod(bd.getFactoryMethodName(), null);
return m.invoke(beanClass, null);
}
// 工厂bean方式来构造对象:new FactoryBeanName().factoryMethodName()
private Object createInstanceByFactoryBean(BeanDefinition bd) throws Exception {
Object factoryBean = this.doGetBean(bd.getFactoryBeanName());
Class<?> factoryBeanClass = factoryBean.getClass();
Method m = factoryBeanClass.getMethod(bd.getFactoryMethodName(), null);
return m.invoke(factoryBean, null);
}
private void doInit(BeanDefinition bd, Object instance) throws Exception {
// 执行初始化方法
if (StringUtils.isNotBlank(bd.getInitMethodName())) {
Method m = instance.getClass().getMethod(bd.getInitMethodName(), null);
m.invoke(instance, null);
}
}
}
2、版本2:新增工厂关闭方法和支持单例bean
主要调整DefaultBeanFactory,实现了Closeable,重写了close方法,只针对单例Bean做了销毁处理,原型不用去销毁
,不知道会创建多少个对象,也得不到这些对象,不用去管,调整doGetBean方法支持单例模式。
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import java.io.Closeable;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory, Closeable {
// 单例存储结构:存beanName和对应实例的映射关系
private Map<String, Object> singletonBeanMap = new ConcurrentHashMap<>(256);
// 省去其它代码...
private Object doGetBean(String beanName) throws Exception {
Objects.requireNonNull(beanName, "beanName不能为空");
Object instance = singletonBeanMap.get(beanName);
if (instance != null) {
return instance;
}
BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
Objects.requireNonNull(beanDefinition, "beanDefinition不能为空");
if (beanDefinition.isSingleton()) {
synchronized (singletonBeanMap) {
instance = singletonBeanMap.get(beanName);
if (instance == null) {
instance = doCreateInstance(beanDefinition);
singletonBeanMap.put(beanName, instance);
}
}
} else {
instance = doCreateInstance(beanDefinition);
}
return instance;
}
@Override
public void close() {
for (Map.Entry<String, BeanDefinition> beanDefinitionEntry : beanDefintionMap.entrySet()) {
String beanName = beanDefinitionEntry.getKey();
BeanDefinition beanDefinition = beanDefinitionEntry.getValue();
if (beanDefinition.isSingleton() && singletonBeanMap.containsKey(beanName)) {
Object instance = this.singletonBeanMap.get(beanName);
try {
Method m = instance.getClass().getMethod(beanDefinition.getDestroyMethodName(), null);
m.invoke(instance, null);
} catch (Exception e) {
log.error("执行名字为[" + beanName + "] 的bean销毁方法异常", e);
}
}
}
}
}
额外扩展:实现一个预构建BeanFactory,可以在系统启动时,提前初始化
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
@Slf4j
public class PreBuildBeanFactory extends DefaultBeanFactory{
public void preInstantiateSingletons() throws Exception {
synchronized (this.beanDefintionMap) {
for (Map.Entry<String, BeanDefinition> entry : this.beanDefintionMap.entrySet()) {
String name = entry.getKey();
BeanDefinition beanDefinition = entry.getValue();
if (beanDefinition.isSingleton()) {
this.getBean(name);
if (log.isDebugEnabled()) {
log.debug("preInstantiate: name=" + name + " " + beanDefinition);
}
}
}
}
}
}
3、版本3:支持获取指定名字的类型
主要调整Beanfactory和DefaultBeanFactory,Beanfactory新增一个getType方法,让子类去实现
public interface Beanfactory {
Class<?> getType(String name) throws Exception;
}
重写getType方法
@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory, Closeable {
@Override
public Class<?> getType(String name) throws Exception {
BeanDefinition beanDefinition = this.getBeanDefinition(name);
Class<?> beanClass = beanDefinition.getBeanClass();
if (beanClass != null) {
// 如果是静态工厂方法,需要获取工厂创建的对象
if (StringUtils.isNotBlank(beanDefinition.getFactoryMethodName())) {
return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
} else {
return beanClass;
}
} else {
// 其它情况是工厂方法的处理
beanClass = getType(beanDefinition.getFactoryBeanName());
return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
}
}
@Override
public BeanDefinition getBeanDefinition(String beanName) {
return beanDefintionMap.get(beanName);
}
}
4、版本4:获取指定类型的所有bean以及唯一bean
一个type可能对应多个name,使用的存储结构如下:
private Map<Class<?>, Set<String>> typeNameMap = new ConcurrentHashMap<>(256);
调整DefaultBeanFactory的registerBeanDefinition,新增registerTypeNameMap方法,实现映射类型和beanName集合,支持当前类,父类,以及实现的接口,在Spring中,当你注入一个子类时,它会自动注入该子类实现的接口,而不会自动注入其继承的父类。
所以下面实现的registerSuperClassTypeNaemMap,是注入其继承的父类,可以选择不要。
@Override
public void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception {
Objects.requireNonNull(beanName, "注册bean需要提供beanName");
Objects.requireNonNull(beanDefinition, "注册bean需要提供beanDefinition");
// 校验给入的bean是否合法
if (!beanDefinition.validate()) {
throw new RuntimeException("名字为[" + beanName + "] 的bean定义不合法:" + beanDefinition);
}
// Spring中默认是不可覆盖(抛异常),可通过参数 spring.main.allow-bean-definition-overriding: true 来允许覆盖
if (this.containsBeanDefinition(beanName)) {
throw new RuntimeException("名字为[" + beanName + "] 的bean定义已存在:" + this.getBeanDefinition(beanName));
}
beanDefintionMap.put(beanName, beanDefinition);
registerTypeNameMap();
}
public void registerTypeNameMap() throws Exception {
for (String name : beanDefintionMap.keySet()) {
Class<?> type = this.getType(name);
// 注册本类
this.registerTypeNameMap(type, name);
// 注册父类
this.registerSuperClassTypeNaemMap(type, name);
// 注册实现的接口
this.registerInterfaceTypeNaemMap(type, name);
}
}
private void registerTypeNameMap(Class<?> type, String name) {
Set<String> beanNames = typeNameMap.get(type);
if (beanNames == null) {
beanNames = new HashSet<>();
typeNameMap.put(type, beanNames);
}
beanNames.add(name);
}
private void registerSuperClassTypeNaemMap(Class<?> type, String name) {
Class<?> superclass = type.getSuperclass();
if (superclass != null && !superclass.equals(Object.class)) {
// 注册本类
this.registerTypeNameMap(superclass, name);
// 注册父类
this.registerSuperClassTypeNaemMap(superclass, name);
// 注册实现的接口
this.registerInterfaceTypeNaemMap(superclass, name);
}
}
private void registerInterfaceTypeNaemMap(Class<?> type, String name) {
Class<?>[] interfaces = type.getInterfaces();
for (Class<?> anInterface : interfaces) {
this.registerTypeNameMap(anInterface, name);
this.registerInterfaceTypeNaemMap(anInterface, name);
}
}
Beanfactory新增两个方法如下:
public interface Beanfactory {
<T> T getBean(Class<T> requiredType) throws Exception;
<T> Map<String,T> getBeansOfType(Class<T> type)throws Exception;
}
DefaultBeanFactory 实现如下:
@Override
public <T> T getBean(Class<T> requiredType) throws Exception {
Set<String> beanNames = typeNameMap.get(requiredType);
if (null == beanNames) {
return null;
}
if (beanNames.size() == 1) {
String beanName = beanNames.iterator().next();
return (T) this.getBean(beanName);
} else {
String primaryBeanName = null;
for (String beanName : beanNames) {
BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
if (beanDefinition != null && beanDefinition.isPrimary()) {
if (primaryBeanName == null) {
primaryBeanName = beanName;
} else {
throw new RuntimeException(requiredType + "类存在多个Primary,无法确定唯一一个Bean");
}
}
}
if (primaryBeanName != null) {
return (T) this.getBean(primaryBeanName);
} else {
throw new RuntimeException(requiredType + "类未找到对应的Bean");
}
}
}
@Override
public <T> Map<String, T> getBeansOfType(Class<T> type) throws Exception {
Set<String> beanNames = typeNameMap.get(type);
if (null == beanNames) {
return null;
}
Map<String, T> nameBeanMap = new HashMap<String, T>();
for (String beanName : beanNames) {
nameBeanMap.put(beanName, (T) this.getBean(beanName));
}
return nameBeanMap;
}
5、版本5:支持Bean别名
待实现
三、最终完整版本
BeanDefinition
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
/**
* bean定义接口
*/
public interface BeanDefinition {
String SCOPE_SINGLETON = "singleton";
String SCOPE_PROTOTYPE = "prototype";
/**
* 类
*/
Class<?> getBeanClass();
void setBeanClass(Class<?> beanClass);
/**
* Scope
*/
void setScope(String scope);
String getScope();
/**
* 是否单例
*/
boolean isSingleton();
/**
* 是否原型
*/
boolean isPrototype();
/**
* 工厂bean名
*/
String getFactoryBeanName();
void setFactoryBeanName(String factoryBeanName);
/**
* 工厂方法名
*/
String getFactoryMethodName();
void setFactoryMethodName(String factoryMethodName);
/**
* 初始化方法
*/
String getInitMethodName();
void setInitMethodName(String initMethodName);
/**
* 销毁方法
*/
String getDestroyMethodName();
void setDestroyMethodName(String destroyMethodName);
/**
* 是否为主要自动候选对象
*/
boolean isPrimary();
void setPrimary(boolean primary);
/**
* 校验bean定义的合法性,BeanClass和FactoryMethodName只能存在一个,定义了FactoryMethodName,必须指定FactoryMethodName
*/
default boolean validate() {
// 没定义类
if (this.getBeanClass() == null) {
// 工厂bean或工厂方法都没定义,则不合法
if (StringUtils.isBlank(getFactoryBeanName()) || StringUtils.isBlank(getFactoryMethodName())) {
return false;
}
}
// 定义了类,又定义工厂bean,不合法
if (this.getBeanClass() != null && StringUtils.isNotBlank(getFactoryBeanName())) {
return false;
}
return true;
}
}
GenericBeanDefinition
public class GenericBeanDefinition implements BeanDefinition {
public static final String SCOPE_DEFAULT = "";
private Class<?> beanClass;
private String scope = SCOPE_DEFAULT;
private String factoryBeanName;
private String factoryMethodName;
private String initMethodName;
private String destroyMethodName;
private boolean primary;
@Override
public Class<?> getBeanClass() {
return beanClass;
}
@Override
public void setBeanClass(Class<?> beanClass) {
this.beanClass = beanClass;
}
@Override
public String getScope() {
return scope;
}
@Override
public void setScope(String scope) {
this.scope = scope;
}
/**
* 默认是单例
*/
@Override
public boolean isSingleton() {
return SCOPE_SINGLETON.equals(this.scope) || SCOPE_DEFAULT.equals(this.scope);
}
@Override
public boolean isPrototype() {
return SCOPE_PROTOTYPE.equals(this.scope);
}
@Override
public String getFactoryBeanName() {
return factoryBeanName;
}
@Override
public void setFactoryBeanName(String factoryBeanName) {
this.factoryBeanName = factoryBeanName;
}
@Override
public String getFactoryMethodName() {
return factoryMethodName;
}
@Override
public void setFactoryMethodName(String factoryMethodName) {
this.factoryMethodName = factoryMethodName;
}
@Override
public String getInitMethodName() {
return initMethodName;
}
@Override
public void setInitMethodName(String initMethodName) {
this.initMethodName = initMethodName;
}
@Override
public String getDestroyMethodName() {
return destroyMethodName;
}
@Override
public void setDestroyMethodName(String destroyMethodName) {
this.destroyMethodName = destroyMethodName;
}
@Override
public boolean isPrimary() {
return primary;
}
@Override
public void setPrimary(boolean primary) {
this.primary = primary;
}
}
BeanDefinitionRegistry
public interface BeanDefinitionRegistry {
void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception;
BeanDefinition getBeanDefinition(String beanName);
boolean containsBeanDefinition(String beanName);
}
Beanfactory
import java.util.Map;
public interface Beanfactory {
Object getBean(String name) throws Exception;
Class<?> getType(String name) throws Exception;
<T> T getBean(Class<T> requiredType) throws Exception;
<T> Map<String,T> getBeansOfType(Class<T> type)throws Exception;
}
DefaultBeanFactory
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import java.io.Closeable;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class DefaultBeanFactory implements BeanDefinitionRegistry, Beanfactory, Closeable {
// 存储结构:存beanName和beanDefinition的映射关系
protected Map<String, BeanDefinition> beanDefintionMap = new ConcurrentHashMap<>(256);
private Map<String, Object> singletonBeanMap = new ConcurrentHashMap<>(256);
private Map<Class<?>, Set<String>> typeNameMap = new ConcurrentHashMap<>(256);
@Override
public void registerBeanDefinition(String beanName, BeanDefinition beanDefinition) throws Exception {
Objects.requireNonNull(beanName, "注册bean需要提供beanName");
Objects.requireNonNull(beanDefinition, "注册bean需要提供beanDefinition");
// 校验给入的bean是否合法
if (!beanDefinition.validate()) {
throw new RuntimeException("名字为[" + beanName + "] 的bean定义不合法:" + beanDefinition);
}
// Spring中默认是不可覆盖(抛异常),可通过参数 spring.main.allow-bean-definition-overriding: true 来允许覆盖
if (this.containsBeanDefinition(beanName)) {
throw new RuntimeException("名字为[" + beanName + "] 的bean定义已存在:" + this.getBeanDefinition(beanName));
}
beanDefintionMap.put(beanName, beanDefinition);
registerTypeNameMap();
}
public void registerTypeNameMap() throws Exception {
for (String name : beanDefintionMap.keySet()) {
Class<?> type = this.getType(name);
// 注册本类
this.registerTypeNameMap(type, name);
// 注册父类:Spring不支持,我们这里也模拟注释掉,不开放
// this.registerSuperClassTypeNaemMap(type, name);
// 注册实现的接口
this.registerInterfaceTypeNaemMap(type, name);
}
}
private void registerTypeNameMap(Class<?> type, String name) {
Set<String> beanNames = typeNameMap.get(type);
if (beanNames == null) {
beanNames = new HashSet<>();
typeNameMap.put(type, beanNames);
}
beanNames.add(name);
}
private void registerSuperClassTypeNaemMap(Class<?> type, String name) {
Class<?> superclass = type.getSuperclass();
if (superclass != null && !superclass.equals(Object.class)) {
// 注册本类
this.registerTypeNameMap(superclass, name);
// 注册父类
this.registerSuperClassTypeNaemMap(superclass, name);
// 注册实现的接口
this.registerInterfaceTypeNaemMap(superclass, name);
}
}
private void registerInterfaceTypeNaemMap(Class<?> type, String name) {
Class<?>[] interfaces = type.getInterfaces();
for (Class<?> anInterface : interfaces) {
this.registerTypeNameMap(anInterface, name);
this.registerInterfaceTypeNaemMap(anInterface, name);
}
}
@Override
public BeanDefinition getBeanDefinition(String beanName) {
return beanDefintionMap.get(beanName);
}
@Override
public boolean containsBeanDefinition(String beanName) {
return beanDefintionMap.containsKey(beanName);
}
@Override
public Object getBean(String name) throws Exception {
return this.doGetBean(name);
}
@Override
public Class<?> getType(String name) throws Exception {
BeanDefinition beanDefinition = this.getBeanDefinition(name);
Class<?> beanClass = beanDefinition.getBeanClass();
if (beanClass != null) {
// 如果是静态工厂方法,需要获取工厂创建的对象
if (StringUtils.isNotBlank(beanDefinition.getFactoryMethodName())) {
return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
} else {
return beanClass;
}
} else {
// 其它情况是工厂方法的处理
beanClass = getType(beanDefinition.getFactoryBeanName());
return beanClass.getDeclaredMethod(beanDefinition.getFactoryMethodName(), null).getReturnType();
}
}
@Override
public <T> T getBean(Class<T> requiredType) throws Exception {
Set<String> beanNames = typeNameMap.get(requiredType);
if (null == beanNames) {
return null;
}
if (beanNames.size() == 1) {
String beanName = beanNames.iterator().next();
return (T) this.getBean(beanName);
} else {
String primaryBeanName = null;
for (String beanName : beanNames) {
BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
if (beanDefinition != null && beanDefinition.isPrimary()) {
if (primaryBeanName == null) {
primaryBeanName = beanName;
} else {
throw new RuntimeException(requiredType + "类存在多个Primary,无法确定唯一一个Bean");
}
}
}
if (primaryBeanName != null) {
return (T) this.getBean(primaryBeanName);
} else {
throw new RuntimeException(requiredType + "类未找到对应的Bean");
}
}
}
@Override
public <T> Map<String, T> getBeansOfType(Class<T> type) throws Exception {
Set<String> beanNames = typeNameMap.get(type);
if (null == beanNames) {
return null;
}
Map<String, T> nameBeanMap = new HashMap<String, T>();
for (String beanName : beanNames) {
nameBeanMap.put(beanName, (T) this.getBean(beanName));
}
return nameBeanMap;
}
private Object doGetBean(String beanName) throws Exception {
Objects.requireNonNull(beanName, "beanName不能为空");
Object instance = singletonBeanMap.get(beanName);
if (instance != null) {
return instance;
}
BeanDefinition beanDefinition = this.getBeanDefinition(beanName);
Objects.requireNonNull(beanDefinition, "beanDefinition不能为空");
if (beanDefinition.isSingleton()) {
synchronized (singletonBeanMap) {
instance = singletonBeanMap.get(beanName);
if (instance == null) {
instance = doCreateInstance(beanDefinition);
singletonBeanMap.put(beanName, instance);
}
}
} else {
instance = doCreateInstance(beanDefinition);
}
return instance;
}
private Object doCreateInstance(BeanDefinition bd) throws Exception {
Class<?> beanClass = bd.getBeanClass();
Object instance = null;
if (beanClass != null) {
if (StringUtils.isBlank(bd.getFactoryMethodName())) {
// 构造方法来构造对象
instance = this.createInstanceByConstructor(bd);
} else {
// 静态工厂方法
instance = this.createInstanceByStaticFactoryMethod(bd);
}
} else {
// 工厂bean方式来构造对象
instance = this.createInstanceByFactoryBean(bd);
}
// 执行初始化方法
this.doInit(bd, instance);
return instance;
}
// 构造方法来构造对象:new BeanClass()
private Object createInstanceByConstructor(BeanDefinition bd) throws Exception {
try {
Class<?> beanClass = bd.getBeanClass();
return beanClass.newInstance();
} catch (SecurityException e) {
log.error("创建bean的实例异常,beanDefinition:" + bd, e);
throw e;
}
}
// 静态工厂方法:BeanClass.factoryMethodName()
private Object createInstanceByStaticFactoryMethod(BeanDefinition bd) throws Exception {
Class<?> beanClass = bd.getBeanClass();
Method m = beanClass.getMethod(bd.getFactoryMethodName(), null);
return m.invoke(beanClass, null);
}
// 工厂bean方式来构造对象:new FactoryBeanName().factoryMethodName()
private Object createInstanceByFactoryBean(BeanDefinition bd) throws Exception {
Object factoryBean = this.doGetBean(bd.getFactoryBeanName());
Class<?> factoryBeanClass = factoryBean.getClass();
Method m = factoryBeanClass.getMethod(bd.getFactoryMethodName(), null);
return m.invoke(factoryBean, null);
}
private void doInit(BeanDefinition bd, Object instance) throws Exception {
// 执行初始化方法
if (StringUtils.isNotBlank(bd.getInitMethodName())) {
Method m = instance.getClass().getMethod(bd.getInitMethodName(), null);
m.invoke(instance, null);
}
}
@Override
public void close() {
for (Map.Entry<String, BeanDefinition> beanDefinitionEntry : beanDefintionMap.entrySet()) {
String beanName = beanDefinitionEntry.getKey();
BeanDefinition beanDefinition = beanDefinitionEntry.getValue();
if (beanDefinition.isSingleton() && singletonBeanMap.containsKey(beanName)) {
Object instance = this.singletonBeanMap.get(beanName);
try {
Method m = instance.getClass().getMethod(beanDefinition.getDestroyMethodName(), null);
m.invoke(instance, null);
} catch (Exception e) {
log.error("执行名字为[" + beanName + "] 的bean销毁方法异常", e);
}
}
}
}
}
PreBuildBeanFactory
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
@Slf4j
public class PreBuildBeanFactory extends DefaultBeanFactory{
public void preInstantiateSingletons() throws Exception {
synchronized (this.beanDefintionMap) {
for (Map.Entry<String, BeanDefinition> entry : this.beanDefintionMap.entrySet()) {
String name = entry.getKey();
BeanDefinition beanDefinition = entry.getValue();
if (beanDefinition.isSingleton()) {
this.getBean(name);
if (log.isDebugEnabled()) {
log.debug("preInstantiate: name=" + name + " " + beanDefinition);
}
}
}
}
}
}