整合一下自己写的简单版Spring, SpringMVC, MyBatis-Plus
不废话上代码
修改部分写了注解
- 修改MainFilter
public abstract class MainFilter implements Filter {
private DispatcherServlet dispatcherServlet;
private ApplicationContext spring;
protected abstract Class getConfigClass();//需要子类重写传进Spring的配置类
public MainFilter() {
//子类传入Spring配置信息,初始化Spring容器
spring = new ApplicationContext(getConfigClass());
this.dispatcherServlet = spring.dispatcherServlet;
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
String uri = request.getRequestURI();
if (checkPass(uri)) {
filterChain.doFilter(servletRequest, servletResponse);
} else {
try {
Object result = dispatcherServlet.dispatchRequest(uri, request.getMethod(), request, response);
parseResponse(result, request, response);
} catch (InvocationTargetException | IllegalAccessException e) {
e.printStackTrace();
}
}
}
private boolean checkPass(String uri){
String[] ends = new String[]{".jsp", ".css", ".html", ".js"};
for (String end : ends){
if (uri.endsWith(end)) {
return true;
}
}
return false;
}
private void parseResponse(Object result, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
if (result instanceof String){
String path = (String) result;
String[] splits = path.split(":");
if ("redirect".equals(splits[0].trim())){
response.sendRedirect(splits[1]);
}else if ("forward".equals(splits[0].trim())){
request.getRequestDispatcher(splits[1]).forward(request, response);
}else {
ServletOutputStream outputStream = response.getOutputStream();
outputStream.write(((String)result).getBytes());
}
}else {
ServletOutputStream outputStream = response.getOutputStream();
if (result == null){
return;
}
String r = JSONObject.toJSONString(result);
outputStream.write(r.getBytes());
}
}
@Override
public void destroy() {
}
}
- 修改ApplicationContext类
public class ApplicationContext {
private Class configClass;
private ConcurrentHashMap<String, Object> singletonObjects = new ConcurrentHashMap<>();
private ConcurrentHashMap<String, BeanDefinition> beanDefinitionMap = new ConcurrentHashMap<>();
//Spring容器扫描@Controller注解并缓存
private List<Class> controllers = new ArrayList<>();
//在内部初始化DispatcherServlet
DispatcherServlet dispatcherServlet;
public ApplicationContext(Class configClass) {
this.configClass = configClass;
scan(configClass);
dispatcherServlet = new DispatcherServlet(controllers);
Method[] methods = configClass.getDeclaredMethods();
try {
Object o = configClass.newInstance();
for (Method method : methods) {
if (method.isAnnotationPresent(Bean.class)) {
method.setAccessible(true);
Object instance = method.invoke(o);
String beanName = method.getName();
//缓存调用方法返回的对象,beanName为方法名
singletonObjects.put(beanName, instance);
Class<?> clazz = method.getReturnType();
BeanDefinition beanDefinition = new BeanDefinition(clazz, "singleton");
beanDefinitionMap.put(beanName, beanDefinition);
}
}
} catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
e.printStackTrace();
}
for (String beanName : beanDefinitionMap.keySet()) {
BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
if ("singleton".equals(beanDefinition.getScope())) {
Object bean = null;
try {
bean = createBean(beanName, beanDefinition);
} catch (IllegalAccessException | InstantiationException e) {
e.printStackTrace();
}
singletonObjects.put(beanName, bean);
}
}
Map<String, Object> controllerObjects = new HashMap<>();
//遍历带有@Controller注解的类
controllers.forEach(clazz -> {
String name = clazz.getName();
Object o = singletonObjects.get(name);
controllerObjects.put(name, o);
});
dispatcherServlet.setControllerObjects(controllerObjects);
}
private void scan(Class configClass) {
List<String> list = new ArrayList<>();
ClassLoader classLoader = ApplicationContext.class.getClassLoader();
if (configClass.isAnnotationPresent(ComponentScan.class)) {
ComponentScan componentScanAnnotation = (ComponentScan) configClass.getDeclaredAnnotation(ComponentScan.class);
String path = componentScanAnnotation.value();
//包名 例如com.jxz.annotation 替换成 com/jxz/annotation
String realPath = path.replace(".", "/");
URL resource = classLoader.getResource(realPath);
File file = new File(resource.getFile());
//递归获取此文件里的所有以.class结尾的全限类名 保存到list里
getClassNames(file, list);
} else {
throw new RuntimeException("找不到指定配置信息");
}
list.stream().forEach(className -> {
try {
Class<?> clazz = classLoader.loadClass(className);
//扫描@Controller注解 添加到controllerMap里
if (clazz.isAnnotationPresent(Controller.class)){
controllers.add(clazz);
BeanDefinition beanDefinition = new BeanDefinition();
beanDefinition.setClazz(clazz);
beanDefinition.setScope("singleton");
beanDefinitionMap.put(className, beanDefinition);
}
if (clazz.isAnnotationPresent(Component.class)) {
Component componentAnnotation = clazz.getDeclaredAnnotation(Component.class);
String beanName = componentAnnotation.value();
BeanDefinition beanDefinition = new BeanDefinition();
beanDefinition.setClazz(clazz);
if (clazz.isAnnotationPresent(Scope.class)) {
Scope scopeAnnotation = clazz.getDeclaredAnnotation(Scope.class);
String scope = scopeAnnotation.value();
if (!"singleton".equals(scope) && !"prototype".equals(scope)) {
throw new RuntimeException("不支持的Scope类型");
}
beanDefinition.setScope(scope);
} else {
beanDefinition.setScope("singleton");
}
beanDefinitionMap.put(beanName, beanDefinition);
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
});
}
private Object createBean(String beanName, BeanDefinition beanDefinition) throws IllegalAccessException, InstantiationException {
Object o = singletonObjects.get(beanName);
if (null != o && "singleton".equals(beanDefinition.getScope())) {
return o;
}
Class clazz = beanDefinition.getClazz();
Object instance = clazz.newInstance();
for (Field field : clazz.getDeclaredFields()) {
if (field.isAnnotationPresent(Autowired.class)) {
field.setAccessible(true);
Object bean = getBean(field.getName());
if (bean == null) {
BeanDefinition beanD = beanDefinitionMap.get(field.getName());
if (null != beanD) {
bean = createBean(field.getName(), beanD);
if ("singleton".equals(beanD.getScope())) {
singletonObjects.put(field.getName(), bean);
}
} else {
throw new RuntimeException("Can't autowired" + field.getName());
}
}
field.set(instance, bean);
}
}
return instance;
}
private void getClassNames(File file, List<String> classNames) {
if (file.isDirectory()) {
File[] files = file.listFiles();
for (File f : files) {
if (f.isDirectory()) {
getClassNames(f, classNames);
}
String fileName = f.getAbsolutePath();
if (fileName.endsWith(".class")) {
String className = fileName.substring(fileName.indexOf("com"), fileName.indexOf(".class"));
className = className.replace("\\", ".");
classNames.add(className);
}
}
}
}
public Object getBean(String beanName) {
if (beanDefinitionMap.containsKey(beanName)) {
BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
if ("singleton".equals(beanDefinition.getScope())) {
Object o = singletonObjects.get(beanName);
return o;
}
if ("prototype".equals(beanDefinition.getScope())) {
try {
return createBean(beanName, beanDefinition);
} catch (IllegalAccessException | InstantiationException e) {
e.printStackTrace();
}
}
}
return null;
}
}
- 修改DispatcheServlet
public class DispatcherServlet{
private HandlerMapping handler;
//缓存Controller单例对象
private Map<String, Object> controllerObjects;
public DispatcherServlet(List<Class> controllers){
handler = new HandlerMapping();
//实例化直接从spring内部传来list
for (Class clazz : controllers){
scan(clazz);
}
handler.init();
}
public Object dispatchRequest(String uri, String type,HttpServletRequest request, HttpServletResponse response) throws InvocationTargetException, IllegalAccessException {
if ("GET".equals(type)){
Map<Method, Class> map = this.handler.handler(uri, MethodType.GET);
return invoke(map, request, response);
}
if ("POST".equals(type)){
Map<Method, Class> map = this.handler.handler(uri, MethodType.POST);
return invoke(map, request, response);
}
return null;
}
private Object invoke(Map<Method, Class> map, HttpServletRequest request, HttpServletResponse response) throws InvocationTargetException, IllegalAccessException {
Set<Map.Entry<Method, Class>> entries = map.entrySet();
for (Map.Entry<Method, Class> entry : entries){
Method method = entry.getKey();
Object instance = controllerObjects.get(entry.getValue().getName());
Object[] args = parseArgTypes(method, request, response);
Object invoke = method.invoke(instance, args);
return invoke;
}
return null;
}
private Object[] parseArgTypes(Method method, HttpServletRequest request, HttpServletResponse response){
Type[] types = method.getGenericParameterTypes();
Parameter[] parameters = method.getParameters();
List<Object> args = new ArrayList<>();
for (int i = 0; i < types.length; i++) {
if(parameters[i].isAnnotationPresent(Param.class)){
Param paramAnnotation = parameters[i].getDeclaredAnnotation(Param.class);
String param = paramAnnotation.value();
String attribute = request.getParameter(param);
if (attribute == null) {
throw new RuntimeException("Illegal Null Param");
}
if ("java.lang.Integer".equals(types[i].getTypeName())){
args.add(Integer.valueOf(attribute));
}else if ("java.lang.Double".equals(types[i].getTypeName())){
args.add(Double.valueOf(attribute));
}else {
args.add(attribute);
}
continue;
}
if ("javax.servlet.http.HttpServletRequest".equals(types[i].getTypeName())){
args.add(request);
}
if ("javax.servlet.http.HttpServletResponse".equals(types[i].getTypeName())){
args.add(response);
}
if ("javax.servlet.http.HttpSession".equals(types[i].getTypeName())){
args.add(request.getSession());
}
}
return args.toArray();
}
private void scan(Class clazz){
//扫描@Controller注解 添加到controllerMap里
if (clazz.isAnnotationPresent(Controller.class)){
if (clazz.isAnnotationPresent(RequestMapping.class)){
RequestMapping requestMappingAnnotation = (RequestMapping) clazz.getDeclaredAnnotation(RequestMapping.class);
String url = requestMappingAnnotation.url();
handler.addController("/" + url, clazz);
}else {
handler.addController("/", clazz);
}
}
}
//设置controllerObject值
public void setControllerObjects(Map<String, Object> controllerObjects) {
this.controllerObjects = controllerObjects;
}
}