模仿SpringMVC的DispatcherServlet 手撸300行代码提炼精华设计思想并保证功能可用(1.0版本)

前言

1、博客内容均出自于咕泡学院架构师第三期
2、架构师系列内容:架构师学习笔记(持续更新)
3、内容为手写SpringMVC的DistapcherServlet的核心功能,从V1版本到V2版本再到V3版本。

1、SpringMVC中的DispatcherServlet的核心功能是哪些?

首先看一下mvc的流程:

  1. 用户发送请求至前端控制器DispatcherServlet。
  2. DispatcherServlet收到请求调用HandlerMapping处理器映射器。
  3. 处理器映射器找到具体的处理器(可以根据xml配置、注解进行查找),生成处理器对象及处理器拦截器(如果有则生成)一并返回给DispatcherServlet。
  4. DispatcherServlet调用HandlerAdapter处理器适配器。
  5. HandlerAdapter经过适配调用具体的处理器(Controller,也叫后端控制器)。
  6. Controller执行完成返回ModelAndView。
  7. HandlerAdapter将controller执行结果ModelAndView返回给DispatcherServlet。
  8. DispatcherServlet将ModelAndView传给ViewReslover视图解析器。
  9. ViewReslover解析后返回具体View。
  10. DispatcherServlet根据View进行渲染视图(即将模型数据填充至视图中)。
  11. DispatcherServlet响应用户。

再看DispatcherServlet 主要做了什么:

@SuppressWarnings("serial")
public class DispatcherServlet extends FrameworkServlet {


   /**
    * Initialize the strategy objects that this servlet uses.
    * <p>May be overridden in subclasses in order to initialize further strategy objects.
    */
   //初始化策略
   protected void initStrategies(ApplicationContext context) {
      //多文件上传的组件
      initMultipartResolver(context);
      //初始化本地语言环境
      initLocaleResolver(context);
      //初始化模板处理器
      initThemeResolver(context);
      //handlerMapping
      initHandlerMappings(context);
      //初始化参数适配器
      initHandlerAdapters(context);
      //初始化异常拦截器
      initHandlerExceptionResolvers(context);
      //初始化视图预处理器
      initRequestToViewNameTranslator(context);
      //初始化视图转换器
      initViewResolvers(context);
      //
      initFlashMapManager(context);
   }

   /**
    * Exposes the DispatcherServlet-specific request attributes and delegates to {@link #doDispatch}
    * for the actual dispatching.
    */
   //获取请求,设置一些request的参数,然后分发给doDispatch
   @Override
   protected void doService(HttpServletRequest request, HttpServletResponse response) throws Exception {
      if (logger.isDebugEnabled()) {
         String resumed = WebAsyncUtils.getAsyncManager(request).hasConcurrentResult() ? " resumed" : "";
         logger.debug("DispatcherServlet with name '" + getServletName() + "'" + resumed +
               " processing " + request.getMethod() + " request for [" + getRequestUri(request) + "]");
      }


      // Keep a snapshot of the request attributes in case of an include,
      // to be able to restore the original attributes after the include.
      Map<String, Object> attributesSnapshot = null;
      if (WebUtils.isIncludeRequest(request)) {
         attributesSnapshot = new HashMap<>();
         Enumeration<?> attrNames = request.getAttributeNames();
         while (attrNames.hasMoreElements()) {
            String attrName = (String) attrNames.nextElement();
            if (this.cleanupAfterInclude || attrName.startsWith(DEFAULT_STRATEGIES_PREFIX)) {
               attributesSnapshot.put(attrName, request.getAttribute(attrName));
            }
         }
      }


      // Make framework objects available to handlers and view objects.
      /* 设置web应用上下文**/
      request.setAttribute(WEB_APPLICATION_CONTEXT_ATTRIBUTE, getWebApplicationContext());
      /* 国际化本地**/
      request.setAttribute(LOCALE_RESOLVER_ATTRIBUTE, this.localeResolver);
      /* 样式**/
      request.setAttribute(THEME_RESOLVER_ATTRIBUTE, this.themeResolver);
      /* 设置样式资源**/
      request.setAttribute(THEME_SOURCE_ATTRIBUTE, getThemeSource());
      //请求刷新时保存属性
      if (this.flashMapManager != null) {
         FlashMap inputFlashMap = this.flashMapManager.retrieveAndUpdate(request, response);
         if (inputFlashMap != null) {
            request.setAttribute(INPUT_FLASH_MAP_ATTRIBUTE, Collections.unmodifiableMap(inputFlashMap));
         }
         //Flash attributes 在对请求的重定向生效之前被临时存储(通常是在session)中,并且在重定向之后被立即移除
         request.setAttribute(OUTPUT_FLASH_MAP_ATTRIBUTE, new FlashMap());
         //FlashMap 被用来管理 flash attributes 而 FlashMapManager 则被用来存储,获取和管理 FlashMap 实体.
         request.setAttribute(FLASH_MAP_MANAGER_ATTRIBUTE, this.flashMapManager);
      }


      try {
         doDispatch(request, response);
      }
      finally {
         if (!WebAsyncUtils.getAsyncManager(request).isConcurrentHandlingStarted()) {
            // Restore the original attribute snapshot, in case of an include.
            if (attributesSnapshot != null) {
               restoreAttributesAfterInclude(request, attributesSnapshot);
            }
         }
      }
   }


   /**
    * Process the actual dispatching to the handler.
    * <p>The handler will be obtained by applying the servlet's HandlerMappings in order.
    * The HandlerAdapter will be obtained by querying the servlet's installed HandlerAdapters
    * to find the first that supports the handler class.
    * <p>All HTTP methods are handled by this method. It's up to HandlerAdapters or handlers
    * themselves to decide which methods are acceptable.
    * @param request current HTTP request
    * @param response current HTTP response
    * @throws Exception in case of any kind of processing failure
    */
   /**
    * 中央控制器,控制请求的转发
    * 将Handler进行分发,handler会被handlerMapping有序的获得
    * 通过查询servlet安装的HandlerAdapters来获得HandlerAdapters来查找第一个支持handler的类
    * 所有的HTTP的方法都会被这个方法掌控。取决于HandlerAdapters 或者handlers 他们自己去决定哪些方法是可用
    * **/
   protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
      HttpServletRequest processedRequest = request;
      HandlerExecutionChain mappedHandler = null;
      boolean multipartRequestParsed = false;


      WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);


      try {
         ModelAndView mv = null;
         Exception dispatchException = null;


         try {
            // 1.检查是否是文件上传的请求
            processedRequest = checkMultipart(request);
            multipartRequestParsed = (processedRequest != request);


            // Determine handler for the current request.
            // 2.取得处理当前请求的controller,这里也称为handler,处理器,
            //      第一个步骤的意义就在这里体现了.这里并不是直接返回controller,
            //  而是返回的HandlerExecutionChain请求处理器链对象,
            //  该对象封装了handler和interceptors.
            mappedHandler = getHandler(processedRequest);
            // 如果handler为空,则返回404
            if (mappedHandler == null) {
               noHandlerFound(processedRequest, response);
               return;
            }


            // Determine handler adapter for the current request.
            //3. 获取处理request的处理器适配器handler adapter
            HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());


            // Process last-modified header, if supported by the handler.
            // 处理 last-modified 请求头
            String method = request.getMethod();
            boolean isGet = "GET".equals(method);
            if (isGet || "HEAD".equals(method)) {
               long lastModified = ha.getLastModified(request, mappedHandler.getHandler());
               if (logger.isDebugEnabled()) {
                  logger.debug("Last-Modified value for [" + getRequestUri(request) + "] is: " + lastModified);
               }
               if (new ServletWebRequest(request, response).checkNotModified(lastModified) && isGet) {
                  return;
               }
            }


            if (!mappedHandler.applyPreHandle(processedRequest, response)) {
               return;
            }


            // Actually invoke the handler.
            // 4.实际的处理器处理请求,返回结果视图对象
            mv = ha.handle(processedRequest, response, mappedHandler.getHandler());


            if (asyncManager.isConcurrentHandlingStarted()) {
               return;
            }


            // 结果视图对象的处理
            applyDefaultViewName(processedRequest, mv);
            mappedHandler.applyPostHandle(processedRequest, response, mv);
         }
         catch (Exception ex) {
            dispatchException = ex;
         }
         catch (Throwable err) {
            // As of 4.3, we're processing Errors thrown from handler methods as well,
            // making them available for @ExceptionHandler methods and other scenarios.
            dispatchException = new NestedServletException("Handler dispatch failed", err);
         }
         processDispatchResult(processedRequest, response, mappedHandler, mv, dispatchException);
      }
      catch (Exception ex) {
         triggerAfterCompletion(processedRequest, response, mappedHandler, ex);
      }
      catch (Throwable err) {
         triggerAfterCompletion(processedRequest, response, mappedHandler,
               new NestedServletException("Handler processing failed", err));
      }
      finally {
         if (asyncManager.isConcurrentHandlingStarted()) {
            // Instead of postHandle and afterCompletion
            if (mappedHandler != null) {
               // 请求成功响应之后的方法
               mappedHandler.applyAfterConcurrentHandlingStarted(processedRequest, response);
            }
         }
         else {
            // Clean up any resources used by a multipart request.
            if (multipartRequestParsed) {
               cleanupMultipart(processedRequest);
            }
         }
      }
   }
}

DispatcherServlet 主要做了initStrategies 初始化策略 ,doService获取请求,设置一些request的参数,然后分发给doDispatch ,doDispatch 中央控制器,控制请求的转发

再来看下面这个图:
在这里插入图片描述
接下来就按照DispatcherServlet 的核心代码,加上图中的流程来模拟DispatchServlet的核心代码

2、配置阶段

配置 application.properties 文件
为了解析方便,我们用 application.properties 来代替 application.xml 文件,具体配置内容如下:

scanPackage=com.jarvisy.demo

配置 web.xml 文件
所有依赖于 web 容器的项目,都是从读取 web.xml 文件开始的。我们先配置好 web.xml中的内容:

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_3_1.xsd"
         version="3.1">

    <!--配置springmvc DispatcherServlet-->
    <servlet>
        <servlet-name>springMVC</servlet-name>
        // 配置成自己实现的MyDispatcherServlet类
        <servlet-class>com.jarvisy.demo.v1.servlet.MyDispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            // 正常这里应该是application.xml文件
            <param-value>application.properties</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
        <async-supported>true</async-supported>
    </servlet>

    <servlet-mapping>
        <servlet-name>springMVC</servlet-name>
        // 这里拦截所以请求
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
</web-app>

自定义Annotation
把SpringMVC的注解拷过来,去掉看不懂的,对目前来说无用的:

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutowired {
    String value() default "";
}
//----------------------------分割线,不同类---------------------------------------
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyController {
    String value() default "";
}
//----------------------------分割线,不同类---------------------------------------
@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestMapping {
    String value() default "";
}
//----------------------------分割线,不同类---------------------------------------
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestParam {
    String value() default "";
}
//----------------------------分割线,不同类---------------------------------------
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {
    String value() default "";
}

编写Controller 类,Servive类:

@MyController
@MyRequestMapping("/demo")
public class DemoController {


    @MyAutowired
    private DemoService demoService;


    @MyRequestMapping("/name")
    public void name(HttpServletRequest req, HttpServletResponse resp, @MyRequestParam() String name) {
        String result = demoService.get(name);
        try {
            resp.getWriter().write(result);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }


    @MyRequestMapping("/name1")
    public void name1(HttpServletRequest req, HttpServletResponse resp, @MyRequestParam() String[] name) {
        String result = Arrays.asList(name).toString();
        try {
            resp.getWriter().write(result);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }


    @MyRequestMapping("/add")
    public void add(HttpServletRequest req, HttpServletResponse resp, @MyRequestParam("a") Integer a, @MyRequestParam("b") Integer b) {
        try {
            resp.getWriter().write(a + " + " + b + " = " + (a + b));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

//----------------------------分割线,不同类---------------------------------------
public interface DemoService {


    public String get(String name );
}
//----------------------------分割线,不同类---------------------------------------


@MyService
public class DemoServiceImpl implements DemoService {
    @Override
    public String get(String name) {
        return "My name is " + name;
    }
}

3、容器初始化

实现V1版本
所有的核心逻辑全部写在一个 init()方法中,没有设计模式可言,代码比较长,混乱。不能一目了然。

public class MyDispatcherServletBak extends HttpServlet {
    private Map<String, Object> mapping = new HashMap<String, Object>();

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        try {
            doDispatch(req, resp);
        } catch (Exception e) {
            resp.getWriter().write("500 Exception " + Arrays.toString(e.getStackTrace()));
        }
    }

    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");
        if (!this.mapping.containsKey(url)) {
            resp.getWriter().write("404 Not Found!!");
            return;
        }
        Method method = (Method) this.mapping.get(url);
        Map<String, String[]> params = req.getParameterMap();
        method.invoke(this.mapping.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get("name")[0]});
    }

    //init方法肯定干得的初始化的工作
    //inti首先我得初始化所有的相关的类,IOC容器、servletBean
    @Override
    public void init(ServletConfig config) throws ServletException {
        InputStream is = null;
        try {
            Properties configContext = new Properties();
            is = this.getClass().getClassLoader().getResourceAsStream(config.getInitParameter("contextConfigLocation"));
            configContext.load(is);
            String scanPackage = configContext.getProperty("scanPackage");
            doScanner(scanPackage);
            for (String className : mapping.keySet()) {
                if (!className.contains(".")) {
                    continue;
                }
                Class<?> clazz = Class.forName(className);
                if (clazz.isAnnotationPresent(MyController.class)) {
                    mapping.put(className, clazz.newInstance());
                    String baseUrl = "";
                    if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
                        MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
                        baseUrl = requestMapping.value();
                    }
                    Method[] methods = clazz.getMethods();
                    for (Method method : methods) {
                        if (!method.isAnnotationPresent(MyRequestMapping.class)) {
                            continue;
                        }
                        MyRequestMapping requestMapping = method.getAnnotation(MyRequestMapping.class);
                        String url = (baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/");
                        mapping.put(url, method);
                        System.out.println("Mapped " + url + "," + method);
                    }
                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    MyService service = clazz.getAnnotation(MyService.class);
                    String beanName = service.value();
                    if ("".equals(beanName)) {
                        beanName = clazz.getName();
                    }
                    Object instance = clazz.newInstance();
                    mapping.put(beanName, instance);
                    for (Class<?> i : clazz.getInterfaces()) {
                        mapping.put(i.getName(), instance);
                    }
                } else {
                    continue;
                }
            }
            for (Object object : mapping.values()) {
                if (object == null) {
                    continue;
                }
                Class clazz = object.getClass();
                if (clazz.isAnnotationPresent(MyController.class)) {
                    Field[] fields = clazz.getDeclaredFields();
                    for (Field field : fields) {
                        if (!field.isAnnotationPresent(MyAutowired.class)) {
                            continue;
                        }
                        MyAutowired autowired = field.getAnnotation(MyAutowired.class);
                        String beanName = autowired.value();
                        if ("".equals(beanName)) {
                            beanName = field.getType().getName();
                        }
                        field.setAccessible(true);
                        try {
                            field.set(mapping.get(clazz.getName()), mapping.get(beanName));
                        } catch (IllegalAccessException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (is != null) {
                try {
                    is.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        System.out.print("My MVC Framework is init");
    }

    private void doScanner(String scanPackage) {
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File classDir = new File(url.getFile());
        for (File file : classDir.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else {
                if (!file.getName().endsWith(".class")) {
                    continue;
                }
                String clazzName = (scanPackage + "." + file.getName().replace(".class", ""));
                mapping.put(clazzName, null);
            }
        }
    }
}

实现V2版本
在 V1 版本上进了优化,采用了常用的设计模式(工厂模式、单例模式、委派模式、策略模式),将 init()方法中的代码进行封装。按照之前的实现思路,先搭基础框架,再填肉注血,具体代码如下
其中:
doInstance() : IOC容器就是注册式单例,工厂模式应用案例
initHandlerMapping() : handlerMapping 就是策略模式的应用案例
doPost() : 用了委派模式,委派模式的具体逻辑在 doDispatch()方法中

public class MyDispatcherServlet extends HttpServlet {
    //保存application.properties配置文件中的内容
    private Properties contextConfig = new Properties();
    //保存扫描的所有的类名
    private List<String> classNames = new ArrayList<>();
    // 传说中的IOC容器,为了简化程序只做演示,不考虑线程安全问题,不考虑ConcurrentHashMap
    private Map<String, Object> ioc = new HashMap<>();
    //保存url和method 的对应关系
    private Map<String, Method> handlerMapping = new HashMap<>();


    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }


    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        //6、调用,运行阶段
        try {
            doDispatch(req, resp);
        } catch (Exception e) {
            e.printStackTrace();
            resp.getWriter().write("500 Exception , Detail:" + Arrays.toString(e.getStackTrace()));
        }
    }


    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
        // 绝对路径->处理成相对路径
        String url = req.getRequestURI().replaceAll(req.getContextPath(), "").replaceAll("/+", "/");


        if (!this.handlerMapping.containsKey(url)) {
            resp.getWriter().write("My DispatcherServlet 404 Not Found");
            return;
        }
        Method method = this.handlerMapping.get(url);
        //通过反射拿到method所在的class,拿到className,再获取beanName
        String beanName = toLowerFirstCase(method.getDeclaringClass().getSimpleName());




        //从req中拿到url传过来的参数
        Map<String, String[]> params = req.getParameterMap();


        DefaultParameterNameDiscoverer discover = new DefaultParameterNameDiscoverer();
        //获取方法参数的真实名称
        String[] parameterNames = discover.getParameterNames(method);
        //获取方法的形参列表
        Class<?>[] parameterTypes = method.getParameterTypes();
        //获取参数上的所有注解 二维数组,可以有多个参数,每个参数可以有多个注解
        Annotation[][] parameterAnnotations = method.getParameterAnnotations();
        // 这里不考虑对象赋值,只是简化的版本
        Object[] paramValues = new Object[parameterTypes.length];


        for (int i = 0; i < parameterTypes.length; i++) {
            Class parameterType = parameterTypes[i];
            //不能用instanceof,parameterType它不是实参,而是形参
            if (parameterType == HttpServletRequest.class) {
                paramValues[i] = req;
                continue;
            } else if (parameterType == HttpServletResponse.class) {
                paramValues[i] = resp;
                continue;
            } else {// 这里不考虑对象赋值等, 只做简化演示
                // 如果没有MyRequestParam注解,或者MyRequestParam的value 为默认值的话,就直接用形参name去获取
                String paramName = "";
                for (Annotation a : parameterAnnotations[i]) {
                    if (a instanceof MyRequestParam) {
                        paramName = ((MyRequestParam) a).value();
                        break;
                    }
                }
                if ("".equals(paramName)) paramName = parameterNames[i];
                if (params.containsKey(paramName)) {
                    paramValues[i] = convert(parameterType, params.get(paramName));
                }
            }
        }


        method.invoke(ioc.get(beanName), paramValues);
    }


    //进行数据类型转换
    private Object convert(Class<?> type, String[] value) {
        //如果是int
        if (Integer.class == type) {
            return Integer.valueOf(value[0]);
        } else if (Integer[].class == type) {
            //do something
            return null;
        } else if (String.class == type) {
            return value[0];
        }else if (String[].class == type) {
            return value;
        }
        //如果还有double或者其他类型,继续加if
        //这时候,我们应该想到策略模式了,在这里暂时不实现
        return value;
    }


    //初始化阶段
    @Override
    public void init(ServletConfig config) throws ServletException {
        // 1、加载配置文件
        doLoadConfig(config.getInitParameter("contextConfigLocation"));
        //2、扫描相关的类
        doScanner(contextConfig.getProperty("scanPackage"));
        //3、初始化扫描到的类,并且将它们放到IOC容器之中
        doInstance();
        //4、完成依赖注入
        doAutowired();
        //5、初始化HandlerMapping
        initHandlerMapping();


        System.out.println("My Spring Framework Is Init");
    }


    // 初始化url和Method的一对一的关系
    private void initHandlerMapping() {
        if (ioc.isEmpty()) return;


        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(MyController.class)) continue;


            //保存写在类上面的@MyRequestMapping("/demo")
            String baseUrl = "";
            if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
                MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
                baseUrl = requestMapping.value();
            }
            //默认获取所有的public方法
            for (Method method : clazz.getMethods()) {
                if (!method.isAnnotationPresent(MyRequestMapping.class)) continue;


                MyRequestMapping requestMapping = method.getAnnotation(MyRequestMapping.class);
                String url = ("/" + baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/");
                handlerMapping.put(url, method);
                System.out.println("Mapped :" + url + "," + method);
            }
        }
    }


    private void doAutowired() {
        if (ioc.isEmpty()) return;


        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            //Declared 所有的,特定的,字段。包括private,protected,default
            // 正常来讲,普通的OOP编程只能拿到public的属性
            Field[] fields = entry.getValue().getClass().getDeclaredFields();
            for (Field field : fields) {
                if (!field.isAnnotationPresent(MyAutowired.class)) continue;
                MyAutowired autowired = field.getAnnotation(MyAutowired.class);


                // 如果没有自定义beanName 则采用类型注入
                String beanName = autowired.value().trim();
                //field.getType().getName()  获得接口类型作为key去取
                if ("".equals(beanName)) beanName = field.getType().getName();


                //强吻 暴力访问
                field.setAccessible(true);
                try {
                    //用反射机制,动态给字段属性复制
                    field.set(entry.getValue(), ioc.get(beanName));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }


    }


    private void doInstance() {
        //初始化,为DI做准备


        //如果为null 就不往下走
        if (classNames.isEmpty()) return;


        try {
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                //这里只需要初始化加了我们自定义注解的类
                //这里只是做演示,体会其流程,设计思想,只举例@Controller 和@Service...


                if (clazz.isAnnotationPresent(MyController.class)) {
                    //Spring默认类名首字母小写
                    ioc.put(toLowerFirstCase(clazz.getSimpleName()), clazz.newInstance());
                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    //1、自定义beanName
                    MyService service = clazz.getAnnotation(MyService.class);
                    String beanName = service.value();
                    //2、如果没有自定义beanName,则默认类名首字母小写
                    if ("".equals(beanName.trim())) beanName = toLowerFirstCase(clazz.getSimpleName());
                    Object newInstance = clazz.newInstance();
                    ioc.put(beanName, newInstance);


                    //3、根据类型自动赋值 这里是找到他的所有接口然后给他实现类的值,是为了Autowired的时候方便(在注入的时候直接用接口类型去ioc取)
                    for (Class<?> anInterface : clazz.getInterfaces()) {
                        if (ioc.containsKey(anInterface.getName()))
                            throw new Exception("The" + anInterface.getName() + " is exists!!!");
                        // 把接口的类型直接当成key
                        ioc.put(anInterface.getName(), newInstance);
                    }
                }


            }
        } catch (Exception e) {
            e.printStackTrace();
        }


    }


    //转换首字母小写
    private String toLowerFirstCase(String beanName) {
        return beanName.replaceFirst("^.", String.valueOf(beanName.charAt(0)).toLowerCase());
    }


    //扫描出相关的类
    private void doScanner(String scanPackage) {
        //scanPackage = com.jarvisy.demo.mvc
        //需要把包路径转换为文件路径 classpath 路径
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File classPath = new File(url.getFile());
        for (File file : classPath.listFiles()) {
            //如果是文件夹,就还需要往下循环一层一层的找
            if (file.isDirectory()) doScanner(scanPackage + "." + file.getName());
            else {
                //只扫描class文件
                if (!file.getName().endsWith(".class")) continue;
                String className = (scanPackage + "." + file.getName().replace(".class", ""));
                classNames.add(className);
            }
        }


    }


    //加载配置文件
    private void doLoadConfig(String contextConfigLocation) {
        //直接从类路径下找到Spring主配置文件所在的路径,并且将其读取出来放到Properties对象中
        //相当于把scanPackage=com.jarvisy.demo.mvc 从文件中保存到了内存中。
        try (InputStream inputStream = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation)) {
            contextConfig.load(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }


    }
}

V3版本优化:
在 V2 版本中,基本功能以及完全实现,但代码的优雅程度还不如人意。譬如 HandlerMapping 还不能像 SpringMVC一样支持正则,url 参数还不支持强制类型转换,在反射调用前还需要重新获取 beanName,在 V3 版本中,继续优化
首先,改造 HandlerMapping,在真实的 Spring 源码中,HandlerMapping 其实是一个 List 而非 Map。List 中的元素是一个自定义的类型。

public class MyDispatcherServlet extends HttpServlet {
    //保存application.properties配置文件中的内容
    private Properties contextConfig = new Properties();
    //保存扫描的所有的类名
    private List<String> classNames = new ArrayList<>();
    // 传说中的IOC容器,为了简化程序只做演示,不考虑线程安全问题,不考虑ConcurrentHashMap
    private Map<String, Object> ioc = new HashMap<>();
    //保存url和method 的对应关系
    // 为什么不用map?
    // 用Map,key只能是url ,但是HandlerMapping本身功能就是把url跟method关系对应,已经具备map的功能
    //根据设计原则:单一职能原则,最少知道原则。
    private List<HandlerMapping> handlerMapping = new ArrayList<>();


    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }


    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        //6、调用,运行阶段
        try {
            doDispatch(req, resp);
        } catch (Exception e) {
            e.printStackTrace();
            resp.getWriter().write("500 Exception , Detail:" + Arrays.toString(e.getStackTrace()));
        }
    }


    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
        HandlerMapping handlerMapping = getHandle(req);
        if (handlerMapping == null) {
            resp.getWriter().write("My DispatcherServlet 404 Not Found");
            return;
        }
        Object[] paramValues = new Object[handlerMapping.paramTypes.length];
        Map<String, String[]> params = req.getParameterMap();
        for (Map.Entry<String, String[]> param : params.entrySet()) {
            if (!handlerMapping.paramIndexMapping.containsKey(param.getKey())) continue;
            Integer index = handlerMapping.paramIndexMapping.get(param.getKey());
            paramValues[index] = convert(handlerMapping.paramTypes[index], param.getValue());
        }


        if (handlerMapping.paramIndexMapping.containsKey(HttpServletRequest.class.getName()))
            paramValues[handlerMapping.paramIndexMapping.get(HttpServletRequest.class.getName())] = req;
        if (handlerMapping.paramIndexMapping.containsKey(HttpServletResponse.class.getName()))
            paramValues[handlerMapping.paramIndexMapping.get(HttpServletResponse.class.getName())] = resp;


        handlerMapping.method.invoke(handlerMapping.controller, paramValues);
    }


    private HandlerMapping getHandle(HttpServletRequest req) {
        if (this.handlerMapping.isEmpty()) return null;
        // 绝对路径->处理成相对路径
        String url = req.getRequestURI().replaceAll(req.getContextPath(), "").replaceAll("/+", "/");
        for (HandlerMapping mapping : this.handlerMapping) {
            if (mapping.getUrl().equals(url)) return mapping;
        }
        return null;
    }


    //进行数据类型转换
    //Spring 做了顶层转换策略  public interface Converter<S, T> 实现了很多转换类型
    private Object convert(Class<?> type, String[] value) {
        //如果是int
        if (Integer.class == type) {
            return Integer.valueOf(value[0]);
        } else if (Integer[].class == type) {
            //do something
            return null;
        } else if (String.class == type) {
            return value[0];
        } else if (String[].class == type) {
            return value;
        }
        //如果还有double或者其他类型,继续加if
        //这时候,我们应该想到策略模式了,在这里暂时不实现
        return value;
    }


    //初始化阶段
    @Override
    public void init(ServletConfig config) throws ServletException {
        // 1、加载配置文件
        doLoadConfig(config.getInitParameter("contextConfigLocation"));
        //2、扫描相关的类
        doScanner(contextConfig.getProperty("scanPackage"));
        //3、初始化扫描到的类,并且将它们放到IOC容器之中
        doInstance();
        //4、完成依赖注入
        doAutowired();
        //5、初始化HandlerMapping
        initHandlerMapping();


        System.out.println("My Spring Framework Is Init");
    }


    // 初始化url和Method的一对一的关系
    private void initHandlerMapping() {
        if (ioc.isEmpty()) return;


        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(MyController.class)) continue;


            //保存写在类上面的@MyRequestMapping("/demo")
            String baseUrl = "";
            if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
                MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
                baseUrl = requestMapping.value();
            }
            //默认获取所有的public方法
            for (Method method : clazz.getMethods()) {
                if (!method.isAnnotationPresent(MyRequestMapping.class)) continue;


                MyRequestMapping requestMapping = method.getAnnotation(MyRequestMapping.class);
                String url = ("/" + baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/");
                this.handlerMapping.add(new HandlerMapping(url, entry.getValue(), method));
                System.out.println("Mapped :" + url + "," + method);
            }
        }
    }


    private void doAutowired() {
        if (ioc.isEmpty()) return;


        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            //Declared 所有的,特定的,字段。包括private,protected,default
            // 正常来讲,普通的OOP编程只能拿到public的属性
            Field[] fields = entry.getValue().getClass().getDeclaredFields();
            for (Field field : fields) {
                if (!field.isAnnotationPresent(MyAutowired.class)) continue;
                MyAutowired autowired = field.getAnnotation(MyAutowired.class);


                // 如果没有自定义beanName 则采用类型注入
                String beanName = autowired.value().trim();
                //field.getType().getName()  获得接口类型作为key去取
                if ("".equals(beanName)) beanName = field.getType().getName();


                //强吻 暴力访问
                field.setAccessible(true);
                try {
                    //用反射机制,动态给字段属性复制
                    field.set(entry.getValue(), ioc.get(beanName));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }


    }


    private void doInstance() {
        //初始化,为DI做准备


        //如果为null 就不往下走
        if (classNames.isEmpty()) return;


        try {
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                //这里只需要初始化加了我们自定义注解的类
                //这里只是做演示,体会其流程,设计思想,只举例@Controller 和@Service...


                if (clazz.isAnnotationPresent(MyController.class)) {
                    //Spring默认类名首字母小写
                    ioc.put(toLowerFirstCase(clazz.getSimpleName()), clazz.newInstance());
                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    //1、自定义beanName
                    MyService service = clazz.getAnnotation(MyService.class);
                    String beanName = service.value();
                    //2、如果没有自定义beanName,则默认类名首字母小写
                    if ("".equals(beanName.trim())) beanName = toLowerFirstCase(clazz.getSimpleName());
                    Object newInstance = clazz.newInstance();
                    ioc.put(beanName, newInstance);


                    //3、根据类型自动赋值 这里是找到他的所有接口然后给他实现类的值,是为了Autowired的时候方便(在注入的时候直接用接口类型去ioc取)
                    for (Class<?> anInterface : clazz.getInterfaces()) {
                        if (ioc.containsKey(anInterface.getName()))
                            throw new Exception("The" + anInterface.getName() + " is exists!!!");
                        // 把接口的类型直接当成key
                        ioc.put(anInterface.getName(), newInstance);
                    }
                }


            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }


    //转换首字母小写
    private String toLowerFirstCase(String beanName) {
        return beanName.replaceFirst("^.", String.valueOf(beanName.charAt(0)).toLowerCase());
    }


    //扫描出相关的类
    private void doScanner(String scanPackage) {
        //scanPackage = com.jarvisy.demo.mvc
        //需要把包路径转换为文件路径 classpath 路径
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File classPath = new File(url.getFile());
        for (File file : classPath.listFiles()) {
            //如果是文件夹,就还需要往下循环一层一层的找
            if (file.isDirectory()) doScanner(scanPackage + "." + file.getName());
            else {
                //只扫描class文件
                if (!file.getName().endsWith(".class")) continue;
                String className = (scanPackage + "." + file.getName().replace(".class", ""));
                classNames.add(className);
            }
        }
    }

    //加载配置文件
    private void doLoadConfig(String contextConfigLocation) {
        //直接从类路径下找到Spring主配置文件所在的路径,并且将其读取出来放到Properties对象中
        //相当于把scanPackage=com.jarvisy.demo.mvc 从文件中保存到了内存中。
        try (InputStream inputStream = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation)) {
            contextConfig.load(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }


    }




    //保存一个url和一个Method的关系
    public class HandlerMapping {
        //请求url
        private String url;
        // url对应的method
        private Method method;
        private Object controller;
        //形参列表 参数的名字作为key,参数的顺序位置作为值
        private Map<String, Integer> paramIndexMapping;
        private Class<?>[] paramTypes;


        public HandlerMapping(String url, Object controller, Method method) {
            this.url = url;
            this.method = method;
            this.controller = controller;
            paramIndexMapping = new HashMap<>();
            paramTypes = method.getParameterTypes();
            putParamIndexMapping(method);
        }

        private void putParamIndexMapping(Method method) {
            DefaultParameterNameDiscoverer discover = new DefaultParameterNameDiscoverer();
            //获取方法参数的真实名称
            String[] parameterNames = discover.getParameterNames(method);
            //提取方法中加了注解的参数
            Annotation[][] pa = method.getParameterAnnotations();
            for (int i = 0; i < paramTypes.length; i++) {
                Class<?> type = paramTypes[i];
                if (type == HttpServletRequest.class || type == HttpServletResponse.class) {
                    paramIndexMapping.put(type.getName(), i);
                    continue;
                }
                String paramName = "";
                for (Annotation a : pa[i]) {
                    if (a instanceof MyRequestParam) {
                        paramName = ((MyRequestParam) a).value();
                        break;
                    }
                }
                if ("".equals(paramName)) paramName = parameterNames[i];
                paramIndexMapping.put(paramName, i);
            }
        }


        public String getUrl() {
            return url;
        }

        public Method getMethod() {
            return method;
        }

        public Object getController() {
            return controller;
        }

        public Class<?>[] getParamTypes() {
            return paramTypes;
        }
    }

}
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值