手写springMVC的流程
看过springmvc源码的小伙伴应该都了解它的运行过程,springmvc是基于servlet来实现的,DispatcherServlet
继承FrameworkServlet
继承HttpServletBean
继承HttpServlet
,其中还涉及到其他类,这里不多说了,当我们启动tomcat容器时,tomcat容器启动之后经过一系列操作后会调用DispatcherServle t#initWebApplicationContext()
方法,随后调用onRefresh(wac)
方法
/**
* This implementation calls {@link #initStrategies}.
*/
@Override
protected void onRefresh(ApplicationContext context) {
initStrategies(context);
}
/**
* 这里就是初始化了各种策略,然后又加载了各种各样的东西,例如handlerMapping对应关系
* 注解如何去解析,视图解析。字符编码等等等
* <p>May be overridden in subclasses in order to initialize further strategy objects.
*/
protected void initStrategies(ApplicationContext context) {
initMultipartResolver(context);
initLocaleResolver(context);
initThemeResolver(context);
initHandlerMappings(context);
initHandlerAdapters(context);
initHandlerExceptionResolvers(context);
initRequestToViewNameTranslator(context);
initViewResolvers(context);
initFlashMapManager(context);
}
总结下来大致为以下步骤(简化版)
1.加载配置文件,初始化环境。其中重要的就是获得需要扫描的包名
2.扫描包,然后将所有类的全类名加入到一个list中
3.遍历list,利用反射得到相应的类然后判断是否含有@Controller
·等注解,如果有将其放入IOC容器iocMap
中,map类型为Map<String,Object>
,String
为类的简单名称(Class.getSimpleName()
)或者@Service(value='')
中的value,Object
为类的实例。此时beanName已经确定。
4.DI注入。遍历iocMap
,利用反射得到类中的私有属性,然后判断该字段是否将含有@Autowired
,如果有,获取该字段类型,根据类型名去iocMap
中查找,然后赋值
5.handlerMapping
初始化。遍历iocMap
,如果类含有@Controller
注解,然后再次判断是否含有@RequestMapping
注解,遍历此类中的Method,判断是否含有@RequestMapping
,然后拼接url,将url放入Map<String,Method>handlerMapping
中。
6.doDispatcherServlet,根据请求url去handlerMapping
中查找method,然后method.invoke()
调用方法。
上代码
引入依赖
<dependencies>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>3.1.0</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>1.2.3</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
</dependency>
</dependencies>
4个注解类
/**
* @PACKAGE_NAME: com.lzx.springmvc.annotation
* @AUTHOR: lzx
* @DATE: 2019/10/23 15:15
* @DESCRIBE:
**/
@Target(ElementType.FIELD)
@Documented
@Retention(RetentionPolicy.RUNTIME)
public @interface Autowoired {
boolean required() default true;
String value() default "";
}
/**
* @PACKAGE_NAME: com.lzx.springmvc.annotation
* @AUTHOR: lzx
* @DATE: 2019/10/23 15:26
* @DESCRIBE:
**/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
@Documented
public @interface Controller {
String value() default "";
}
/**
* @PACKAGE_NAME: com.lzx.springmvc.annotation
* @AUTHOR: lzx
* @DATE: 2019/10/23 15:23
* @DESCRIBE:
**/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE,ElementType.METHOD})
@Documented
public @interface RequestMapping {
String value() default "";
}
/**
* @PACKAGE_NAME: com.lzx.springmvc.annotation
* @AUTHOR: lzx
* @DATE: 2019/10/23 15:35
* @DESCRIBE:
**/
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {
String value() default "";
}
核心类DispatcherServlet
/**
* @PACKAGE_NAME: com.lzx.springmvc.dispatcherServlet
* @AUTHOR: lzx
* @DATE: 2019/10/23 15:42
* @DESCRIBE:
**/
@WebServlet(name = "mySpringMVC",
urlPatterns = "/",
loadOnStartup = 1,
asyncSupported = true,
initParams = {@WebInitParam(name = "contextConfigLocation", value = "application.properties")})
public class DispatcherServlet extends HttpServlet {
private Logger logger = LoggerFactory.getLogger(DispatcherServlet.class);
private Properties properties = new Properties();
//IOC容器
private Map<String, Object> iocMap = new HashMap<>();
private Map<String, Method> handlerMapping = new HashMap<>();
//存储所有全类名
private List<String> classList = new ArrayList<>();
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
doDispatcherServlet(req, resp);
} catch (InvocationTargetException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.service(req, resp);
}
@Override
public void init(ServletConfig config) throws ServletException {
//1.加载配置文件
//2.扫描包 并将handlerMapping放入list中,将instance放入ioc容器中
//3.进行DI操作,将ioc中没有赋值的类实例赋值
//4.查找所有handlerMapping
loadConfig(config.getInitParameter("contextConfigLocation"));
doScanPackage(properties.getProperty("scan.package"));
initIOC();
doAutowired();
doHandlerMapping(config);
printAllData();
}
/**
* 1,加载配置文件
*
* @param initParameter
*/
public void loadConfig(String initParameter) {
InputStream in = this.getClass().getClassLoader().getResourceAsStream(initParameter);
try {
properties.load(in);
logger.info("[mySpringMVC INFO-1] : 加载配置文件");
} catch (IOException e) {
e.printStackTrace();
} finally {
if (null != in) {
try {
in.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
/**
* 2.开始进行包扫描
* 将类的全类名放入{@link DispatcherServlet#classList}中
*/
public void doScanPackage(String scanPackage) {
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
if (url == null) {
logger.info("[mySpringMVC INFO-2] : 扫描路径不对");
return;
}
File files = new File(url.getPath());
for (File file : files.listFiles()) {
if (!file.isDirectory()) {
if (!file.getName().endsWith(".class")) {
logger.info("[mySpringMVC INFO-2] : " + file.getName() + "不是一个class文件");
continue;
} else {
String className = (scanPackage + "." + file.getName()).replaceAll("/", ".")
.replace(".class", "");
classList.add(className);
logger.info("[mySpringMVC INFO-2] : {} 已放入classList中", file.getName());
}
} else {
doScanPackage(scanPackage + "." + file.getName());
logger.info("[mySpringMVC INFO-2] : {} 是一个文件夹", file.getName());
}
}
}
/**
* 3.初始化ioc容器,并将所有类实例放入{@link DispatcherServlet#iocMap}容器中
*/
public void initIOC() {
if (classList.isEmpty()) {
return;
}
for (String name : classList) {
try {
Class<?> clazz = Class.forName(name);
Object instance;
String beanName;
if (clazz.isAnnotationPresent(Controller.class)) {
instance = clazz.newInstance();
beanName = getFirstLowercase(clazz.getSimpleName());
iocMap.put(beanName, instance);
logger.info("iocMap放入beanName为{}", beanName);
} else if (clazz.isAnnotationPresent(Service.class)) {
beanName = getFirstLowercase(clazz.getSimpleName());
Service service = clazz.getAnnotation(Service.class);
if (!"".equals(service.value())) {
beanName = service.value();
}
instance = clazz.newInstance();
iocMap.put(beanName, instance);
logger.info("iocMap放入beanName为{}", beanName);
for (Class c : clazz.getInterfaces()) {
if (iocMap.containsKey(c.getName())) {
throw new Exception("此beanName已经存在");
}
iocMap.put(c.getName(), instance);
logger.info("iocMap放入beanName为{}", c.getName());
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
/**
* 4.开始依赖注入
*/
public void doAutowired(){
if(iocMap.isEmpty()){
return;
}
for(Map.Entry<String,Object> entry : iocMap.entrySet()){
Field[] fields = entry.getValue().getClass().getDeclaredFields();
for(Field field : fields){
field.setAccessible(true);
if(!field.isAnnotationPresent(Autowoired.class)){
continue;
}
Autowoired autowoired = field.getAnnotation(Autowoired.class);
String beanName = field.getType().getName();
if(!"".equals(autowoired.value())){
beanName = autowoired.value();
}
try {
field.set(entry.getValue(),iocMap.get(beanName));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
/**
* 6.初始化handlerMapping 记录url
*/
public void doHandlerMapping(ServletConfig config){
if(iocMap.isEmpty()){
return;
}
for(Map.Entry<String,Object> entry : iocMap.entrySet()){
Class<?> clazz = entry.getValue().getClass();
if(clazz.isAnnotationPresent(Controller.class)){
String url;
if(clazz.isAnnotationPresent(RequestMapping.class)){
RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class);
url = requestMapping.value();
for(Method method : clazz.getMethods()){
if(method.isAnnotationPresent(RequestMapping.class)){
RequestMapping methodRequestMapping = method.getAnnotation(RequestMapping.class);
String methodUrl = methodRequestMapping.value();
methodUrl = url + methodUrl ;
String contextPath = config.getServletContext().getContextPath() + methodUrl;
handlerMapping.put(contextPath,method);
}
}
}
}
}
}
/**
* 运行阶段,方法调用 通过请求的url。然后找到相对应的方法,用反射调用
* @param req
* @param resp
*/
public void doDispatcherServlet(HttpServletRequest req, HttpServletResponse resp) throws InvocationTargetException, IllegalAccessException {
String uri = req.getRequestURI();
Method method = handlerMapping.get(uri);
if(method != null){
String simpleName = method.getDeclaringClass().getSimpleName();
Object o = iocMap.get(getFirstLowercase(simpleName));
method.invoke(o,req,resp);
}else {
try {
PrintWriter writer = resp.getWriter();
writer.write(new String("请输入正确的url路径".getBytes("utf-8")));
writer.flush();
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 打印所有数据
*/
public void printAllData() {
logger.info("---------iocMap---------");
iocMap.entrySet().stream().forEach(t -> logger.info(String.valueOf(t)));
logger.info("---------classList---------");
classList.stream().forEach(t -> logger.info(t));
logger.info("---------handlerMapping---------");
handlerMapping.entrySet().stream().forEach(t -> logger.info(String.valueOf(t)));
}
/**
* 将beanName的首字母小写
*
* @param beanName
* @return
*/
public String getFirstLowercase(String beanName) {
char[] chars = beanName.toCharArray();
chars[0] += 32;
return new String(chars);
}
}
字符过滤器Filter
/**
* @PACKAGE_NAME: com.lzx.springmvc.Filter
* @AUTHOR: lzx
* @DATE: 2019/10/24 16:23
* @DESCRIBE:
**/
@WebFilter(filterName = "encoding",
urlPatterns = "/*")
public class EncodingFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
System.out.println(this.toString() + "过滤器初始化.....");
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
System.out.println("过滤器生效中.............");
HttpServletResponse resp = (HttpServletResponse) response;
HttpServletRequest req = (HttpServletRequest) request;
req.setCharacterEncoding("utf-8");
resp.setHeader("Content-type", "text/html;charset=UTF-8");
resp.setCharacterEncoding("UTF-8");
chain.doFilter(req, resp);
}
@Override
public void destroy() {
System.out.println(this.toString() + "销毁.....");
}
}
测试用例
@Controller
@RequestMapping("/hello")
public class HelloWorld {
@Autowoired
private World world;
@RequestMapping("/say")
public void say(HttpServletRequest request, HttpServletResponse response) throws IOException {
String name = request.getParameter("name");
String world = this.world.helloWorld(name);
PrintWriter writer = response.getWriter();
writer.write(world);
writer.flush();
writer.close();
}
@RequestMapping("/send")
public void send(HttpServletRequest request, HttpServletResponse response) throws IOException {
PrintWriter writer = response.getWriter();
writer.write("QQ.com");
writer.flush();
writer.close();
}
}
/**
* @PACKAGE_NAME: com.lzx.springmvc.mvcDemo.service.impl
* @AUTHOR: lzx
* @DATE: 2019/10/23 20:38
* @DESCRIBE:
**/
@Service
public class WorldImpl implements World {
@Override
public String helloWorld(String name) {
return "hello " + name;
}
}
public interface World {
String helloWorld(String name);
}
配置文件application.properties
scan.package = com.lzx.springmvc
日志文件
<?xml version="1.0" encoding="UTF-8"?>
<configuration scan="true" scanPeriod="5 seconds" debug="false">
<property name="mySpringMVC2" value="mySpringMVC2"/>
<!-- 日志路径 -->
<timestamp key="byDate" datePattern="yyyyMMdd"/>
<timestamp key="bySecond" datePattern="yyyyMMdd'T'HHmmss"/>
<contextName>${appName}</contextName>
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} %5p %t %logger{15} - %m%n</pattern>
</encoder>
</appender>
<root level="INFO">
<appender-ref ref="console"/>
</root>
</configuration>
到此结束,还有很多很多需要改进的地方,日后慢慢学习,大神轻喷,望指正不足。多谢。