通过手写spring理解spring的设计思想以及IOC与DI的原理

只动手不动脑,越学越晕,既动手又动脑,事半功倍。

前言
Demo只是基于web框架简单阐述spring的设计思想,spring实际实现比较复杂,大家可以先从简单原理慢慢来。

设计思想

一、配置

  1. 配置web.xml 中的Servlet拦截器: DispatcherServlet
  2. 设定init-param:application.properties(以properties文件代替xml文件,实现简单模拟)
  3. 设置url-pattern: /*
  4. 配置annotation

二、初始化

  1. 调用init()方法 ,加载配置文件。
  2. IOC容器初始化。
  3. 扫描相关的类,添加了@Controller、@Service注解的类,将扫描的类放入缓存。
  4. 创建实例化并保存至容器 通过反射机制将类实例化并放入IOC容器中。
  5. 扫描IOC容器中的类,给没有实例化的类进行实例化。
  6. 将URL与Method的对应关系进行存储。

三、运行

  1. 调用doGet/doPost方法。
  2. 匹配HandlerMapping ,获取调用的URL去HandlerMapping中匹配相应的Method。
  3. 反射调用method.invoke() , 获取请求结果
  4. response.getWrite().write(), 将结果响应给浏览器

加上写代码前的文档整理截图,应该可以看得更清晰点。
在这里插入图片描述
Demo

项目模块截图
在这里插入图片描述
一、先创建自定义注解

NsController.java
package com.ns.anno;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 控制层类注解
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NsController {
    
    String value() default "";
}

NsService.java
package com.ns.anno;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NsService {

    String value() default "";
}

NsAutowired.java
package com.ns.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 默认只提供通过类型注入
 */
@Target({ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATION_TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface NsAutowired {

}

NsRequestMapping.java
package com.ns.anno;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 方法注解
 */
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NsRequestMapping {

    String value() default "";
}

NsRequestParam.java
package com.ns.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 参数注解,默认只有value的
 */
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
public @interface NsRequestParam {

    String value() default "";
}

二、编写Controller层代码

NsDemoController.java
package com.ns.controller;

import com.ns.anno.NsAutowired;
import com.ns.anno.NsController;
import com.ns.anno.NsRequestMapping;
import com.ns.anno.NsRequestParam;
import com.ns.api.NsDemoService;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@NsController
@NsRequestMapping("/ns/demo")
public class NsDemoController {

    @NsAutowired
    private NsDemoService nsDemoService;

    @NsRequestMapping("/select")
    public void select(HttpServletRequest request, HttpServletResponse response, @NsRequestParam("id") String id){
        String result = nsDemoService.select(id);
        try {
            response.getWriter().write(result);
        } catch (IOException e) {
            throw new RuntimeException("NsDemoController.select pass Exception ,detail info:"+e.getMessage());
        }
    }
}

三、编写service层代码

NsDemoService.java
package com.ns.api;

public interface NsDemoService {

    String select(String id);
}

四、编写service实现类代码

NsDemoServiceImpl.java
package com.ns.service;

import com.ns.anno.NsService;
import com.ns.api.NsDemoService;

@NsService
public class NsDemoServiceImpl implements NsDemoService {

    @Override
    public String select(String id) {
        if("1".equals(id)){
            return "11111111111";
        }else{
            return "1234567890";
        }
    }
}

五、添加静态资源文件

application.properties(只配置一个扫描包路径)
scan.package=com.ns

六、添加拦截器

NsDispatcherServlet.java(方法就不单独粘贴了,省的看不明白)
package com.ns.servlet;

import com.ns.anno.NsAutowired;
import com.ns.anno.NsController;
import com.ns.anno.NsRequestMapping;
import com.ns.anno.NsRequestParam;
import com.ns.anno.NsService;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 自定义拦截器
 * 重写父类的doGet方法,doPost方法,init方法
 */
public class NsDispatcherServlet extends HttpServlet {

    private Properties properties = new Properties();//缓存配置文件中的内容

    private List<String> classNames = new ArrayList<>();//存放所有扫描到的类对应的类路径

    private Map<String,Object> ioc = new HashMap<>();//IOC容器

    private List<Handler> handlerMapping = new ArrayList<>();// 存放url与method的映射关系


    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        //默认doGet方法调用doPost
        this.doPost(req,resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        try {
            this.doDispatch(req,resp);
        } catch (InvocationTargetException | IllegalAccessException e) {
            throw new RuntimeException("NsDispatcherServlet.doPost pass Exception,detail info:"+e.getMessage());
        }
    }

    //请求的具体方法
    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
        //遍历handlerMapping,获取该请求对应的方法
        Handler handler = getHandler(req);
        if(handler == null){
            resp.getWriter().write("404 Not Found");
            return;
        }
        //方法对应的形参类型列表
        Class<?>[] paramTypes = handler.method.getParameterTypes();
        //定义请求的实参列表,长度与形参列表长度一致
        Object[] paramValues = new Object[paramTypes.length];
        //获取方法的实参列表
        Map<String,String[]> params = req.getParameterMap();
        //遍历参数
        for (Map.Entry<String, String[]> entry : params.entrySet()) {
            //获取参数对应的值
            String value = entry.getValue()[0];
            if(handler.paramIndexMapping.containsKey(entry.getKey())){
                int index = handler.paramIndexMapping.get(entry.getKey());
                //实参对应的位置放入对应的值
                paramValues[index] = convertValue(paramTypes[index],value);
            }
        }
        if(handler.paramIndexMapping.containsKey(HttpServletRequest.class.getName())){
            int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
            paramValues[reqIndex] = req;
        }
        if(handler.paramIndexMapping.containsKey(HttpServletResponse.class.getName())){
            int respIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
            paramValues[respIndex] = resp;
        }
        //请求method的invoke方法
        Object returnValue = handler.method.invoke(handler.controller,paramValues);
        if(returnValue == null || returnValue instanceof Void){
            return;
        }
        resp.getWriter().write(returnValue.toString());
    }

    //类型转换,实参的值与形参的类型进行匹配,此处只列出两种
    private Object convertValue(Class<?> paramType, String value) {
        if(paramType == Integer.class){
            return Integer.parseInt(value);
        }else if(paramType == Double.class){
            return Double.valueOf(value);
        }
        return value;
    }

    //得到对应的handler映射关系
    private Handler getHandler(HttpServletRequest req) {
        if(handlerMapping.isEmpty()){
            return null;
        }
        String url = req.getRequestURI();//拿到请求的url
        String contextPath = req.getContextPath();
        url = url.replaceAll(contextPath,"");//防止路径中有空格
        //从所有的handlerMapping中遍历查找与当前请求路径对应的Handler
        for(Handler handler : handlerMapping){
            Matcher matcher = handler.url.matcher(url);
            if(matcher.matches()){
               return handler;
            }
        }
        return null;
    }

    /**
     * Servlet初始化(模板模式的具体体现)
     * 1.加载配置文件到内存
     * 2.获取扫描包下的所有类路径
     * 3.获取所有添加了@NsController,@NsService注解的类,进行初始化并放入IOC容器中 --> IOC初始化
     * 4.拿着IOC容器中的所有类,获取类中添加了@NsAutowired的声明对象,进行依赖注入  --> DI
     * 5.拿着IOC容器中所有的类,获取类中所有的public修饰的方法,筛选出添加了@NsRequestMapping注解的方法,
     *      设置url与method的对应关系 -- HandlerMapping
     * 6.初始化完成
     * @throws ServletException
     */
    @Override
    public void init(ServletConfig config) throws ServletException {
        //1.将扫描到的文件加载到内存
        doScanner(config);
        //2.获取扫描包下的所有类路径
        doLoadClassNames(properties.getProperty("scan.package"));
        //3.遍历classNames中的所有类路径,拿到类文件,判断类文件是否添加了@NsController,@NsService注解,如果添加了,就实例化放到IOC容器中
        doInitIoc();
        //4.对ioc中的类进行依赖注入
        doAutowired();
        //5.添加url与method的映射关系
        doHandlerMapping();
    }



    //将扫描到的文件内容加载到内存中
    private void doScanner(ServletConfig config) {
        try {
            //通过类加载器以流的方式将文件加载到内存中
            InputStream is = this.getClass().getClassLoader().getResourceAsStream(config.getInitParameter("applicationContext"));
            properties.load(is);
        } catch (FileNotFoundException e) {
            throw new RuntimeException("no filed found, detail info :"+e.getMessage());
        } catch (IOException e) {
            throw new RuntimeException("load field content pass Exception, detail info:"+e.getMessage());
        }
    }

    //获取扫描包下的所有类路径
    private void doLoadClassNames(String packageName) {
        //获取扫描包对应的Url路径(将类路径转换为文件路径,即将.替换为/ ---> 目的是为了方便查找文件)
        URL url = this.getClass().getClassLoader().getResource(packageName.replaceAll("\\.","/"));
        assert url != null;
        //列出该路径下的所有文件
        File[] files = new File(url.getFile()).listFiles();
        assert files != null;
        //遍历所有的文件
        for(File file : files){
            //如果该文件为文件夹,递归
            if(file.isDirectory()){
                doLoadClassNames(packageName+"."+file.getName());
            }else{
                //如果该文件不是以.class结尾,不用管这个文件
                if(!file.getName().endsWith(".class")){ continue; }
                //如果文件是以.class结尾,将该类路径放入到classNames中
                classNames.add((packageName+"."+file.getName()).replace(".class",""));
            }
        }
    }

    /**
     * IOC容器初始化
     * 以 @NsController @NsService注解为例
     */
    private void doInitIoc() {
        //只有classNames有类路径才需要初始化,没有则表示扫描的包下没有类文件(如果自己有疑问,请检查自己配置的扫描包是否正确,或者代码是否编译)
        if(!classNames.isEmpty()){
            try {
                for(String className : classNames) {
                    //反射获取类
                    Class<?> clazz = Class.forName(className);
                    //判断该类是否添加了@NsController注解
                    if (clazz.isAnnotationPresent(NsController.class)) {
                        //获取该类对应的类名
                        ioc.put(clazz.getSimpleName(), clazz.newInstance());
                    }
                    //判断该类是否添加了@NsService注解
                    else if (clazz.isAnnotationPresent(NsService.class)) {
                        //获取到@NsService注解
                        NsService nsService = clazz.getAnnotation(NsService.class);
                        //获取到自定义value属性值
                        String beanName = nsService.value();
                        //如果自定义的值为空,则拿该类名当做ioc的key
                        if (beanName.isEmpty()) {
                            beanName = clazz.getSimpleName();
                        }
                        //此时,将该service类已经放入到容器中了
                        Object instance = clazz.newInstance();
                        ioc.put(beanName, instance);
                        //接下来,处理service对应的接口
                        for(Class<?> interClazz : clazz.getInterfaces()){
                            if(!ioc.containsKey(interClazz.getSimpleName())){
                                //因为是某个实现类对应的接口,因此对应的实例与实现类对应的实例是同一个
                                ioc.put(interClazz.getSimpleName(),instance);
                            }
                        }
                    }
                }
            } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
                throw new RuntimeException("NsDispatcherServlet.doInitIoc pass Exception , detail info :"+e.getMessage());
            }
        }
    }

    /**
     * 遍历ioc容器的类,给添加了@NsAutowired的类进行依赖注入
     */
    private void doAutowired() {
        //当ioc中有类时再处理
        if(!ioc.isEmpty()){
            //遍历ioc容器
            for(Map.Entry<String,Object> entry : ioc.entrySet()){
                //获取到类中所有声明的文件
                Field[] fields = entry.getValue().getClass().getDeclaredFields();
                //遍历这些声明属性,拿到添加了@NsAutowired注解的声明
                for(Field field : fields){
                    if(field.isAnnotationPresent(NsAutowired.class)){
                        //因为目前只提供了通过类型注入的方式,因此拿到类型对应的类名就可以了
                        String beanName = field.getType().getSimpleName();
                        //如果是public以外的修饰符,只要添加了@NsAutowired注解的,都要强制赋值
                        field.setAccessible(true);
                        //利用反射,动态给字段赋值
                        try {
                            field.set(entry.getValue(),ioc.get(beanName));
                        } catch (IllegalAccessException e) {
                            throw new RuntimeException("NsDispatcherServlet.doAutowired pass Exception, detail info:"+e.getMessage());
                        }
                    }
                }
            }
        }
    }

    /**
     * 添加url与method的映射关系
     *
     */
    private void doHandlerMapping() {
        if(!ioc.isEmpty()){
            for (Map.Entry<String, Object> entry : ioc.entrySet()) {
                //获取类对应的类信息
                Class<?> clazz = entry.getValue().getClass();
                //如果该类为@NsController注解的,再进行接下来的处理
                if(clazz.isAnnotationPresent(NsController.class)){
                    String basicUrl ="";//类中@NsRequestMapping对应的值
                    if(clazz.isAnnotationPresent(NsRequestMapping.class)){
                        NsRequestMapping nsRequestMapping = clazz.getAnnotation(NsRequestMapping.class);
                        basicUrl = nsRequestMapping.value();
                    }
                    //拿到类中所有的方法
                    Method[] methods = clazz.getMethods();
                    for (Method method : methods) {
                        //只处理添加了@NsRequestMapping的方法
                        if(method.isAnnotationPresent(NsRequestMapping.class)){
                            NsRequestMapping requestMapping = method.getAnnotation(NsRequestMapping.class);
                            //拼方法的url,把多个/替换成一个
                            String regex = (basicUrl+requestMapping.value()).replaceAll("/+","/");
                            //通过正则来处理url
                            Pattern pattern = Pattern.compile(regex);
                            handlerMapping.add(new Handler(entry.getValue(),pattern,method));
                        }
                    }
                }
            }
        }
    }

    /**
     * 内部类
     * Handler类,维护url与method对应关系类
     */
    private class Handler {

        Pattern url;   //请求url

        Object controller;//属于哪个controller的

        Method method;//调用哪个方法

        Map<String,Integer> paramIndexMapping;//方法中的参数顺序

        Handler(Object controller, Pattern url, Method method){
            this.controller = controller;
            this.url = url;
            this.method = method;
            paramIndexMapping = new HashMap<>();
            putParamIndexMapping(method);
        }

        private void putParamIndexMapping(Method method) {
            //获取方法中添加了注解的参数
            Parameter[] parameters = method.getParameters();
            for(int i=0;i<parameters.length;i++){
                Parameter parameter = parameters[i];
                if(parameter.getType() == HttpServletRequest.class || parameter.getType() == HttpServletResponse.class){
                    paramIndexMapping.put(parameter.getType().getName(),i);
                    continue;
                }
                String key;
                if(parameter.isAnnotationPresent(NsRequestParam.class)){
                    NsRequestParam nsRequestParam = parameter.getAnnotation(NsRequestParam.class);
                    key = nsRequestParam.value();
                    paramIndexMapping.put(key,i);
                }
            }
        }
    }

}

七、修改web.xml文件

web.xml
<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_4_0.xsd"
         version="4.0">
    <servlet>
        <servlet-name>nsSpringMvc</servlet-name>
        <servlet-class>com.ns.servlet.NsDispatcherServlet</servlet-class>
        <init-param>
            <param-name>applicationContext</param-name>
            <param-value>application.properties</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
    </servlet>
    <servlet-mapping>
        <servlet-name>nsSpringMvc</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
</web-app>

整个spring的设计原理以及IOC的初始化,DI实现都已经简单模拟完成
运行结果:
在这里插入图片描述
在这里插入图片描述
最后,希望大家都能学以致用。
注:可以先把代码贴出来看看运行过程,最好还是自己多写几遍,边写边思考,容易理解。
代码如有误导,请帮忙指正,谢谢。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值