前言
做Java的应该不会有人没听过Spring吧!!Spring是一个优秀的IOC容器,为了增加对spring的了解,自己简单实现一个SpringMVC。
先说明以下几点:
1、IOC容器是个啥?就是一个大Map,可以理解为HashMap<String,Object>
2、springmvc的核心从哪开始?DispatcherServlet
- 熟悉servlet编程的小伙子,肯定都知道,面向Servlet编程时,会在web.xml中配置一个ServletMapping,指定了每个Servlet要拦截的URL,然后做对应的业务处理。这也就是说一个Servlet基本上只能处理一个URL业务了!!!
- SpringMVC,做了一些改变,将URL细粒度到方法层面(也可以认为是一个大Map,HashMap<String,Method>),使用反射method.invoke(instance,args),对应的URL去找对应的Method执行,大大简化了代码。
下面就简单实现一个SpringMVC,资源下载地址加深一下印象。
分析
DispatcherServlet都要干些啥
上面说了,所有的URL请求都要先经过Servlet,通过serlvet的doPost或doGet方法进行处理。并且,每个URL对应一个方法(controller中的方法),那要做的事情,无非就是以下几件:
- 将所有的Controller、service类扫描出来进行实例化,放入IOC容器中
- 处理controller和service对象的依赖关系(controller需要调用service的方法)
- 将URL和方法映射起来,做一个URLMapping。
- 处理请求时,将请求中的参数解析为method方法的参数
实现
pom依赖
<dependencies>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>3.1.0</version>
<scope>provided</scope>
</dependency>
<!-- fastjson支持 -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.40</version>
</dependency>
</dependencies>
定义一些注解
仿照SpringMVC,定义@Controller、@Service、@Autowired、@RequestMapping、@RequestParam
这些注解的作用,大家应该都是很熟悉的吧
package com.zyu.mvc.anno;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(value = {ElementType.TYPE})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface Controller {
}
package com.zyu.mvc.anno;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(value = {ElementType.TYPE})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface Service {
String value() default "";
}
package com.zyu.mvc.anno;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(value = {ElementType.FIELD})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface Autowired {
String value() default "";
}
package com.zyu.mvc.anno;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(value = {ElementType.TYPE, ElementType.METHOD})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface RequestMapping {
String value() default "";
}
package com.zyu.mvc.anno;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(value = {ElementType.PARAMETER})
@Retention(value = RetentionPolicy.RUNTIME)
public @interface RequestParam {
String value() default "";
}
DispatcherServlet实现
package com.zyu.mvc.servlet;
import com.alibaba.fastjson.JSON;
import com.zyu.mvc.anno.*;
import com.zyu.mvc.handle.HandService;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
/**
* 自定义DispatcherServlet,用于拦截请求,做业务处理
*/
public class DispatcherServlet extends HttpServlet {
/**
* 扫描的类路径集合
*/
List<String> classPaths = new ArrayList<>();
/**
* IOC容器
*/
HashMap<String, Object> beans = new HashMap<>();
/**
* url路径和方法的映射
*/
HashMap<String, Object> urlHandlers = new HashMap<>();
/**
* 覆盖父类的init方法,做一些初始化内容
*
* @throws ServletException
*/
@Override
public void init() throws ServletException {
super.init();
// 1 扫描包路径
doScanPackage("com.zyu.mvc");
// 2 实例化对象
doInstance();
// 3 处理依赖的注入
resolveDependencies();
// 4 url映射
doUrlMapping();
}
/**
* 处理方法和路径的映射
*/
private void doUrlMapping() {
//将controller中对应方法和配置的url路径映射起来
beans.forEach((beanName, instance) -> {
Class<?> clazz = instance.getClass();
if (clazz.isAnnotationPresent(Controller.class)) {
String baseUrl = "";
if (clazz.isAnnotationPresent(RequestMapping.class)) {
RequestMapping anno = clazz.getAnnotation(RequestMapping.class);
if (!"".equals(anno.value().trim())) {
baseUrl = anno.value().trim();
}
}
Method[] methods = clazz.getDeclaredMethods();
for (Method method : methods) {
if (method.isAnnotationPresent(RequestMapping.class)) {
String methodUrl = "";
RequestMapping anno = method.getAnnotation(RequestMapping.class);
if (!"".equals(anno.value().trim())) {
methodUrl = anno.value().trim();
}
urlHandlers.put(baseUrl + methodUrl, method);
}
}
}
});
}
/**
* 处理自动注入
*/
private void resolveDependencies() {
//遍历IOC容器,解决bean依赖问题,处理Autowired自动注入
beans.forEach((beanName, instance) -> {
Class<?> clazz = instance.getClass();
if (clazz.isAnnotationPresent(Controller.class)) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
if (field.isAnnotationPresent(Autowired.class)) {
Autowired anno = field.getAnnotation(Autowired.class);
String key = anno.value().trim().equals("") ? lowerFirstChar(field.getType().getSimpleName()) : anno.value();
field.setAccessible(true);
try {
field.set(instance, beans.get(key));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
});
}
/**
* 实例化扫描的对象
*/
private void doInstance() {
//遍历收集的类信息,实例化controller和service实例,放入ioc容器中
classPaths.forEach(classPath -> {
try {
Class<?> clazz = Class.forName(classPath);
if (clazz.isAnnotationPresent(Controller.class)) {
beans.put(lowerFirstChar(clazz.getSimpleName()), clazz.newInstance());
} else if (clazz.isAnnotationPresent(Service.class)) {
Service anno = clazz.getAnnotation(Service.class);
String key = anno.value();
if (!"".equals(key.trim())) {
beans.put(key, clazz.newInstance());
} else {
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> inter : interfaces) {
beans.put(lowerFirstChar(inter.getSimpleName()), clazz.newInstance());
}
}
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
}
});
}
/**
* 扫描需要加载的类路径
*
* @param basePackage
*/
private void doScanPackage(String basePackage) {
//获取要加载的包路径
URL resource = this.getClass().getClassLoader().getResource("/" + basePackage.replaceAll("\\.", "/"));
String basePath = resource.getFile();
File baseDir = new File(basePath);
//递归遍历目录下的class文件,收集要加载的类路径
File[] files = baseDir.listFiles();
for (File file : files) {
if (file.isDirectory()) {
doScanPackage(basePackage + "." + file.getName());
} else {
if (file.getName().endsWith(".class")) {
classPaths.add(basePackage + "." + file.getName().replaceAll(".class", ""));
}
}
}
}
/**
* 将字符串的首字母转换为小写返回
*
* @param origin
* @return
*/
private String lowerFirstChar(String origin) {
if (origin == null || origin.trim().length() == 0)
return origin;
char[] chars = origin.toCharArray();
if (chars[0] >= 'A' && chars[0] <= 'Z') {
chars[0] += 32;
}
return new String(chars);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
this.doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
//获取请求的url
String url = req.getRequestURI();
String contextPath = req.getContextPath();
//截掉contextPath,得到真实url路径
url = url.replaceAll(contextPath, "");
//查找url对应的method对象
Method method = (Method) urlHandlers.get(url);
if (null == method) {
resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
resp.getWriter().write("404 not found");
resp.getWriter().flush();
return;
}
//获取method对应class的bean实例
Class<?> clazz = method.getDeclaringClass();
String beanName = lowerFirstChar(clazz.getSimpleName());
Object instance = beans.get(beanName);
if (null == instance) {
resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
resp.getWriter().write("404 not found");
resp.getWriter().flush();
return;
}
//使用参数处理器从req中解析方法需要的参数
HandService handler = (HandService) beans.get("HandServiceImpl");
Object[] args = handler.hand(req, resp, method, beans);
try {
//执行反射调用
Object result = method.invoke(instance, args);
resp.setHeader("Content-Type", "application/json");
resp.setStatus(HttpServletResponse.SC_OK);
resp.getWriter().write(JSON.toJSONString(result));
resp.getWriter().flush();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
}
}
方法处理器:从request中解析方法参数
package com.zyu.mvc.handle;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.Map;
/**
* 从request请求参数中解析method参数的service
*/
public interface HandService {
Object[] hand(HttpServletRequest req, HttpServletResponse resp, Method method, Map<String,Object> beans);
}
package com.zyu.mvc.handle.impl;
import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;
import com.zyu.mvc.handle.HandService;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
/**
* 参数处理service的实现类
*/
@Service("HandServiceImpl")
public class HandServiceImpl implements HandService {
@Override
public Object[] hand(HttpServletRequest req, HttpServletResponse resp, Method method, Map<String, Object> beans) {
//获取参数的类型列表
Class<?>[] paraTypes = method.getParameterTypes();
Object[] args = new Object[paraTypes.length];
//从容器中获取所有参数解析器的类
Map<String,Object> argResolvers = getArgResolvers(ArgumentResolver.class,beans);
int index = 0, args_i = 0;
//遍历参数类型列表,从req重解析对应的参数值
for (Class<?> paraType : paraTypes) {
for (Map.Entry<String, Object> argResolver : argResolvers.entrySet()) {
ArgumentResolver ar = (ArgumentResolver) argResolver.getValue();
if(ar.support(paraType,method,index)){
args[args_i++] = ar.resolver(req,resp,paraType,method,index);
}
}
index++;
}
return args;
}
/**
* 获取所有的参数解析器
* @param type
* @param beans
* @return
*/
private Map<String, Object> getArgResolvers(Class<?> type, Map<String, Object> beans) {
HashMap<String,Object> argResolvers = new HashMap<>();
beans.forEach((beanName, bean) -> {
Class<?> clazz = bean.getClass();
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> inf : interfaces) {
if(inf.isAssignableFrom(type)){
argResolvers.put(beanName,bean);
}
}
});
return argResolvers;
}
}
参数解析器:从request中解析指定类型的参数
参数解析器接口
package com.zyu.mvc.argumentresolver;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
/**
* 参数解析器接口
*/
public interface ArgumentResolver {
/**
* 检查是否支持解析该参数类型
* @param type 要解析的类型
* @param method 要解析的方法
* @param index 要解析的参数次序
* @return
*/
boolean support(Class<?> type, Method method, int index);
/**
* 执行解析,从request中解析指定方法的指定次序的参数值
* @param req
* @param resp
* @param type
* @param method
* @param index
* @return
*/
Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index);
}
HttpServletRequest解析器实现类
package com.zyu.mvc.argumentresolver.impl;
import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
/**
* HttpServletRequest类型参数解析器
*/
@Service("httpServletRequestArgResolver")
public class HttpServletRequestArgResolver implements ArgumentResolver {
@Override
public boolean support(Class<?> type, Method method, int index) {
return ServletRequest.class.isAssignableFrom(type);
}
@Override
public Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index) {
return req;
}
}
HttpServletResponse解析器实现类
package com.zyu.mvc.argumentresolver.impl;
import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
/**
* HttpServletResponse类型参数解析器
*/
@Service("httpServletResponseArgResolver")
public class HttpServletResponseArgResolver implements ArgumentResolver {
@Override
public boolean support(Class<?> type, Method method, int index) {
return ServletResponse.class.isAssignableFrom(type);
}
@Override
public Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index) {
return resp;
}
}
RequestParam解析器实现类
package com.zyu.mvc.argumentresolver.impl;
import com.zyu.mvc.anno.RequestParam;
import com.zyu.mvc.anno.Service;
import com.zyu.mvc.argumentresolver.ArgumentResolver;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
/**
* RequestParam参数解析器
*/
@Service("requestParamArgResolver")
public class RequestParamArgResolver implements ArgumentResolver {
@Override
public boolean support(Class<?> type, Method method, int index) {
//二维数组,每一维度表示一个参数
Annotation[] paramAns = method.getParameterAnnotations()[index];
for (Annotation paramAn : paramAns) {
if(RequestParam.class.isAssignableFrom(paramAn.getClass())){
return true;
}
}
return false;
}
@Override
public Object resolver(HttpServletRequest req, HttpServletResponse resp, Class<?> type, Method method, int index) {
//二维数组,每一维度表示一个参数
Annotation[] paramAns = method.getParameterAnnotations()[index];
Object arg = null;
if (paramAns.length > 0) {
for (Annotation paramAn : paramAns) {
// 获取参数@RequestParam注解的值
if (RequestParam.class.isAssignableFrom(paramAn.getClass())) {
RequestParam requestParam = (RequestParam) paramAn;
String argValue = req.getParameter(requestParam.value());
//解析为指定的类型
if (String.class.isAssignableFrom(type)) {
arg = argValue;
} else if (Integer.class.isAssignableFrom(type)) {
arg = Integer.valueOf(argValue);
} else if (Boolean.class.isAssignableFrom(type)) {
arg = Boolean.valueOf(argValue);
} else if (Long.class.isAssignableFrom(type)) {
arg = Long.valueOf(argValue);
} else if (Double.class.isAssignableFrom(type)) {
arg = Double.valueOf(argValue);
} else if (Float.class.isAssignableFrom(type)) {
arg = Float.valueOf(argValue);
}
}
}
}
return arg;
}
}
web.xml配置
<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_3_1.xsd"
version="3.1">
<!--配置自定义的DispatcherServlet-->
<servlet>
<servlet-name>DispatcherServlet</servlet-name>
<servlet-class>com.zyu.mvc.servlet.DispatcherServlet</servlet-class>
<load-on-startup>0</load-on-startup>
</servlet>
<servlet-mapping>
<servlet-name>DispatcherServlet</servlet-name>
<url-pattern>/</url-pattern>
</servlet-mapping>
</web-app>
测试
定义一个controller
package com.zyu.mvc.controller;
import com.zyu.mvc.anno.Autowired;
import com.zyu.mvc.anno.Controller;
import com.zyu.mvc.anno.RequestMapping;
import com.zyu.mvc.anno.RequestParam;
import com.zyu.mvc.service.TestService;
@Controller
@RequestMapping("/test")
public class TestController {
@Autowired("testService")
TestService testService;
@RequestMapping("/sayHello")
public String sayHello(@RequestParam("username") String username, @RequestParam("age") Integer age) {
return testService.sayHello(username, age);
}
}
定义一个Service
接口
package com.zyu.mvc.service;
public interface TestService {
String sayHello(String username,int age);
}
实现类
package com.zyu.mvc.service.impl;
import com.zyu.mvc.anno.Service;
import com.zyu.mvc.service.TestService;
@Service("testService")
public class TestServiceImpl implements TestService {
@Override
public String sayHello(String username, int age) {
return String.format("hello,%s %d", username, age);
}
}
配置tomcat
用惯了boot的我在这里还“研究”了一会。。。
请求
结束语
收获满满啊,学无止境,诸君共勉