简单实现一个springmvc

前言

做Java的应该不会有人没听过Spring吧!!Spring是一个优秀的IOC容器,为了增加对spring的了解,自己简单实现一个SpringMVC。
先说明以下几点:
1、IOC容器是个啥?就是一个大Map,可以理解为HashMap<String,Object>
2、springmvc的核心从哪开始?DispatcherServlet

  • 熟悉servlet编程的小伙子,肯定都知道,面向Servlet编程时,会在web.xml中配置一个ServletMapping,指定了每个Servlet要拦截的URL,然后做对应的业务处理。这也就是说一个Servlet基本上只能处理一个URL业务了!!!
  • SpringMVC,做了一些改变,将URL细粒度到方法层面(也可以认为是一个大Map,HashMap<String,Method>),使用反射method.invoke(instance,args),对应的URL去找对应的Method执行,大大简化了代码。

下面就简单实现一个SpringMVC,资源下载地址加深一下印象。

分析

DispatcherServlet都要干些啥

上面说了,所有的URL请求都要先经过Servlet,通过serlvet的doPost或doGet方法进行处理。并且,每个URL对应一个方法(controller中的方法),那要做的事情,无非就是以下几件:

  • 将所有的Controller、service类扫描出来进行实例化,放入IOC容器中
  • 处理controller和service对象的依赖关系(controller需要调用service的方法)
  • 将URL和方法映射起来,做一个URLMapping。
  • 处理请求时,将请求中的参数解析为method方法的参数

实现

pom依赖

<dependencies>
  <dependency>
      <groupId>javax.servlet</groupId>
      <artifactId>javax.servlet-api</artifactId>
      <version>3.1.0</version>
      <scope>provided</scope>
  </dependency>
  <!-- fastjson支持 -->
  <dependency>
      <groupId>com.alibaba</groupId>
      <artifactId>fastjson</artifactId>
      <version>1.2.40</version>
  </dependency>
</dependencies>

定义一些注解

仿照SpringMVC,定义@Controller、@Service、@Autowired、@RequestMapping、@RequestParam
这些注解的作用,大家应该都是很熟悉的吧

package com.zyu.mvc.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(value = {ElementType.TYPE})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface Controller {
}

package com.zyu.mvc.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(value = {ElementType.TYPE})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface Service {
    String value() default "";
}

package com.zyu.mvc.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(value = {ElementType.FIELD})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface Autowired {
    String value() default "";
}

package com.zyu.mvc.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(value = {ElementType.TYPE, ElementType.METHOD})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface RequestMapping {
    String value() default "";
}

package com.zyu.mvc.anno;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(value = {ElementType.PARAMETER})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface RequestParam {
    String value() default "";
}

DispatcherServlet实现

package com.zyu.mvc.servlet;

import com.alibaba.fastjson.JSON;
import com.zyu.mvc.anno.*;
import com.zyu.mvc.handle.HandService;

import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

/**
 * 自定义DispatcherServlet,用于拦截请求,做业务处理
 */
public class DispatcherServlet extends HttpServlet {
    /**
     * 扫描的类路径集合
     */
    List<String> classPaths = new ArrayList<>();
    /**
     * IOC容器
     */
    HashMap<String, Object> beans = new HashMap<>();

    /**
     * url路径和方法的映射
     */
    HashMap<String, Object> urlHandlers = new HashMap<>();

    /**
     * 覆盖父类的init方法,做一些初始化内容
     *
     * @throws ServletException
     */
    @Override
    public void init() throws ServletException {
        super.init();
        // 1 扫描包路径
        doScanPackage("com.zyu.mvc");
        // 2 实例化对象
        doInstance();
        // 3 处理依赖的注入
        resolveDependencies();
        // 4 url映射
        doUrlMapping();
    }

    /**
     * 处理方法和路径的映射
     */
    private void doUrlMapping() {
        //将controller中对应方法和配置的url路径映射起来
        beans.forEach((beanName, instance) -> {
            Class<?> clazz = instance.getClass();
            if (clazz.isAnnotationPresent(Controller.class)) {
                String baseUrl = "";
                if (clazz.isAnnotationPresent(RequestMapping.class)) {
                    RequestMapping anno = clazz.getAnnotation(RequestMapping.class);
                    if (!"".equals(anno.value().trim())) {
                        baseUrl = anno.value().trim();
                    }
                }
                Method[] methods = clazz.getDeclaredMethods();
                for (Method method : methods) {
                    if (method.isAnnotationPresent(RequestMapping.class)) {
                        String methodUrl = "";
                        RequestMapping anno = method.getAnnotation(RequestMapping.class);
                        if (!"".equals(anno.value().trim())) {
                            methodUrl = anno.value().trim();
                        }
                        urlHandlers.put(baseUrl + methodUrl, method);
                    }
                }
            }
        });
    }

    /**
     * 处理自动注入
     */
    private void resolveDependencies() {
        //遍历IOC容器,解决bean依赖问题,处理Autowired自动注入
        beans.forEach((beanName, instance) -> {
            Class<?> clazz = instance.getClass();
            if (clazz.isAnnotationPresent(Controller.class)) {
                Field[] fields = clazz.getDeclaredFields();
                for (Field field : fields) {
                    if (field.isAnnotationPresent(Autowired.class)) {
                        Autowired anno = field.getAnnotation(Autowired.class);
                        String key = anno.value().trim().equals("") ? lowerFirstChar(field.getType().getSimpleName()) : anno.value();
                        field.setAccessible(true);
                        try {
                            field.set(instance, beans.get(key));
                        } catch (IllegalAccessException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        });
    }

    /**
     * 实例化扫描的对象
     */
    private void doInstance() {
        //遍历收集的类信息,实例化controller和service实例,放入ioc容器中
        classPaths.forEach(classPath -> {
            try {
                Class<?> clazz = Class.forName(classPath);
                if (clazz.isAnnotationPresent(Controller.class)) {
                    beans.put(lowerFirstChar(clazz.getSimpleName()), clazz.newInstance());
                } else if (clazz.isAnnotationPresent(Service.class)) {
                    Service anno = clazz.getAnnotation(Service.class);
                    String key = anno.value();
                    if (!"".equals(key.trim())) {
                        beans.put(key, clazz.newInstance());
                    } else {
                        Class<?>[] interfaces = clazz.getInterfaces();
                        for (Class<?> inter : interfaces) {
                            beans.put(lowerFirstChar(inter.getSimpleName()), clazz.newInstance());
                        }
                    }
                }
            } catch (ClassNotFoundException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } catch (InstantiationException e) {
                e.printStackTrace();
            }
        });
    }

    /**
     * 扫描需要加载的类路径
     *
     * @param basePackage
     */
    private void doScanPackage(String basePackage) {
        //获取要加载的包路径
        URL resource = this.getClass().getClassLoader().getResource("/" + basePackage.replaceAll("\\.", "/"));
        String basePath = resource.getFile();
        File baseDir = new File(basePath);
        //递归遍历目录下的class文件,收集要加载的类路径
        File[] files = baseDir.listFiles();
        for (File file : files) {
            if (file.isDirectory()) {
                doScanPackage(basePackage + "." + file.getName());
            } else {
                if (file.getName().endsWith(".class")) {
                    classPaths.add(basePackage + "." + file.getName().replaceAll(".class", ""));
                }
            }
        }
    }

    /**
     * 将字符串的首字母转换为小写返回
     *
     * @param origin
     * @return
     */
    private String lowerFirstChar(String origin) {
        if (origin == null || origin.trim().length() == 0)
            return origin;
        char[] chars = origin.toCharArray();
        if (chars[0] >= 'A' && chars[0] <= 'Z') {
            chars[0] += 32;
        }
        return new String(chars);
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        //获取请求的url
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        //截掉contextPath,得到真实url路径
        url = url.replaceAll(contextPath, "");
        //查找url对应的method对象
        Method method = (Method) urlHandlers.get(url);
        if (null == method) {
            resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
            resp.getWriter().write("404 not found");
            resp.getWriter().flush();
            return;
        }

        //获取method对应class的bean实例
        Class<?> clazz = method.getDeclaringClass();
        String beanName = lowerFirstChar(clazz.getSimpleName());
        Object instance = beans.get(beanName);
        if (null == instance) {
            resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
            resp.getWriter().write("404 not found");
            resp.getWriter().flush();
            return;
        }

        //使用参数处理器从req中解析方法需要的参数
        HandService handler = (HandService) beans.get("HandServiceImpl");
        Object[] args = handler.hand(req, resp, method, beans);
        try {
            //执行反射调用
            Object result = method.invoke(instance, args);
            resp.setHeader("Content-Type", "application/json");
            resp.setStatus(HttpServletResponse.SC_OK);
            resp.getWriter().write(JSON.toJSONString(result));
            resp.getWriter().flush();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
    }

}

方法处理器:从request中解析方法参数

package com.zyu.mvc.handle;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.Map;

/**
 * 从request请求参数中解析method参数的service
 */
public interface HandService {
    Object[] hand(HttpServletRequest req, HttpServletResponse resp, Method method, Map<String,Object> beans);
}

package com.zyu.mvc.handle.impl;

import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;
import com.zyu.mvc.handle.HandService;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;

/**
 * 参数处理service的实现类
 */
@Service("HandServiceImpl")
public class HandServiceImpl implements HandService {
    @Override
    public Object[] hand(HttpServletRequest req, HttpServletResponse resp, Method method, Map<String, Object> beans) {
        //获取参数的类型列表
        Class<?>[] paraTypes = method.getParameterTypes();
        Object[] args = new Object[paraTypes.length];

        //从容器中获取所有参数解析器的类
        Map<String,Object> argResolvers = getArgResolvers(ArgumentResolver.class,beans);
        int index = 0, args_i = 0;
        //遍历参数类型列表,从req重解析对应的参数值
        for (Class<?> paraType : paraTypes) {
            for (Map.Entry<String, Object> argResolver : argResolvers.entrySet()) {
                ArgumentResolver ar = (ArgumentResolver) argResolver.getValue();
                if(ar.support(paraType,method,index)){
                    args[args_i++] = ar.resolver(req,resp,paraType,method,index);
                }
            }
            index++;
        }
        return args;
    }

    /**
     * 获取所有的参数解析器
     * @param type
     * @param beans
     * @return
     */
    private Map<String, Object> getArgResolvers(Class<?> type, Map<String, Object> beans) {
        HashMap<String,Object> argResolvers = new HashMap<>();
        beans.forEach((beanName, bean) -> {
            Class<?> clazz = bean.getClass();
            Class<?>[] interfaces = clazz.getInterfaces();
            for (Class<?> inf : interfaces) {
                if(inf.isAssignableFrom(type)){
                    argResolvers.put(beanName,bean);
                }
            }
        });
        return argResolvers;
    }

}

参数解析器:从request中解析指定类型的参数

参数解析器接口
package com.zyu.mvc.argumentresolver;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;

/**
 * 参数解析器接口
 */
public interface ArgumentResolver {
    /**
     * 检查是否支持解析该参数类型
     * @param type 要解析的类型
     * @param method 要解析的方法
     * @param index  要解析的参数次序
     * @return
     */
    boolean support(Class<?> type, Method method, int index);

    /**
     * 执行解析,从request中解析指定方法的指定次序的参数值
     * @param req
     * @param resp
     * @param type
     * @param method
     * @param index
     * @return
     */
    Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index);
}

HttpServletRequest解析器实现类
package com.zyu.mvc.argumentresolver.impl;

import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;

import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;

/**
 * HttpServletRequest类型参数解析器
 */
@Service("httpServletRequestArgResolver")
public class HttpServletRequestArgResolver implements ArgumentResolver {
    @Override
    public boolean support(Class<?> type, Method method, int index) {
        return ServletRequest.class.isAssignableFrom(type);
    }

    @Override
    public Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index) {
        return req;
    }
}

HttpServletResponse解析器实现类
package com.zyu.mvc.argumentresolver.impl;

import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;

import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
/**
 * HttpServletResponse类型参数解析器
 */
@Service("httpServletResponseArgResolver")
public class HttpServletResponseArgResolver implements ArgumentResolver {
    @Override
    public boolean support(Class<?> type, Method method, int index) {
        return ServletResponse.class.isAssignableFrom(type);
    }

    @Override
    public Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index) {
        return resp;
    }
}

RequestParam解析器实现类
package com.zyu.mvc.argumentresolver.impl;

import com.zyu.mvc.anno.RequestParam;
import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;

/**
 * RequestParam参数解析器
 */
@Service("requestParamArgResolver")
public class RequestParamArgResolver implements ArgumentResolver {
    @Override
    public boolean support(Class<?> type, Method method, int index) {
        //二维数组,每一维度表示一个参数
        Annotation[] paramAns = method.getParameterAnnotations()[index];
        for (Annotation paramAn : paramAns) {
            if(RequestParam.class.isAssignableFrom(paramAn.getClass())){
                return true;
            }
        }
        return false;
    }

    @Override
    public Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index) {
        //二维数组,每一维度表示一个参数
        Annotation[] paramAns = method.getParameterAnnotations()[index];
        Object arg = null;
        if (paramAns.length > 0) {
            for (Annotation paramAn : paramAns) {
                // 获取参数@RequestParam注解的值
                if (RequestParam.class.isAssignableFrom(paramAn.getClass())) {
                    RequestParam requestParam = (RequestParam) paramAn;
                    String argValue = req.getParameter(requestParam.value());
                    //解析为指定的类型
                    if (String.class.isAssignableFrom(type)) {
                        arg = argValue;
                    } else if (Integer.class.isAssignableFrom(type)) {
                        arg = Integer.valueOf(argValue);
                    } else if (Boolean.class.isAssignableFrom(type)) {
                        arg = Boolean.valueOf(argValue);
                    } else if (Long.class.isAssignableFrom(type)) {
                        arg = Long.valueOf(argValue);
                    } else if (Double.class.isAssignableFrom(type)) {
                        arg = Double.valueOf(argValue);
                    } else if (Float.class.isAssignableFrom(type)) {
                        arg = Float.valueOf(argValue);
                    }
                }
            }
        }
        return arg;
    }
}

web.xml配置

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_3_1.xsd"
         version="3.1">

    <!--配置自定义的DispatcherServlet-->
    <servlet>
        <servlet-name>DispatcherServlet</servlet-name>
        <servlet-class>com.zyu.mvc.servlet.DispatcherServlet</servlet-class>
        <load-on-startup>0</load-on-startup>
    </servlet>

    <servlet-mapping>
        <servlet-name>DispatcherServlet</servlet-name>
        <url-pattern>/</url-pattern>
    </servlet-mapping>
</web-app>

测试

定义一个controller

package com.zyu.mvc.controller;

import com.zyu.mvc.anno.Autowired;
import com.zyu.mvc.anno.Controller;
import com.zyu.mvc.anno.RequestMapping;
import com.zyu.mvc.anno.RequestParam;
import com.zyu.mvc.service.TestService;

@Controller
@RequestMapping("/test")
public class TestController {
    @Autowired("testService")
    TestService testService;

    @RequestMapping("/sayHello")
    public String sayHello(@RequestParam("username") String username, @RequestParam("age") Integer age) {
       return testService.sayHello(username, age);
    }
}

定义一个Service

接口
package com.zyu.mvc.service;

public interface TestService {
    String sayHello(String username,int age);
}

实现类
package com.zyu.mvc.service.impl;

import com.zyu.mvc.anno.Service;
import com.zyu.mvc.service.TestService;

@Service("testService")
public class TestServiceImpl implements TestService {
    @Override
    public String sayHello(String username, int age) {
        return String.format("hello,%s %d", username, age);
    }
}

配置tomcat

用惯了boot的我在这里还“研究”了一会。。。
在这里插入图片描述

请求

在这里插入图片描述

结束语

收获满满啊,学无止境,诸君共勉

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值