通常service层会调用多个DAO层的方法来完成一个事务,所以必须要保证多个DAO层的数据库连接是同一个。这里使用ThreadLocal来确保一个事务之内所有的数据库操作都使用同一个数据库连接。
13.1 ThreadLocal原理
每一个线程之中都维护一个ThreadLocalMap,它存储本线程中所有ThreadLocal对象及其对应的值。
ThreadLocal的两个主要的方法是set和get方法。
13.1.1 ThreadLocal中的set方法
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);
} else {
createMap(t, value);
}
}
由源码可以看出,set方法分为三步:
- 获取当前线程t,
- 获取当前线程的ThreadLocalMap
- 将当前的ThreadLocal对象作为键,存储对象value作为值存入ThreadLocalMap中
13.1.2 ThreadLocal中的get方法
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
由源码可以看出,get方法取出对象分为三步:
- 获取当前线程t
- 获取当前线程的ThreadLocalMap
- 将当前的ThreadLocal对象作为键,查询ThreadLocalMap中对应的值,并返回
13.2 使用ThreadLocal确保数据库事务使用同一个数据库连接
修改JDBCUtils代码,加入ThreadLocal私有对象,保证JDBCUtils.getConnection()方法从ThreadLocal中获取连接,若获取连接为空,则从数据库连接池中取出一个连接放入ThreadLocal中。
public class JDBCUtils {
private static ThreadLocal<Connection> threadLocal = new ThreadLocal<>();
private static DataSource dataSource;
static {
try {
//在javaweb项目中不能使用ClassLoader.getSystemClassLoader().getResourceAsStream()来进行配置文件的读取
InputStream is = JDBCUtils.class.getClassLoader().getResourceAsStream("druid.properties");
Properties properties = new Properties();
properties.load(is);
dataSource = DruidDataSourceFactory.createDataSource(properties);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 使用druid数据库连接池技术来获取数据库连接
*/
public static Connection getConnection() throws Exception {
Connection connection = threadLocal.get();
if (connection == null){
connection = dataSource.getConnection();
threadLocal.set(connection);
}
return connection;
}
/**
* 从ThreadLocal中关闭数据库连接
*/
public static void closeConnection() throws SQLException {
Connection connection = threadLocal.get();
if(connection == null){
return;
}
if (!connection.isClosed()){
connection.close();
threadLocal.set(null);
}
}
...
}
创建TransactionManager
package trans;
import util.JDBCUtils;
public class TransactionManager {
public static void beginTrans() throws Exception {
JDBCUtils.getConnection().setAutoCommit(false);
}
public static void commit() throws Exception {
JDBCUtils.getConnection().commit();
// 事务提交后表示当前事务完成,关闭当前事务数据库连接
JDBCUtils.closeConnection();
}
public static void rollback() throws Exception {
JDBCUtils.getConnection().rollback();
// 事务回滚后表示当前事务终止,关闭当前事务数据库连接
JDBCUtils.closeConnection();
}
}
创建OpenSessionInViewFilter对所有请求进行事务控制
package filter;
import trans.TransactionManager;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import java.io.IOException;
@WebFilter("*.do")
public class OpenSessionInViewFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
Filter.super.init(filterConfig);
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
// 要保证下层的代码能够顺利抛出异常,所以之后的代码尽量少使用try catch
try{
System.out.println("开启事务...");
TransactionManager.beginTrans();
filterChain.doFilter(servletRequest,servletResponse);
System.out.println("提交事务...");
TransactionManager.commit();
}catch (Exception e){
e.printStackTrace();
try {
System.out.println("回滚事务...");
TransactionManager.rollback();
} catch (Exception ex) {
ex.printStackTrace();
}
}
}
@Override
public void destroy() {
Filter.super.destroy();
}
}
修改BaseDAO使数据库操作能够顺利抛出异常
package dao;
import exception.BaseDAOException;
import util.JDBCUtils;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.util.ArrayList;
import java.util.List;
/**
* 由于已经在过滤器中进行事务的处理,所以BaseDAO中所有的方法都将使用同一个Connection
* @param <T>
*/
public abstract class BaseDAO<T> {
/**
* 不显式的赋值null也不会报错,为什么?
*/
private Class<T> clazz = null;
{
Type genericSuperclass = this.getClass().getGenericSuperclass();
ParameterizedType paramType = (ParameterizedType) genericSuperclass;
Type[] actualTypeArguments = paramType.getActualTypeArguments();
clazz = (Class<T>) actualTypeArguments[0];
}
public int update(String sql, Object ...args){
PreparedStatement ps = null;
Connection connection = null;
try {
connection = JDBCUtils.getConnection();
//预编译sql语句
ps = connection.prepareStatement(sql);
//填充占位符
for (int i = 0; i < args.length; i++) {
ps.setObject(i + 1, args[i]);
}
/*
* 如果执行的是查询操作,有返回结果,则返回true
* 如果执行的是增删改操作,没有返回结果,则返回false
*/
return ps.executeUpdate();
} catch (Exception e) {
e.printStackTrace();
// 数据库操作出现问题时,抛出自定义异常,其他所有数据库方法中亦如此
throw new BaseDAOException("BaseDAO层update操作出现异常...");
} finally {
JDBCUtils.closeResource(null, ps);
}
}
...
}
另外,我发现在DispatcherServlet中如果这样
package servlet;
import dao.impl.FruitDAOImpl;
import ioc.BeanFactory;
import ioc.ClassPathXmlApplicationContext;
import org.junit.Test;
import pojo.Fruit;
import util.JDBCUtils;
import util.StringUtils;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.sql.Connection;
@WebServlet("*.do")
public class DispatcherServlet extends ViewBaseServlet {
private BeanFactory beanFactory;
public DispatcherServlet(){
}
@Override
public void init() throws ServletException {
super.init();
this.beanFactory = new ClassPathXmlApplicationContext();
}
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
request.setCharacterEncoding("UTF-8");
String servletPath = request.getServletPath();
servletPath = servletPath.substring(1);
servletPath = servletPath.substring(0, servletPath.lastIndexOf(".do"));
// 获取要执行的方法的controller类对象
Object controllerBeanObj = beanFactory.getBean(servletPath);
String operate = request.getParameter("operate");
if(StringUtils.isEmpty(operate)){
operate = "index";
}
// 获取controller类对象中要执行的方法
Method[] methods = controllerBeanObj.getClass().getDeclaredMethods();
for (Method method : methods) {
if (operate.equals(method.getName())){
Parameter[] parameters = method.getParameters();
Object[] parametersValues = new Object[parameters.length];
for (int i = 0; i < parameters.length; i++) {
if ("request".equals(parameters[i].getName())){
parametersValues[i] = request;
} else if ("response".equals(parameters[i].getName())){
parametersValues[i] = response;
} else if ("session".equals(parameters[i].getName())){
parametersValues[i] = request.getSession();
} else {
String parameterName = parameters[i].getName();
String parameterValue = request.getParameter(parameterName);
String typeName = parameters[i].getType().getName();
if (parameterValue != null){
if ("java.lang.Integer".equals(typeName)){
parametersValues[i] = Integer.parseInt(parameterValue);
} else {
parametersValues[i] = parameterValue;
}
}
}
}
method.setAccessible(true);
String methodReturn = null;
try {
//直接try catch这句代码,发生的异常不会被捕获,会抛出给Filter捕获
methodReturn = (String) method.invoke(controllerBeanObj, parametersValues);
} catch (IllegalAccessException e) {
//此处无需抛出异常即可被Filter捕获
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
if (methodReturn.startsWith("redirect:")){
String redirect = methodReturn.substring("redirect:".length());
response.sendRedirect(redirect);
} else{
processTemplate(methodReturn, request, response);
}
}
}
}
}
如果这样,则会出现问题
package servlet;
import dao.impl.FruitDAOImpl;
import ioc.BeanFactory;
import ioc.ClassPathXmlApplicationContext;
import org.junit.Test;
import pojo.Fruit;
import util.JDBCUtils;
import util.StringUtils;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.sql.Connection;
@WebServlet("*.do")
public class DispatcherServlet extends ViewBaseServlet {
private BeanFactory beanFactory;
public DispatcherServlet(){
}
@Override
public void init() throws ServletException {
super.init();
this.beanFactory = new ClassPathXmlApplicationContext();
}
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
request.setCharacterEncoding("UTF-8");
String servletPath = request.getServletPath();
servletPath = servletPath.substring(1);
servletPath = servletPath.substring(0, servletPath.lastIndexOf(".do"));
// 获取要执行的方法的controller类对象
Object controllerBeanObj = beanFactory.getBean(servletPath);
String operate = request.getParameter("operate");
if(StringUtils.isEmpty(operate)){
operate = "index";
}
// 获取controller类对象中要执行的方法
// 直接从这里try catch,则下面method.invoke发生的异常就会被捕获为InvocationTargetException
// 导致无法抛出到Filter,不知为什么,又没有读者了解的解释一下
try{
Method[] methods = controllerBeanObj.getClass().getDeclaredMethods();
for (Method method : methods) {
if (operate.equals(method.getName())){
Parameter[] parameters = method.getParameters();
Object[] parametersValues = new Object[parameters.length];
for (int i = 0; i < parameters.length; i++) {
if ("request".equals(parameters[i].getName())){
parametersValues[i] = request;
} else if ("response".equals(parameters[i].getName())){
parametersValues[i] = response;
} else if ("session".equals(parameters[i].getName())){
parametersValues[i] = request.getSession();
} else {
String parameterName = parameters[i].getName();
String parameterValue = request.getParameter(parameterName);
String typeName = parameters[i].getType().getName();
if (parameterValue != null){
if ("java.lang.Integer".equals(typeName)){
parametersValues[i] = Integer.parseInt(parameterValue);
} else {
parametersValues[i] = parameterValue;
}
}
}
}
method.setAccessible(true);
String methodReturn = (String) method.invoke(controllerBeanObj, parametersValues);
if (methodReturn.startsWith("redirect:")){
String redirect = methodReturn.substring("redirect:".length());
response.sendRedirect(redirect);
} else{
processTemplate(methodReturn, request, response);
}
}
}
} catch (Exception e) {
e.printStackTrace();
// 这里应该添加以下代码保证能够顺利抛出异常
// throw new DispatcherServletException("DispatcherServlet出现异常...");
}
}
}