模仿一个简单版的SpringMVC框架

一、如果只使用传统的servlet处理web请求时,我们的代码可能是这样的。

public class BaseServlet extends HttpServlet {
    @Override
    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String action = request.getParameter("action");
        System.out.println("action为:" + action);
        // 通过反射拿到方法名,并执行方法
        try {
            Class clazz = this.getClass();
            Method method = clazz.getMethod(action, HttpServletRequest.class, HttpServletResponse.class);
            method.invoke(this, request, response);
        } catch (Exception e) {
            e.printStackTrace();
        }

    }

    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doPost(request, response);
    }
}

@WebServlet(urlPatterns="/linkman")
public class LinkManServlet extends BaseServlet {

    public void query(HttpServletRequest request, HttpServletResponse response) {
        System.out.println("query....");
    }
    public void add(HttpServletRequest request, HttpServletResponse response) {
        System.out.println("add...");
    }
    public void update(HttpServletRequest request, HttpServletResponse response) {
        System.out.println("update....");
    }
    public void del(HttpServletRequest request, HttpServletResponse response) {
        System.out.println("del...");
    }
}
存在的问题
  • 前端的每个请求都要传一个action参数
  • 每个servlet都要继承BaseServlet

二、为了解决这两个问题

  • 我们首先要有这么一个工具类: 用来获取指定包下的所有的class
public class ClassScannerUtils {

    /**
     * 获得包下面的所有的class
     * @param
     * @return List包含所有class的实例
     */
    public static List<Class<?>> getClasssFromPackage(String packageName) {
        List clazzs = new ArrayList<>();
        // 是否循环搜索子包
        boolean recursive = true;
        // 包名对应的路径名称
        String packageDirName = packageName.replace('.', '/');
        Enumeration<URL> dirs;

        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            while (dirs.hasMoreElements()) {

                URL url = dirs.nextElement();
                String protocol = url.getProtocol();
                if ("file".equals(protocol)) {
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    findClassInPackageByFile(packageName, filePath, recursive, clazzs);
                }
            }

        } catch (Exception e) {
            e.printStackTrace();
        }
        return clazzs;
    }

    /**
     * 在package对应的路径下找到所有的class
     */
    public static void findClassInPackageByFile(String packageName, String filePath, final boolean recursive,
                                                List<Class<?>> clazzs) {
        File dir = new File(filePath);
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }
        // 在给定的目录下找到所有的文件,并且进行条件过滤
        File[] dirFiles = dir.listFiles(new FileFilter() {

            public boolean accept(File file) {
                boolean acceptDir = recursive && file.isDirectory();// 接受dir目录
                boolean acceptClass = file.getName().endsWith("class");// 接受class文件
                return acceptDir || acceptClass;
            }
        });

        for (File file : dirFiles) {
            if (file.isDirectory()) {
                findClassInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, clazzs);
            } else {
                String className = file.getName().substring(0, file.getName().length() - 6);
                try {
                    clazzs.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + "." + className));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }


}
  • 我们定义一个注解
@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestMapping {
    String value();
}
  • 写一个的Servlet处理器: 目的是为了,处理所有的以.do为结尾的请求。按照以下的逻辑去执行。大概的意思就是:先获取请求的路径不包含项目路径。再拿到指定包下所有的class集合,遍历class集合,并反射拿到有RequestMapping注解的方法的注解值,与请求路径比较,相同就invoke调用方法。
@WebServlet(urlPatterns = "*.do")
public class DispatcherServlet extends HttpServlet {
    @Override
    protected void doPost(HttpServletRequest request, HttpServletResponse response) {
        try {
            // 获取客户端请求路径
            String uri = request.getRequestURI();
            //  /my_springMVC_war_exploded/login.do
            System.out.println(uri);
            // 获取项目路径
            String contextPath = request.getContextPath();
            //  /my_springMVC_war_exploded
            String requestPath = uri.substring(contextPath.length(), uri.lastIndexOf("."));
            // 截取到请求路径为: /login
            System.out.println(requestPath);
            // 找到注解@RequestMapping的value和客户端请求路径相等的方法。
            // 获得包下面的所有的class
            List<Class<?>> classList = ClassScannerUtils.getClasssFromPackage("com.hy.controller");

            // 遍历class集合
            for (Class<?> clazz : classList) {
                // 反射拿到该class中所有的方法
                Method[] methods = clazz.getMethods();
                for (Method method : methods) {
                    boolean isMappingMethod = method.isAnnotationPresent(RequestMapping.class);
                    if (isMappingMethod) {
                        RequestMapping mapping = method.getAnnotation(RequestMapping.class);
                        String mappingPath = mapping.value();
                        System.out.println(mappingPath);

                        if (mappingPath.equals(requestPath)) {
                            method.invoke(clazz.newInstance(),request, response);
                            return;
                        }
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doPost(request, response);
    }
}
这样也存在问题
  • 客户端每次请求都要遍历指定包里所有类。效率低。
  • 客户端每次请求都要反射创建对象。效率低。

三、于是乎,我们可以这样,定义一个map用来存储,第一次请求时,遍历得到的类.方法.注解的值,后面每次请求,我们就从Map中拿数据。

  • 再重新定义一下注解: @Controller和@RequestMapping
@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestMapping {
    String value();
}

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Controller {
}
  • 定义一个POJO类:用来存储方法和类的实例
public class MVCMethod {
    private Method method;
    private Object object;

    @Override
    public String toString() {
        return "MVCMethod{" +
                "method=" + method +
                ", object=" + object +
                '}';
    }

    public MVCMethod(Method method, Object object) {
        this.method = method;
        this.object = object;
    }

    public Method getMethod() {
        return method;
    }

    public void setMethod(Method method) {
        this.method = method;
    }

    public Object getObject() {
        return object;
    }

    public void setObject(Object object) {
        this.object = object;
    }
}
  • 定义一个核心的控制器类:map里存着的就是要反射调用方法需要的class实例,和method对象。
@WebServlet(urlPatterns = "*.do", loadOnStartup = 1)
public class DispatcherServlet2 extends HttpServlet {

    private Map<String, MVCMethod> map = new HashMap<>();

    @Override
    public void init() throws ServletException {
        // 获取指定包下所有的class
        try {
            List<Class<?>> classList = ClassScannerUtils.getClasssFromPackage("com.my.controller");

            for (Class<?> clazz : classList) {
                // 判断该clazz是否有controller注解
                boolean annotationPresent = clazz.isAnnotationPresent(Controller.class);
                if (!annotationPresent) {
                    continue;
                }
                boolean clazzAnnotationPresent = clazz.isAnnotationPresent(RequestMapping.class);
                String mappingPath = "";
                if (clazzAnnotationPresent) {
                    mappingPath = clazz.getAnnotation(RequestMapping.class).value();
                    System.out.println("类上的请求参数:" + mappingPath);
                }
                // 拿到该class里所有的方法
                Method[] methods = clazz.getMethods();
                for (Method method : methods) {
                    // 判断该方法上是否有RequestMapping注解
                    boolean methodAnnotationPresent = method.isAnnotationPresent(RequestMapping.class);
                    if (methodAnnotationPresent) {
                        // 如果有这个注解就把这个方法对象和clazz实例存到MVCMethod对象中
                        MVCMethod mvcMethod = new MVCMethod(method, clazz.newInstance());
                        RequestMapping mapping = method.getAnnotation(RequestMapping.class);
                        System.out.println("方法上的请求参数:" + mapping.value());
                        String path = mappingPath + mapping.value();
                        System.out.println("最终的请求参数:" + path);
                        //以mappingPath为key,mvcMethod为value存储到map中
                        map.put(path, mvcMethod);
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

    }

    @Override
    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {

        try {
            //  /my_springMVC_war_exploded/login.do
            String requestURI = request.getRequestURI();
            String contextPath = request.getContextPath();
            String reqPath = requestURI.substring(contextPath.length(), requestURI.lastIndexOf("."));

            // 反射调用方法
            MVCMethod mvcMethod = map.get(reqPath);
            Method method = mvcMethod.getMethod();
            method.invoke(mvcMethod.getObject(),request, response);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doPost(request, response);
    }
}
  • 于是乎我们就可以这样写一个Controller类
@Controller
@RequestMapping("/linkman")
public class LinkManController {

    @RequestMapping("/queryAll")
    public void queryAll(HttpServletRequest request, HttpServletResponse response) {
        System.out.println("queryAll: 查询所有联系人。。。");
    }

    @RequestMapping("/delete")
    public void delete(HttpServletRequest request, HttpServletResponse response) {
        System.out.println("delete: 删除联系人。。。");
    }
}

到此一个简易版的SpringMVC就结束了。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值