模拟SpringMVC注解@Controller, @Param, @RequestMapping
- @Controller
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Controller {
}
- @Param
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface Param {
String value();
}
- @RequestMapping
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface RequestMapping {
String url();
int method() default 0; //不加默认加到类上
}
- Interface MethodType
public interface MethodType {
int GET = 1;
int POST = 2;
}
- ControllerDefinition类
public class ControllerDefinition {
//对应的Controller的Class
private Class controllerClazz;
//Controller中get方法 key-@RequestMapping值 value-对应的方法
private Map<String, Method> getMethodMap;
//post方法
private Map<String, Method> postMethodMap;
//调用方法时候要用到的实例,后期整合Spring用Spring管理
private Object instance;
public ControllerDefinition(Class controllerClazz, Map<String, Method> getMethodMap, Map<String, Method> postMethodMap) {
this.controllerClazz = controllerClazz;
this.getMethodMap = getMethodMap;
this.postMethodMap = postMethodMap;
try {
this.instance = controllerClazz.newInstance();
} catch (InstantiationException | IllegalAccessException e) {
e.printStackTrace();
}
}
public Class getControllerClazz() {
return controllerClazz;
}
public Map<String, Method> getGetMethodMap() {
return getMethodMap;
}
public Map<String, Method> getPostMethodMap() {
return postMethodMap;
}
public Object getInstance() {
return instance;
}
}
- MainFilter
//通过用户创建一个Filter继承MainFilter
//Filter过滤所有请求 /* 到MainFilter中过滤
//Controller必须在Filter同包里
public class MainFilter implements Filter {
//解析请求 分发请求的对象
DispatcherServlet dispatcherServlet;
//带有@Controller注解的权限类型
List<String> list;
public MainFilter() {
//扫描用户所定义的Filter所在包
String fullName = this.getClass().getName();
//获取次包里所有带有@Controller注解的class文件
list = classList(fullName);
}
//跟上一篇文章spring获取包里文件差不多
private List<String> classList(String fullName) {
String packageName = fullName.substring(0, fullName.lastIndexOf("."));
String realPath = packageName.replace(".", "/");
URL resource = this.getClass().getClassLoader().getResource(realPath);
File file = new File(resource.getFile());
List<String> list = new ArrayList<>();
getClassNames(file, list);
return list;
}
//扫描文件
private void getClassNames(File file, List<String> classNames) {
if (file.isDirectory()) {
File[] files = file.listFiles();
for (File f : files) {
if (f.isDirectory()) {
getClassNames(f, classNames);
}
String fileName = f.getAbsolutePath();
if (fileName.endsWith(".class")) {
String className = fileName.substring(fileName.indexOf("com"), fileName.indexOf(".class"));
className = className.replace("\\", ".");
classNames.add(className);
}
}
}
}
//当用户定义的Filter实例化时 会进到此init()方法里初始化dispatcherServlet
@Override
public void init(FilterConfig filterConfig) throws ServletException {
//把这些带有@Controller注解的全限类名传给dispatcherServlet
//生成处理请求的HandlerMapping对象
dispatcherServlet = new DispatcherServlet(list);
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
//获取请求路径
String uri = request.getRequestURI();
//如果是静态资源.css .html .js .jsp 则放行
if (checkPass(uri)) {
filterChain.doFilter(servletRequest, servletResponse);
} else {
try {
//如果是请求传给dispatcherServlet寻找可执行handler
Object result = dispatcherServlet.dispatchRequest(uri, request.getMethod(), request, response);
//解析方法执行结果
parseResponse(result, request, response);
} catch (InvocationTargetException | IllegalAccessException e) {
e.printStackTrace();
}
}
}
//判断是否是静态资源
private boolean checkPass(String uri){
String[] ends = new String[]{".jsp", ".css", ".html", ".js"};
for (String end : ends){
if (uri.endsWith(end)) {
return true;
}
}
return false;
}
//解析方法执行结果
private void parseResponse(Object result, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
//如果Controller中方法的返回值是String
if (result instanceof String){
String path = (String) result;
String[] splits = path.split(":");
//如果return "redirect:/index.jsp"
//则response.sendRedirect()
if ("redirect".equals(splits[0].trim())){
response.sendRedirect(splits[1]);
}else if ("forward".equals(splits[0].trim())){
//如果return "forward:/index.jsp"
//重定向
request.getRequestDispatcher(splits[1]).forward(request, response);
}else {
//如果是普通文本直接向页面输出
ServletOutputStream outputStream = response.getOutputStream();
outputStream.write(((String)result).getBytes());
}
}else {
//不是String类型都转换成JSON字符串并向页面输出
ServletOutputStream outputStream = response.getOutputStream();
if (result == null){
return;
}
String r = JSONObject.toJSONString(result);
outputStream.write(r.getBytes());
}
}
@Override
public void destroy() {
}
}
- DispatcherServlet
public class DispatcherServlet{
//操作对应请求的对象 下面会写
private HandlerMapping handler;
public DispatcherServlet(List<String> list){
handler = new HandlerMapping();
for (String className : list){
try {
Class<?> clazz = Class.forName(className);
//扫描@Controller注解
scan(clazz);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
handler.init();
}
//根据MainFilter传过来的请求路径uri和请求类型type
//让handler解析请求并返回对应 Controller实例 和 方法Method
public Object dispatchRequest(String uri, String type,HttpServletRequest request, HttpServletResponse response) throws InvocationTargetException, IllegalAccessException {
if ("GET".equals(type)){
Map<Method, Object> map = this.handler.handler(uri, MethodType.GET);
//执行方法 返回Object
return invoke(map, request, response);
}
if ("POST".equals(type)){
Map<Method, Object> map = this.handler.handler(uri, MethodType.POST);
return invoke(map, request, response);
}
return null;
}
//传入解析完后的参数 执行方法
private Object invoke(Map<Method, Object> map, HttpServletRequest request, HttpServletResponse response) throws InvocationTargetException, IllegalAccessException {
Set<Map.Entry<Method, Object>> entries = map.entrySet();
for (Map.Entry<Method, Object> entry : entries){
Method method = entry.getKey();
Object instance = entry.getValue();
Object[] args = parseArgTypes(method, request, response);
Object invoke = method.invoke(instance, args);
return invoke;
}
return null;
}
//解析方法参数只能注入request response session
//@Param("name")可以注入String Integer Double
private Object[] parseArgTypes(Method method, HttpServletRequest request, HttpServletResponse response){
Type[] types = method.getGenericParameterTypes();
Parameter[] parameters = method.getParameters();
List<Object> args = new ArrayList<>();
for (int i = 0; i < types.length; i++) {
if(parameters[i].isAnnotationPresent(Param.class)){
Param paramAnnotation = parameters[i].getDeclaredAnnotation(Param.class);
String param = paramAnnotation.value();
String attribute = request.getParameter(param);
if (attribute == null) {
throw new RuntimeException("Illegal Null Param");
}
if ("java.lang.Integer".equals(types[i].getTypeName())){
args.add(Integer.valueOf(attribute));
}else if ("java.lang.Double".equals(types[i].getTypeName())){
args.add(Double.valueOf(attribute));
}else {
args.add(attribute);
}
continue;
}
if ("javax.servlet.http.HttpServletRequest".equals(types[i].getTypeName())){
args.add(request);
}
if ("javax.servlet.http.HttpServletResponse".equals(types[i].getTypeName())){
args.add(response);
}
if ("javax.servlet.http.HttpSession".equals(types[i].getTypeName())){
args.add(request.getSession());
}
}
return args.toArray();
}
private void scan(Class clazz){
//扫描@Controller注解 添加到controllerMap里
if (clazz.isAnnotationPresent(Controller.class)){
//向HandlerMapping对象添加Controller
//如果Controller上带有@RequestMapping注解 url为注解值"/url"
//如果没有则默认 "/"
if (clazz.isAnnotationPresent(RequestMapping.class)){
RequestMapping requestMappingAnnotation = (RequestMapping) clazz.getDeclaredAnnotation(RequestMapping.class);
String url = requestMappingAnnotation.url();
handler.addController("/" + url, clazz);
}else {
handler.addController("/", clazz);
}
}
}
}
- HandlerMapping类
public class HandlerMapping {
// key为 Controller上的@RequestMapping值 value为对应的Controller类
Map<String, Class> controllerMap;
// key为对应的Controller类
// value为 ControllerDefinition类(Class controllerClazz, Map<String, Method> getMethodMap, Map<String, Method> postMethodMap, Object instance;)
Map<Class, ControllerDefinition> cache;
public HandlerMapping(){
controllerMap = new HashMap<>();
cache = new HashMap<>();
}
public void addController(String url, Class clazz){
controllerMap.put(url, clazz);
}
/*
//TODO 这个对象是 ControllerDefinition里自己生成的 以后由Spring管理
解析Uri并获取对应Controller的对应的Method和Controller对象
*/
public Map<Method, Object> handler(String uri, int type){
Map<Method, Object> map = new HashMap<>();
String controllerMapping = "/";
String methodMapping = "";
String s = uri.substring(1);
if (s.contains("/")){
controllerMapping += s.substring(0, s.indexOf("/"));
methodMapping += s.substring(s.indexOf("/"));
}else {
methodMapping += uri;
}
Class clazz = controllerMap.get(controllerMapping);
if (null == clazz){
throw new RuntimeException("No Mapping for " + controllerMapping);
}
ControllerDefinition definition = cache.get(clazz);
if (type == MethodType.GET){
Map<String, Method> getMethodMap = definition.getGetMethodMap();
Method method = getMethodMap.get(methodMapping);
if (method == null){
throw new RuntimeException("No Mapping for" + controllerMapping + methodMapping);
}
map.put(method, definition.getInstance());
return map;
}else if (type == MethodType.POST){
Map<String, Method> postMethodMap = definition.getPostMethodMap();
Method method = postMethodMap.get(methodMapping);
if (method == null){
throw new RuntimeException("No Mapping for" + controllerMapping + methodMapping);
}
map.put(method, definition.getInstance());
return map;
}else {
throw new RuntimeException("Illegal Method Type");
}
}
public void init() {
//获取所有带有@Controller注解的类
Collection<Class> classes = controllerMap.values();
//生成ControllerDefinition对象并缓存
for (Class clazz : classes){
ControllerDefinition definition = createControllerDefinition(clazz);
cache.put(clazz, definition);
}
}
//扫描所有带有@RequestMapping的方法
private ControllerDefinition createControllerDefinition(Class clazz) {
Method[] methods = clazz.getDeclaredMethods();
Map<String, Method> getMap = new HashMap<>();
Map<String, Method> postMap = new HashMap<>();
for (Method method : methods){
method.setAccessible(true);
if (method.isAnnotationPresent(RequestMapping.class)){
RequestMapping requestMappingAnnotation = (RequestMapping) method.getDeclaredAnnotation(RequestMapping.class);
String url = requestMappingAnnotation.url();
int value = requestMappingAnnotation.method();
if (value == MethodType.GET){
getMap.put(url, method);
}else if (value == MethodType.POST){
postMap.put(url, method);
}
}
}
return new ControllerDefinition(clazz, getMap, postMap);
}
}
下一篇文章把Spring和SpringMVC整合起来 再加上之前模拟的Mybatis的BaseDao一起测试,这个就不写测试了
//添加一个自定义过滤器
@WebFilter("/*")
public class MyFilter extends MainFilter {
}
//Controller
@Controller
@RequestMapping(url = "test")
public class MyController {
@RequestMapping(url = "/a", method = MethodType.GET)
public String a(HttpServletRequest request){
request.setAttribute("Message", "SpringMVC测试成功");
return "forward:/index.jsp";
}
@RequestMapping(url = "/b", method = MethodType.GET)
public String b(@Param("name") String name, HttpServletRequest request){
request.setAttribute("Message", name);
return "forward:/index.jsp";
}
}