框架的本质就是反射+注解+设计模式,本文旨在实现简单的SSM框架,帮助大家理解真正的框架思想,关于Spring的AOP部分,有不懂的童鞋可以看博主的另一篇帖子:手写JDK动态代理并简易实现Spring的各类通知
首先是Spring的容器部分:
ClassPathXmlApplicationContext.java
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Parameter;
import java.util.HashMap;
import java.util.Map;
public class MyClassPathXmlApplicationContext implements MyBeanFactory {
Map<String, Object> beanMap = new HashMap<>();
/*
先利用DOM技术解析XML文档,创建beanMap<String,Object>
给beanMap存储<fruitBiz,FruitBizImpl>,<fruitDAO,FruitDAOImpl>两对键值对
利用反射技术给FruitBizImpl实例对象中的FruitDAO成员变量赋实例对象FruitDAOImpl值
*/
public MyClassPathXmlApplicationContext(String contextConfigLocation) {
//DOM技术解析XML文档
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
try {
DocumentBuilder builder = factory.newDocumentBuilder();
InputStream is = getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
//创建document对象,相当于一个XML文档
Document document = builder.parse(is);
//拿到name等于bean的标签对象List
NodeList nodeList = document.getElementsByTagName("bean");
//步骤一:创建所有的bean对象
for (int i = 0; i < nodeList.getLength(); i++) {
//根据下标获取List中的结点
Node node = nodeList.item(i);
//强转为元素
Element element = (Element) node;
//获取document元素的id与className属性
String id = element.getAttribute("id");
NodeList childNodes = element.getChildNodes();
String className = element.getAttribute("class");
Class clazz = Class.forName(className);
boolean flag = true;
String nameStr = "";
String valueStr = "";
for (int j = 0; j < childNodes.getLength(); j++) {
Node item = childNodes.item(j);
if (item.getNodeType() == Node.ELEMENT_NODE && "constructor-arg".equals(item.getNodeName())) {
flag = false;
Element propertyElement = (Element) item;
String name = propertyElement.getAttribute("name");
String value = propertyElement.getAttribute("value");
nameStr += name;
valueStr += value + ",";
}
}
Object instance = null;
if (flag) {
instance = clazz.newInstance();
} else {
for (Constructor constructor : clazz.getConstructors()) {
String constructorStr = "";
Parameter[] parameters = constructor.getParameters();
for (Parameter parameter : parameters) {
constructorStr += parameter.getName();
}
if (nameStr.equals(constructorStr)) {
Class[] clazzs = constructor.getParameterTypes();
valueStr = valueStr.substring(0, valueStr.length() - 1);
String[] strs = valueStr.split(",");
Object[] objects = new Object[strs.length];
for (int j = 0; j < clazzs.length; j++) {
if ("java.lang.Integer".equals(clazzs[j].getName()) || "int".equals(clazzs[i].getName()))
objects[j] = Integer.parseInt(strs[j]);
else if ("java.lang.Double".equals(clazzs[j].getName()) || "double".equals(clazzs[i].getName()))
objects[j] = Double.parseDouble(strs[j]);
else if ("java.lang.Boolean".equals(clazzs[j].getName()) || "boolean".equals(clazzs[i].getName()))
objects[j] = Boolean.parseBoolean(strs[j]);
else
objects[j] = strs[j];
}
instance = constructor.newInstance(objects);
}
}
}
beanMap.put(id, instance);
}
for (int i = 0; i < nodeList.getLength(); i++) {
Node node = nodeList.item(i);
Element element = (Element) node;
String id = element.getAttribute("id");
Object instance = beanMap.get(id);
NodeList childNodes = element.getChildNodes();
for (int j = 0; j < childNodes.getLength(); j++) {
Node item = childNodes.item(j);
if (item.getNodeType() == Node.ELEMENT_NODE && "property".equals(item.getNodeName())) {
Element propertyElement = (Element) item;
String name = propertyElement.getAttribute("name");
String ref = propertyElement.getAttribute("ref");
String value = propertyElement.getAttribute("value");
Object refObj = beanMap.get(ref);
Class clazz = instance.getClass();
Field field = clazz.getDeclaredField(name);
field.setAccessible(true);
if (refObj == null) {
Object value1;
if ("java.lang.Integer".equals(field.getType().getName()) || "int".equals(field.getType().getName()))
value1 = Integer.parseInt(value);
else if ("java.lang.Double".equals(field.getType().getName()) || "double".equals(field.getType().getName()))
value1 = Double.parseDouble(value);
else if ("java.lang.Boolean".equals(field.getType().getName()) || "boolean".equals(field.getType().getName()))
value1 = Boolean.parseBoolean(value);
else
value1 = value;
field.set(instance,value1);
}else {
field.set(instance, refObj);
}
}
}
}
} catch (ParserConfigurationException e) {
e.printStackTrace();
} catch (SAXException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (NoSuchFieldException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
}
@Override
public Object getBean(String id) {
return beanMap.get(id);
}
@Override
public <T> T getBean(Class<T> aClass) {
for (Object value : beanMap.values()) {
if (value.getClass().equals(aClass)) {
return (T) value;
}
}
return null;
}
}
暴露在外的工厂接口,用于获取实例对象
MyBeanFactory.java
public interface MyBeanFactory {
Object getBean(String id);
<T> T getBean(Class<T> tClass);
}
然后是Spring整合Mybatis的部分
为了支持事务,博主这里用了ThreadLocal保证同一个线程对于数据库的多次操作用的是同一个数据库连接。
ConnectionUtil.java
import com.alibaba.druid.pool.DruidDataSourceFactory;
import javax.sql.DataSource;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Properties;
public class ConnectionUtil {
private static ThreadLocal<Connection> tl = new ThreadLocal<>();
private static DataSource ds = null ;
static{
InputStream is = null;
try {
is = ConnectionUtil.class.getClassLoader().getResourceAsStream("druid_jdbc.properties");
if (is != null) {
Properties properties = new Properties();
properties.load(is);
ds = DruidDataSourceFactory.createDataSource(properties);
}
} catch (FileNotFoundException e1) {
e1.printStackTrace();
} catch (IOException e2) {
e2.printStackTrace();
} catch (Exception e3) {
e3.printStackTrace();
}
}
//获取连接对象
public static Connection getConn() throws SQLException {
Connection conn = tl.get();
if(conn==null){
conn = ds.getConnection();
tl.set(conn);
}
return tl.get();
}
}
下面是实现ORM的自动化映射。
BaseDAO.java
import com.alibaba.druid.pool.DruidDataSourceFactory;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import javax.sql.DataSource;
public abstract class BaseDAO<T> {
protected Connection conn;
protected PreparedStatement psmt;
protected ResultSet rs;
protected Connection getConn() throws SQLException {
return ConnectionUtil.getConn();
}
private void setParams(Object... params) throws SQLException {
if (params != null && params.length > 0) {
for(int i = 0; i < params.length; ++i) {
psmt.setObject(i + 1, params[i]);
}
}
}
protected int executeUpdate(String sql, Object... params) {
try {
conn = getConn();
if (insertFlag) {
psmt = conn.prepareStatement(sql, 1);
} else {
psmt = conn.prepareStatement(sql);
}
setParams(params);
int count = psmt.executeUpdate();
if (insertFlag) {
rs = psmt.getGeneratedKeys();
if (rs.next()) {
Long id = this.rs.getLong(1);
return id.intValue();
}
}
return count;
} catch (SQLException e) {
e.printStackTrace();
}finally {
throw new FruitException("添加水果出错了...............");
}
}
protected List<T> executeQuery(String sql, Object... params) {
ArrayList list = new ArrayList();
try {
conn = getConn();
psmt = conn.prepareStatement(sql);
setParams(params);
rs = psmt.executeQuery();
ResultSetMetaData rsmd = this.rs.getMetaData();
int colCount = rsmd.getColumnCount();
while(rs.next()) {
T entityObj = createInstance(getClassType());
for(int i = 1; i <= colCount; ++i) {
String colName = rsmd.getColumnName(i);
Object colValue = rs.getObject(i);
setValue(entityObj, colName, colValue);
}
list.add(entityObj);
}
} catch (SQLException e) {
e.printStackTrace();
}
return list;
}
private void setValue(Object instance, String filedName, Object fieldValue) {
Class entityClass = instance.getClass();
Field field = null;
try {
field = entityClass.getDeclaredField(filedName);
field.setAccessible(true);
field.set(instance, fieldValue);
} catch (NoSuchFieldException e1) {
e1.printStackTrace();
} catch (IllegalAccessException e2) {
e2.printStackTrace();
}
}
private Class<T> getClassType() {
Type type = this.getClass().getGenericSuperclass();
Type actualType = ((ParameterizedType)type).getActualTypeArguments()[0];
String typeName = actualType.getTypeName();
Class entityClass = null;
try {
entityClass = Class.forName(typeName);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
return entityClass;
}
private T createInstance(Class<T> entityClass) {
try {
return entityClass.newInstance();
} catch (InstantiationException e1) {
e1.printStackTrace();
} catch (IllegalAccessException e2) {
e2.printStackTrace();
}
return null;
}
protected T load(String sql, Object... params) {
try {
conn = getConn();
psmt = conn.prepareStatement(sql);
setParams(params);
rs = psmt.executeQuery();
ResultSetMetaData rsmd = rs.getMetaData();
int colCount = rsmd.getColumnCount();
if (rs.next()) {
T entityObj = createInstance(this.getClassType());
for(int i = 1; i <= colCount; ++i) {
String colName = rsmd.getColumnName(i);
Object colValue = rs.getObject(i);
setValue(entityObj, colName, colValue);
}
return entityObj;
}
} catch (SQLException e) {
e.printStackTrace();
}
return null;
}
protected Object[] executeComplexQuery(String sql, Object... params) {
try {
conn = getConn();
psmt = conn.prepareStatement(sql);
setParams(params);
rs = psmt.executeQuery();
ResultSetMetaData rsmd = this.rs.getMetaData();
int colCount = rsmd.getColumnCount();
Object[] arr = new Object[colCount];
if (rs.next()) {
for(int i = 1; i <= colCount; ++i) {
arr[i - 1] = this.rs.getObject(i);
}
return arr ;
}
} catch (SQLException e) {
e.printStackTrace();
}
return null;
}
}
TypeUtil.java
public class TypeUtil {
public static boolean isMyType(String typeName){
if("int".equals(typeName)){
return false;
}else if("java.lang.Integer".equals(typeName)){
return false ;
}else if("java.lang.String".equals(typeName)){
return false;
}else if("java.sql.Date".equals(typeName)){
return false ;
}else if("java.util.Date".equals(typeName)){
return false ;
}
return true ;
}
}
下面实现全局事务管理器。
TransactionManager.java
import java.sql.Connection;
import java.sql.SQLException;
//事务管理类
public class TransactionManager {
private static TransactionManager instance = new TransactionManager();
private TransactionManager(){
}
public static synchronized TransactionManager getInstance(){
if(instance==null){
instance = new TransactionManager();
}
return instance ;
}
//0.开启事务
public void start() throws SQLException {
ConnectionUtil.getConn().setAutoCommit(false);
}
//1.提交事务
public void commit() throws SQLException {
ConnectionUtil.getConn().commit();
close();
}
//2.回滚事务
public void rollback() throws SQLException{
ConnectionUtil.getConn().rollback();
close();
}
//3.关闭连接
public void close() throws SQLException {
ConnectionUtil.close();
}
}
最后是Spring MVC核心部分。
DispatcherServlet.java
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
@WebServlet("*.action")
public class DispatcherServlet extends HttpServlet {
private BeanFactory beanFactory ;
@Override
public void init(ServletConfig config) throws ServletException {
//获取ServletContext上下文对象(获取application)
ServletContext application = config.getServletContext();
//从application保存作用域获取之前Listener中存放进去的数据,其实就是获取配置文件名applicationContext.xml
String contextConfigLocation = (String)application.getAttribute("contextConfigLocation");
//根据这个配置文件去创建指定的一个一个的Bean对象
beanFactory = new ClassPathXmlApplicationContext(contextConfigLocation);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req,resp);
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
//获取URI , 也就是获取字符串:/login.action
String uri = request.getRequestURI();
int actionIndex = uri.indexOf(".action");
// path : login
String path = uri.substring(1,actionIndex);
//从工厂中根据path获取到对应的Bean对象(也就是Action对象)
Object actionObj = beanFactory.getBean(path);
//获取oper的值,也就是对应Action中的方法名
String oper = request.getParameter("oper");
if(oper==null){
oper = "index" ;
}
try {
//获取所有的方法
Method[] methods = actionObj.getClass().getDeclaredMethods();
for(Method m : methods){
//如果方法名等于oper的值,表示找到了我们需要执行的方法
if(m.getName().equals(oper)){
//invoke之前,获取参数,创建session(如果需要的话)
int paramCount = m.getParameterCount();
//获取m方法的所有参数
Parameter[] params = m.getParameters();
//创建一个数组,用于存放实参,将来调用的时候设置进去
Object[] paramObjs = new Object[paramCount];
for (int i = 0; i < params.length; i++) {
Parameter param = params[i];
//获取参数名称
String paramName = param.getName();
//如果是session、request、response,那么就把当前Servlet中的这三个对象设置进去
if("session".equals(paramName)) {
HttpSession session = request.getSession();
paramObjs[i] = session;
}else if("response".equals(paramName)){
paramObjs[i] = response;
}else if("request".equals(paramName)){
paramObjs[i] = request;
}else{
//表示参数应该从请求中获取
String paramValueStr = request.getParameter(paramName);
Object paramValue = paramValueStr ;
String typeName = param.getType().getName();
//因为从请求中获取的参数的类型全部都是String类型
//但是我们方法上的形参不一定是String,有可能是Integer、Double、Float...
//因此需要做强制类型转换。
// 此处简单处理,我们知道当前系统中只用到了Integer,没有用到其他数据类型
//如果需要用到其他数据类型,则需要专门写一个类型转换器Convertor
if("java.lang.Integer".equals(typeName)){
if(paramValueStr!=null){
paramValue = Integer.parseInt(paramValueStr);
}
}else if("java.lang.Double".equals(typeName)){
if(paramValueStr!=null){
paramValue = Double.parseDouble(paramValueStr);
}
}
paramObjs[i] = paramValue ;
}
}
//方法调用时,paramObjs作为实参传入
Object viewObj = m.invoke(actionObj,paramObjs);
//通过ViewResolver进行资源转发或者重定向
ViewResolver.getInstance().resolverView(request,response,viewObj);
}
}
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e.getMessage());
}
}
}
视图解析器
ViewResolver.java
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
public class ViewResolver {
private static ViewResolver instance;
private ViewResolver() {}
public synchronized static ViewResolver getInstance() {
if(instance == null){
instance = new ViewResolver();
}
return instance ;
}
public void resolverView(HttpServletRequest request, HttpServletResponse response, Object viewObj) throws ServletException, IOException {
if(viewObj != null){
String view = (String)viewObj ;
if(view.startsWith("redirect:")){
String viewValue = view.substring("redirect:".length());
response.sendRedirect(viewValue);
}else{
request.getRequestDispatcher(view).forward(request,response);
}
}
}
}