目录
1.先编写@Controller和@RequestMapping注解
根据前文中,我们了解到DispatcherServlet其实就是个servlet,首先他先进行初始化,将所有HandlerMapping进行缓存(Map存储),初始化其他各种组件。它将请求的处理都集中到doDispatch方法,交由HandlerMapping去找到对应的处理的method,HandlerAdapter去调用这个method,所以HandlerAdapter就相当于反射执行方法,但是这里包含比较多的参数赋值,请求路径的参数,session中的参数等等,然后返回ModelAndView,ViewResolver解析视图,渲染视图等等之后返回给用户。我们如何来实现一个简易的springmvc框架。
1.先编写@Controller和@RequestMapping注解
@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyController {
String value() default "";
}
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD,ElementType.TYPE})
public @interface MyRequestMapping {
String value()default"";
}
2.实现DispatcherServlet
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URL;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.jasper.tagplugins.jstl.core.Param;
import springmvc.myspringmvc.annotation.MyController;
import springmvc.myspringmvc.annotation.MyRequestMapping;
public class MyDispatcherServlet extends HttpServlet {
//需要扫描的包,这里就固定了,有需要的可以在文件中配置读取
private String packageName = "springmvc.myspringmvc.controller";
// ioc容器
private Map<String, Object> iocMap = new HashMap<>();
// url -> method进行映射
private Map<String, Object> handlerMapping = new HashMap<>();
// url -> 类的对象进行映射
private Map<String, Object> urlMapping = new HashMap<>();
@Override
public void init() throws ServletException {
// 先进行初始化
// 1.扫描包中所有的类,然后如果此类的注解是@MyController的类进行加载
Long startTime = System.currentTimeMillis();
initClass(packageName);
Long endTime = System.currentTimeMillis();
// 2.对@MyRequestMapping的注解进行加载
initHandle();
}
private void initHandle() {
// 从ioc容器中所有的类进行遍历
iocMap.forEach((key, value) -> {
String path = "";
Class<?> clazz = null;
try {
clazz = Class.forName(key);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
// 判断此类是不是有@MyRequestMapping的注解
if (clazz != null && clazz.isAnnotationPresent(MyRequestMapping.class)) {
// 获取标签中的值
path += clazz.getAnnotation(MyRequestMapping.class).value();
}
// 获取类中所有的method
Method[] methods = clazz.getMethods();
for (Method method : methods) {
// 判断方法上是否有有@MyRequestMapping的注解
if (method.isAnnotationPresent(MyRequestMapping.class)) {
String url = path + method.getAnnotation(MyRequestMapping.class).value();
urlMapping.put(url, value);
handlerMapping.put(url, method);
}
}
});
}
private void initClass(String packageName) {
System.out.println(packageName);
URL uri = this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));
// 扫描路径下的文件
File directory = new File(uri.getFile());
// 如果扫描到的是文件夹,则递归扫描
for (File file : directory.listFiles()) {
System.out.println(packageName + " " + file.getName());
if (file.isDirectory()) {
initClass(packageName + "." + file.getName());
} else {
String className = packageName + "." + file.getName();
try {
if (className.lastIndexOf(".class") != -1 || className.lastIndexOf(".class") != -1) {
className = className.substring(0, className.lastIndexOf("."));
}
Class<?> clazz = Class.forName(className);
// 判断类是不是有@MyController注解
if (clazz.isAnnotationPresent(MyController.class)) {
// 进行初始化,确保是单例对象
if (iocMap.get(className) == null) {
Object obj = clazz.newInstance();
iocMap.put(className, obj);
}
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
doDispatch(req, resp);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
e.printStackTrace();
}
}
private void doDispatch(HttpServletRequest req, HttpServletResponse resp)
throws IOException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
String url = req.getRequestURI();
String context = req.getContextPath();
// 去掉请求路径的前缀
url = url.replaceAll(context, "");
if (urlMapping.get(url) != null) {
// 获取method
Method method = (Method) handlerMapping.get(url);
// 获取method的参数类型
Class[] parameters = method.getParameterTypes();
// 获取请求的参数
Map<String, String[]> reqMap = req.getParameterMap();
// 作为反射时传入方法的value值
Object[] values = new Object[parameters.length];
int i = 0;
for (Class clazz : parameters) {
String requestParam = clazz.getSimpleName();
if (requestParam.equals("HttpServletRequest")) {
values[i++] = req;
} else if (requestParam.equals("HttpServletResponse")) {
values[i++] = resp;
} else {
int flag = i;
reqMap.forEach((k, v) -> {
String val = Arrays.toString(v).replaceAll("\\[|\\]", "").replaceAll(",\\s", ",");
values[flag] = val;
});
i++;
}
}
// 反射执行方法
method.invoke(urlMapping.get(url), values);
} else {
resp.getWriter().println("404,resource not find");
}
}
}
3.编写Controller测试类
import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpRequest;
import springmvc.myspringmvc.annotation.MyController;
import springmvc.myspringmvc.annotation.MyRequestMapping;
@MyController
public class MyControllerTest {
@MyRequestMapping("/test")
public void test(HttpServletRequest request, HttpServletResponse response, String username) {
try {
response.getWriter().println("success, this is test,username = " + username);
} catch (IOException e) {
e.printStackTrace();
}
}
@MyRequestMapping("/ok")
public void ok(HttpServletRequest request, HttpServletResponse response) {
try {
response.getWriter().println("success,this is ok");
} catch (IOException e) {
e.printStackTrace();
}
}
}
4.注册使用自定义的DispatcherServlet
<servlet>
<servlet-name>mySpringmvc</servlet-name>
<servlet-class>springmvc.myspringmvc.servlet.MyDispatcherServlet</servlet-class>
<load-on-startup>1</load-on-startup>
</servlet>
<servlet-mapping>
<servlet-name>mySpringmvc</servlet-name>
<url-pattern>/*</url-pattern>
</servlet-mapping>
5.测试
先来测试个不存在的,因为自定义的DispatcherServlet没有实现视图解析器,所以访问视图也是无法访问的,因为被我们的DispatcherServlet拦截了。