手写一个springx(一)

源码:https://github.com/zhaoyunxing92/springx

前言

如果你稍微留意下commit就会这个项目早在好2个月前就创建了,我也一直想写下,但是确实太懒了就一直拖着着。这几天由于没有太多开发任务就有想写了。

依赖环境

  • jdk:1.8

  • tomcat8.5

  • servlet3.1.0

  • maven3.6.0

正文

​ 1. 这个项目又两个分支:webxmlnowebxml,我开始是基于servlet2.5写的,后面我又想用servlet3.x了于是就两个分支了(后面主要代码都在nowebxml上)。

2. Servlet容器启动会扫描,当前应用里面每一个jar包的ServletContainerInitializer的实现,然后就注入一个DispatcherServlet 就可以了,跟在web.xml配置一样,只是这里都在java代码里面了(看过spring代码的同学可能一眼就看出了它也是有个SpringServletContainerInitializer类作为人口的)

public class SpringxServletContainerInitializer implements ServletContainerInitializer {
    @Override
    public void onStartup(Set<Class<?>> c, ServletContext ctx) {
        ServletRegistration.Dynamic dispatcherServlet = ctx.addServlet("dispatcherServlet", DispatcherServlet.class);
        dispatcherServlet.setInitParameter("scanPackage", "com.sunny.springx.example");
        dispatcherServlet.addMapping("/*");
        dispatcherServlet.setLoadOnStartup(1);
    }
}

3.上面代码可以看出有个scanPackage参数,这个是告诉springx从那里开始工作。DispatcherServlet继承HttpServlet所以开始会调用init方法,然后就是大概四步就可以完成一个简单的springx

    @Override
    public void init(ServletConfig config) {
        Instant start = Instant.now();
        // 1.扫描相关类
        doScanner(config.getInitParameter("scanPackage"));
        // 2.初始化类
        doInstance();
        // 3. 注入
        doAutoWried();
        // 4. 初始化url
        initHandlerMapping();

        Instant end = Instant.now();

        System.out.println("springx init in " + Duration.between(start, end).toMillis() + " ms");
    }

​ 4. 主要逻辑都在DispatcherServlet这里面

public class DispatcherServlet extends HttpServlet {

    //存放class name
    private List<String> classNames = new ArrayList<>();
    // ioc 容器
    private Map<String, Object> ioc = new HashMap<>();
    //handMapping
    private List<Handler> handlerMapping = new ArrayList<>();

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        try {
            doDispatch(req, resp);
        } catch (Exception e) {
            resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
            resp.getWriter().write("500 " + e.getMessage());
        }
    }

    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");

        Handler handler = getHandler(url);
        //url 不存在404
        if (Objects.isNull(handler)) {
            resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
            resp.getWriter().write("not found url : " + url);
            return;
        }

        // 获取方法
        Object invoke = handler.method.invoke(handler.controller, null);
        resp.setStatus(HttpServletResponse.SC_OK);
        resp.getWriter().write(invoke.toString());
    }


    @Override
    public void init(ServletConfig config) {
        Instant start = Instant.now();
        // 1.扫描相关类
        doScanner(config.getInitParameter("scanPackage"));
        // 2.初始化类
        doInstance();
        // 3. 注入
        doAutoWried();
        // 4. 初始化url
        initHandlerMapping();

        Instant end = Instant.now();

        System.out.println("springx init in " + Duration.between(start, end).toMillis() + " ms");
    }

    /**
     * 获取 handler
     *
     * @param url 请求url
     * @return
     */
    private Handler getHandler(String url) {
        if (handlerMapping.isEmpty()) return null;

        for (Handler handler : handlerMapping) {
            Matcher matcher = handler.pattern.matcher(url);
            if (!matcher.matches()) continue;
            return handler;
        }
        return null;
    }

    private void initHandlerMapping() {
        if (ioc.isEmpty()) return;
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            // 获取类
            Class<?> clazz = entry.getValue().getClass();
            //没有controller注解的跳过
            if (!clazz.isAnnotationPresent(Controller.class)) continue;
            //根url
            String rootUrl = "";
            if (clazz.isAnnotationPresent(RequestMapping.class)) {
                RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class);
                rootUrl = requestMapping.value();
            }

            //扫描类全部方法
            for (Method method : clazz.getMethods()) {

                //只处理RequestMapping注解
                if (!method.isAnnotationPresent(RequestMapping.class)) continue;
                RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
                // 拼接rootUrl 全局替换避免两个斜杠 正则
                String reg = ("/" + rootUrl + requestMapping.value()).replaceAll("/+", "/");

                handlerMapping.add(new Handler(method, entry.getValue(), Pattern.compile(reg)));
                System.out.println("mapping:" + reg + "," + method);
            }
        }

    }

    private void doAutoWried() {
        if (ioc.isEmpty()) return;

        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            // 获取类中全部的字段
            Field[] fields = entry.getValue().getClass().getDeclaredFields();
            for (Field field : fields) {

                // 不包含autowried注解的跳过
                if (!field.isAnnotationPresent(Autowried.class)) continue;

                Autowried autowried = field.getAnnotation(Autowried.class);

                String beanId = autowried.value().trim();
                if (StringUtils.isBlank(beanId)) {
                    beanId = field.getName();
                }
                //设置授权
                field.setAccessible(true);
                try {
                    // 字段赋值
                    field.set(entry.getValue(), ioc.get(beanId));
                } catch (IllegalAccessException ex) {
                    ex.printStackTrace();
                }
            }
        }
    }

    private void doInstance() {
        if (classNames.isEmpty()) return;
        try {
            for (String className : classNames) {
                //根据名称实例化加了controller和service注解的类
                Class<?> clazz = Class.forName(className);
                //beanId 默认类小写
                String beanId;
                if (clazz.isAnnotationPresent(Controller.class)) {// 处理Controller注解
                    Controller controller = clazz.getAnnotation(Controller.class);
                    beanId = controller.value();

                    if (StringUtils.isBlank(beanId)) {
                        beanId = StringUtils.lowerFirstCase(clazz.getSimpleName());
                    }

                    ioc.put(beanId, clazz.newInstance());
                } else if (clazz.isAnnotationPresent(Service.class)) { // 处理Service 注解
                    // 获取注解
                    Service service = clazz.getAnnotation(Service.class);

                    beanId = service.value();
                    if (StringUtils.isBlank(beanId)) {
                        beanId = StringUtils.lowerFirstCase(clazz.getSimpleName());
                    }

                    Object instance = clazz.newInstance();
                    ioc.put(beanId, instance);
                    // 获取类的实现
                    for (Class<?> anInterface : clazz.getInterfaces()) {
                        ioc.put(StringUtils.lowerFirstCase(anInterface.getSimpleName()), instance);
                    }
                }
            }
        } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
            e.printStackTrace();
        }
    }

    /**
     * 加载路径下的全部class name
     *
     * @param scanPackage 包开始扫描路径
     */
    private void doScanner(String scanPackage) {
        // com.sunny.springx.example 路径替换文件路径
        URL url = getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        assert url != null;
        File classDir = new File(url.getFile());

        for (File file : Objects.requireNonNull(classDir.listFiles())) {
            //如果是文件夹继续扫描
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else {
                String className = scanPackage + "." + file.getName().replaceAll(".class", "");
                classNames.add(className);
            }
        }
    }

    private ClassLoader getClassLoader() {
        return this.getClass().getClassLoader();
    }

    private class Handler {
        Method method;
        //方法对象实力
        Object controller;
        //url正则
        Pattern pattern;

        Handler(Method method, Object controller, Pattern pattern) {
            this.method = method;
            this.controller = controller;
            this.pattern = pattern;
        }
    }
}

结尾

这个项目写的比较粗糙,后面有时间我打算把spring的能力都在这个项目里面体现下。比如参数绑定我就没有实现,只是定义了注解还没有实现。当然你也可以fork去写这玩玩,感谢围观
  zhaoyunxing微信公众号

关注公众号一起写bug,一起嘻嘻哈哈

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值