手写mini springIoc

背景

每次在工作中使用spring,总会被其流畅代码思路和迷人的架构模式以及设计模式吸引,也看过一阵子源码,但总觉得了解的是是而非,所以想着何不自己手写一个springIoc用来梳理巩固自己所学的知识,以及自我充电

spring核心

  • 控制反转 (IoC,Inversion of Control)(本篇文章实现)

传统的JAVA开发模式中,当需要一个对象时,我们使用new或者通过getInstance等直接或者间接调用构造方法创建一个对象,而在Spring开发模式中,Spring容器使用工厂模式为我们创建了所需要的对象,不需要我们自己去创建了,直接调用Spring提供的对象就可以了,这就是控制反转,相信我们在用spring的时候,用xml或者注解了解过

  • 面向切面编程(AOP)(后续实现)

在面向对象编程(OOP)中,我们将事务纵向抽成一个个的对象,而在面向切面编程中,我们将一个个的对象某些类似的方面横向抽成一个切面,对这个切面进行一些如权限控制,事务管理,日志记录等公用操作处理的过程,就是面向切面编程的思想。

面向切面编程也是spring非常具有特色的功能,在实际工作中也非常广泛应用,就比如之前在公司我用springAop和spring SPEL机制实现的一部分功能

前置环境准备

假如我是一个spring开发人员,我要实现一个ioc,我需要怎么做呢?

  • 需要一个解析获取用户定义包扫描路径 比如`component-scan`注解扫描路径,可以放在xml或者文件等等,这里放在 `applicationContext.xml`
  • 获取所有java.class文件,java编译好的文件都放在xxx/target/classes下,如果是测试包则放在xxx/target/test-classes下,所以我们需要递归获取这下面的文件
  • 解析处理文件全路径名称,根据系统环境区分`/`和'\',拼接类路径
  • 根据类路径,进行反射构造实例 class.forName(xxxx)
  • 根据自定义注解标识,哪些类需要构造实例
  • 构造实例后需要填充属性,注解标识属性实例构造填充到该实例
  • 创建一个存取构造好的bean的容器

搭建项目

maven依赖

创建一个maven项目,因为需要解析xml获取component-scan 导入`dom4j依赖`和`lombok依赖`以及相关log日志依赖

 <dependencies>

        <dependency>
            <groupId>org.apache.logging.log4j</groupId>
            <artifactId>log4j-api</artifactId>
            <version>2.17.1</version>
        </dependency>
        <!-- Log4j Core -->
        <dependency>
            <groupId>org.apache.logging.log4j</groupId>
            <artifactId>log4j-core</artifactId>
            <version>2.17.1</version>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.30</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/org.dom4j/dom4j -->
        <dependency>
            <groupId>org.dom4j</groupId>
            <artifactId>dom4j</artifactId>
            <version>2.1.4</version>
        </dependency>

    </dependencies>

包结构

  • com.xiaohu.springioc下面的包业务代码,写控制层,service层代码
  • org.springframework.xx下面的包为自己简易的写spring相关ioc

编码

实现bean容器

package org.springframework.container;

import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.*;
import org.springframework.xml.XmlParser;

import java.io.File;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * @Version 1.0
 * @Author xiaohugg
 * @Description ClassPathXmlApplicationContext
 * @Date 2023/11/29 11:27
 **/
public class ClassPathXmlApplicationContext {

    private static final Logger logger = LogManager.getLogger(ClassPathXmlApplicationContext.class);

    /**
     * spring的ioc名字作为key
     */
    private final Map<String, Object> iocNameContainer = new ConcurrentHashMap<>();
    /**
     * spring的class作为key
     */
    private final Map<Class<?>, Object> iocClassContainer = new ConcurrentHashMap<>();

    /**
     * 根据接口,获取接口下的实现类
     * 类似 context.getBean(UserService.class)
     */
    private final Map<Class<?>, List<Object>> iocInterfacesContainer = new ConcurrentHashMap<>();

    private final Set<String> classFiles = new HashSet<>();

    private final String xmlPath;

    public ClassPathXmlApplicationContext(String xmlPath) {
        this.xmlPath = xmlPath;
        refresh();
    }

    @SneakyThrows
    private void refresh() {
        //解析componentScanPath 包扫描路径
        String componentScanPath = XmlParser.parse(xmlPath);

        //获取包扫描路径的class文件路径
        File file = findClassPath(componentScanPath);

        //获取.class文件结尾的包全路径名
        findClassFiles(file, componentScanPath, classFiles);

        //反射
        newInstance(classFiles);

        //实现对象的属性的依赖注入
        doDI();

        logger.fatal("iocNameContainer {}", iocNameContainer);
        logger.fatal("iocClassContainer {}", iocClassContainer);
        logger.fatal("iocInterfacesContainer {}", iocInterfacesContainer);
    }

    private void doDI() {
        Set<Map.Entry<Class<?>, Object>> entries = iocClassContainer.entrySet();
        entries.forEach(it -> {
            Class<?> aClass = it.getKey();
            Field[] declaredFields = aClass.getDeclaredFields();
            Set<Field> hasAutowiredField = Arrays.stream(declaredFields).filter(field -> field.isAnnotationPresent(Autowired.class)).collect(Collectors.toSet());
            hasAutowiredField.forEach(field -> {
                //依赖注入属性
                Autowired annotation = field.getAnnotation(Autowired.class);
                String value = annotation.value();
                Object bean;
                if ("".equals(value)) {
                    //默认按类型获取
                    Class<?> type = field.getType();
                    bean = getBean(type);
                    if (Objects.isNull(bean)) {
                        throw new IllegalStateException("获取不到 bean: " + type.getName());
                    }
                } else {
                    //按用户填写的beanName获取
                    bean = iocNameContainer.getOrDefault(value, new IllegalArgumentException("找不到beanName: " + value));
                }
                try {
                    field.setAccessible(true);
                    field.set(iocClassContainer.get(aClass), bean);
                } catch (IllegalAccessException e) {
                    logger.error("属性注入失败 {}", e.getMessage());
                }
            });
        });
    }

    private static File findClassPath(String componentScanPath) {
        String path = Objects.requireNonNull(Thread.currentThread().getContextClassLoader().getResource("")).getPath();
        String url = path + componentScanPath.replace(".", File.separator);
        // windows环境去除路径前面的 '/'
        if (System.getProperty("os.name").toLowerCase().contains("win")) {
            url = url.replaceFirst("/", "");
        }
        if (url.contains("test-classes")) {
            url = url.replace("test-classes", "classes");
        }
        return new File(url);
    }

    public static String getBeanName(Class<?> c) {
        try {
            Annotation annotation = c.getAnnotations()[0];
            Method valueMethod = annotation.annotationType().getDeclaredMethod("value");
            String value = (String) valueMethod.invoke(annotation);
            if (value != null && !value.isEmpty()) {
                return value;
            }
        } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
            // 处理异常: 可能是注解没有value()方法,或者其他反射调用错误
            logger.error("获取beanName 失败 {}", e.getMessage());
        }
        //没指定beanName 默认用类型首字母小写
        return Character.toLowerCase(c.getSimpleName().charAt(0)) + c.getSimpleName().substring(1);
    }

    private void putIoc(Class<?>[] interfaces, Object instance, String beanName, Class<?> c) {
        for (Class<?> anInterface : interfaces) {
            iocInterfacesContainer.computeIfAbsent(anInterface, k -> new ArrayList<>()).add(instance);
        }
        iocNameContainer.compute(beanName, (key, value) -> {
            if (value != null) {
                throw new IllegalStateException("Bean with name '" + beanName + "' already exists.");
            }
            return instance;
        });
        iocClassContainer.compute(c, (key, value) -> {
            if (value != null) {
                throw new IllegalStateException("Bean with class name '" + c.getSimpleName() + "' already exists.");
            }
            return instance;
        });
    }

    public Object getBean(String beanName) {
        return iocNameContainer.getOrDefault(beanName, null);
    }

    public <T> T getBean(Class<T> clazz) {
        //首先根据class获取,获取不到再通过接口获取
        if (iocClassContainer.containsKey(clazz)) {
            return clazz.cast(iocClassContainer.get(clazz));
        }
        List<Object> computed = iocInterfacesContainer.compute(clazz, (key, value) -> {
            if (value == null || value.isEmpty()) {
                return null;
            }
            if (value.size() > 1) {
                throw new IllegalArgumentException("只能获取到一个bean 但是获取到了 " + value.size() + "个相同类型的bean");
            }
            return value;
        });
        return computed == null ? null : clazz.cast(computed.get(0));
    }

    private void newInstance(Set<String> classFiles) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
        for (String classFile : classFiles) {
            try {
                classFile = classFile.replace(File.separator, ".").replace(".class", "");
                Class<?> c = Class.forName(classFile);
                Annotation[] annotations = new Annotation[]{c.getAnnotation(Component.class), c.getAnnotation(Controller.class),
                        c.getAnnotation(Service.class), c.getAnnotation(Repository.class)};
                if (Arrays.stream(annotations).anyMatch(Objects::nonNull)) {
                    String beanName = getBeanName(c);
                    Object instance = c.newInstance();
                    Class<?>[] interfaces = c.getInterfaces();
                    putIoc(interfaces, instance, beanName, c);
                }
            } catch (Exception e) {
                logger.error("构造bean失败 失败原因 {}", e.getMessage());
                throw e;
            }
        }
    }

    private void findClassFiles(File classFiles, String componentScanPath, Set<String> classNameList) {
        File[] files = classFiles.listFiles();
        if (files != null) {
            for (File file : files) {
                if (file.isFile() && file.getName().endsWith(".class")) {
                    // 如果是.class文件,添加到列表
                    String fullPath = file.getAbsolutePath();
                    int index = fullPath.indexOf(componentScanPath.replace(".", File.separator));
                    if (index != -1) {
                        String filePath = fullPath.substring(index);
                        classNameList.add(filePath);
                    }
                } else if (file.isDirectory()) {
                    // 如果是目录,递归调用
                    findClassFiles(file, componentScanPath, classNameList);
                }
            }
        }
    }
}

注解标识

package org.springframework.stereotype;

import java.lang.annotation.*;

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {

     String value() default "";
}
package org.springframework.stereotype;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Controller {
    String value() default "";
}
package org.springframework.stereotype;

import java.lang.annotation.*;

/**
 * @Version 1.0
 * @Author xiaohugg
 * @Description Repository
 * @Date 2023/11/30 11:50
 **/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Repository {

    String value() default "";
}
package org.springframework.stereotype;

import java.lang.annotation.*;

/**
 * @Version 1.0
 * @Author xiaohugg
 * @Description Service
 * @Date 2023/11/29 11:14
 **/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {

    String value() default "";
}

!!! 每个注解都有一个value的方法,代表beanName 如果用户填了,则根据用户填写的value获取bean,否则则根据类的名称获取,首字母小写

解析xml的parse

package org.springframework.xml;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.dom4j.Attribute;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;

import java.io.InputStream;

/**
 * @Version 1.0
 * @Author xiaohugg
 * @Description XmlParser
 * @Date 2023/11/29 11:47
 **/
public class XmlParser {
    private static final Logger logger = LogManager.getLogger(XmlParser.class);
    private XmlParser() {
    }

    public static String parse(String path) {
        try (InputStream inputStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(path)) {
            SAXReader saxReader = SAXReader.createDefault();
            Document document = saxReader.read(inputStream);
            Element rootElement = document.getRootElement();
            Element element = rootElement.element("component-scan");
            Attribute basePackage = element.attribute("base-package");
            return basePackage.getText();
        } catch (Exception e) {
            logger.error("解析xml失败 {}",e.getMessage());
           throw new IllegalArgumentException("解析错误");
        }
    }
}

applicationContext.xml

填写注解扫描包路径

<?xml version="1.0" encoding="UTF-8" ?>
<beans>
    <!--扫描包-->
    <component-scan base-package="com.xiaohu"/>
</beans>

测试

package com.xiaohu.springioc;

import com.xiaohu.springioc.controller.EmployeesController;
import org.springframework.container.ClassPathXmlApplicationContext;

/**
 * @Version 1.0
 * @Author huqiang
 * @Description MainTest
 * @Date 2023/11/29 11:44
 **/
public class MainTest {

    public static void main(String[] args) {
        ClassPathXmlApplicationContext classPathXmlApplicationContext = new ClassPathXmlApplicationContext("applicationContext.xml");
        //通过类型获取
        EmployeesController employeesController = classPathXmlApplicationContext.getBean(EmployeesController.class);
        employeesController.findEmployees();

        System.out.println("======================================================");
        //通过名称获取
        EmployeesController bean = (EmployeesController)classPathXmlApplicationContext.getBean("oc");
        bean.findEmployees();
        bean.selectById(1);

    }
}

可以看到相关的类已经构造好了

注入两个相同类型的bean

往往我们在使用spring的时候,注入接口,其下有多个实现类,往往会提示注入多个bean,在这次代码中,也实现了

    public <T> T getBean(Class<T> clazz) {
        //首先根据class获取,获取不到再通过接口获取
        if (iocClassContainer.containsKey(clazz)) {
            return clazz.cast(iocClassContainer.get(clazz));
        }
        List<Object> computed = iocInterfacesContainer.compute(clazz, (key, value) -> {
            if (value == null || value.isEmpty()) {
                return null;
            }
            if (value.size() > 1) {
                throw new IllegalArgumentException("只能获取到一个bean 但是获取到了 " + value.size() + "个相同类型的bean");
            }
            return value;
        });
        return computed == null ? null : clazz.cast(computed.get(0));
    }

比如 我现在有个 `EmployeesService`下面有两个实现类 `EmployeesServiceImpl`、`EmployeesService1`

如果我在controller层,不根据beanName获取,则根据class获取和接口获取,就会出现找到多个bean异常

根据bean名称获取

在Autowored注解填充需要注入的beanName

可以看到注入的是impl

结论

从这个简单的代码中,基本实现了一个简单的ioc控制反转,实现了属性实例的传递,这个并不能解决循环依赖的问题,以及spring三级缓存扩展,bean的生命周期,相关前置后置处理扩展点等等,所以相当于学习了解的作用,多看看源码学习,多学习源码架构思想和精髓

代码地址:  spring-ioc: 手写简易的spring

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值