JavaWeb:(十三)ThreadLocal实现事务管理

通常service层会调用多个DAO层的方法来完成一个事务,所以必须要保证多个DAO层的数据库连接是同一个。这里使用ThreadLocal来确保一个事务之内所有的数据库操作都使用同一个数据库连接。

13.1 ThreadLocal原理

每一个线程之中都维护一个ThreadLocalMap,它存储本线程中所有ThreadLocal对象及其对应的值。

ThreadLocal的两个主要的方法是set和get方法。

13.1.1 ThreadLocal中的set方法

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}

由源码可以看出,set方法分为三步:

  1. 获取当前线程t,
  2. 获取当前线程的ThreadLocalMap
  3. 将当前的ThreadLocal对象作为键,存储对象value作为值存入ThreadLocalMap中

13.1.2 ThreadLocal中的get方法

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

由源码可以看出,get方法取出对象分为三步:

  1. 获取当前线程t
  2. 获取当前线程的ThreadLocalMap
  3. 将当前的ThreadLocal对象作为键,查询ThreadLocalMap中对应的值,并返回

13.2 使用ThreadLocal确保数据库事务使用同一个数据库连接

修改JDBCUtils代码,加入ThreadLocal私有对象,保证JDBCUtils.getConnection()方法从ThreadLocal中获取连接,若获取连接为空,则从数据库连接池中取出一个连接放入ThreadLocal中。

public class JDBCUtils {
    private static ThreadLocal<Connection> threadLocal = new ThreadLocal<>();
    private static DataSource dataSource;
    static {
        try {
            //在javaweb项目中不能使用ClassLoader.getSystemClassLoader().getResourceAsStream()来进行配置文件的读取
            InputStream is = JDBCUtils.class.getClassLoader().getResourceAsStream("druid.properties");
            Properties properties = new Properties();
            properties.load(is);
            dataSource = DruidDataSourceFactory.createDataSource(properties);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    /**
     * 使用druid数据库连接池技术来获取数据库连接
     */
    public static Connection getConnection() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null){
            connection = dataSource.getConnection();
            threadLocal.set(connection);
        }
        return connection;
    }
    
    /**
     * 从ThreadLocal中关闭数据库连接
     */
    public static void closeConnection() throws SQLException {
        Connection connection = threadLocal.get();
        if(connection == null){
            return;
        }
        if (!connection.isClosed()){
            connection.close();
            threadLocal.set(null);
        }
    }
    ...
}

创建TransactionManager

package trans;

import util.JDBCUtils;

public class TransactionManager {

    public static void beginTrans() throws Exception {
        JDBCUtils.getConnection().setAutoCommit(false);
    }

    public static void commit() throws Exception {
        JDBCUtils.getConnection().commit();
        // 事务提交后表示当前事务完成,关闭当前事务数据库连接
        JDBCUtils.closeConnection();
    }

    public static void rollback() throws Exception {
        JDBCUtils.getConnection().rollback();
        // 事务回滚后表示当前事务终止,关闭当前事务数据库连接
        JDBCUtils.closeConnection();
    }
}

创建OpenSessionInViewFilter对所有请求进行事务控制

package filter;

import trans.TransactionManager;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import java.io.IOException;

@WebFilter("*.do")
public class OpenSessionInViewFilter implements Filter {

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        Filter.super.init(filterConfig);
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        // 要保证下层的代码能够顺利抛出异常,所以之后的代码尽量少使用try catch
        try{
            System.out.println("开启事务...");
            TransactionManager.beginTrans();
            filterChain.doFilter(servletRequest,servletResponse);
            System.out.println("提交事务...");
            TransactionManager.commit();
        }catch (Exception e){
            e.printStackTrace();
            try {
                System.out.println("回滚事务...");
                TransactionManager.rollback();
            } catch (Exception ex) {
                ex.printStackTrace();
            }
        }
    }

    @Override
    public void destroy() {
        Filter.super.destroy();
    }
}

修改BaseDAO使数据库操作能够顺利抛出异常

package dao;

import exception.BaseDAOException;
import util.JDBCUtils;

import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.util.ArrayList;
import java.util.List;

/**
 * 由于已经在过滤器中进行事务的处理,所以BaseDAO中所有的方法都将使用同一个Connection
 * @param <T>
 */
public abstract class BaseDAO<T> {

    /**
     * 不显式的赋值null也不会报错,为什么?
     */
    private Class<T> clazz = null;

    {
        Type genericSuperclass = this.getClass().getGenericSuperclass();
        ParameterizedType paramType = (ParameterizedType) genericSuperclass;

        Type[] actualTypeArguments = paramType.getActualTypeArguments();
        clazz = (Class<T>) actualTypeArguments[0];
    }

    public int update(String sql, Object ...args){
        PreparedStatement ps = null;
        Connection connection = null;
        try {
            connection = JDBCUtils.getConnection();
            //预编译sql语句
            ps = connection.prepareStatement(sql);
            //填充占位符
            for (int i = 0; i < args.length; i++) {
                ps.setObject(i + 1, args[i]);
            }
            /*
             * 如果执行的是查询操作,有返回结果,则返回true
             * 如果执行的是增删改操作,没有返回结果,则返回false
             */
            return ps.executeUpdate();
        } catch (Exception e) {
            e.printStackTrace();
            // 数据库操作出现问题时,抛出自定义异常,其他所有数据库方法中亦如此
            throw new BaseDAOException("BaseDAO层update操作出现异常...");
        } finally {
            JDBCUtils.closeResource(null, ps);
        }
    }
    ...
}

另外,我发现在DispatcherServlet中如果这样

package servlet;

import dao.impl.FruitDAOImpl;
import ioc.BeanFactory;
import ioc.ClassPathXmlApplicationContext;
import org.junit.Test;
import pojo.Fruit;
import util.JDBCUtils;
import util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.sql.Connection;

    @WebServlet("*.do")
    public class DispatcherServlet extends ViewBaseServlet {

        private BeanFactory beanFactory;

        public DispatcherServlet(){
        }

        @Override
        public void init() throws ServletException {
            super.init();
            this.beanFactory = new ClassPathXmlApplicationContext();
        }

    @Override
    protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        request.setCharacterEncoding("UTF-8");
        String servletPath = request.getServletPath();
        servletPath = servletPath.substring(1);
        servletPath = servletPath.substring(0, servletPath.lastIndexOf(".do"));

        // 获取要执行的方法的controller类对象
        Object controllerBeanObj = beanFactory.getBean(servletPath);

        String operate = request.getParameter("operate");
        if(StringUtils.isEmpty(operate)){
            operate = "index";
        }
        // 获取controller类对象中要执行的方法
        Method[] methods = controllerBeanObj.getClass().getDeclaredMethods();
        for (Method method : methods) {
            if (operate.equals(method.getName())){
                Parameter[] parameters = method.getParameters();
                Object[] parametersValues = new Object[parameters.length];
                for (int i = 0; i < parameters.length; i++) {
                    if ("request".equals(parameters[i].getName())){
                        parametersValues[i] = request;
                    } else if ("response".equals(parameters[i].getName())){
                        parametersValues[i] = response;
                    } else if ("session".equals(parameters[i].getName())){
                        parametersValues[i] = request.getSession();
                    } else {
                        String parameterName = parameters[i].getName();
                        String parameterValue = request.getParameter(parameterName);
                        String typeName = parameters[i].getType().getName();
                        if (parameterValue != null){
                            if ("java.lang.Integer".equals(typeName)){
                                parametersValues[i] = Integer.parseInt(parameterValue);
                            } else {
                                parametersValues[i] = parameterValue;
                            }
                        }
                    }
                }

                method.setAccessible(true);
                String methodReturn = null;
                try {
                    //直接try catch这句代码,发生的异常不会被捕获,会抛出给Filter捕获
                    methodReturn = (String) method.invoke(controllerBeanObj, parametersValues);
                } catch (IllegalAccessException e) {
                    //此处无需抛出异常即可被Filter捕获
                    e.printStackTrace();
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                }
                if (methodReturn.startsWith("redirect:")){
                    String redirect = methodReturn.substring("redirect:".length());
                    response.sendRedirect(redirect);
                } else{
                    processTemplate(methodReturn, request, response);
                }
            }
        }
    }
}

如果这样,则会出现问题

package servlet;

import dao.impl.FruitDAOImpl;
import ioc.BeanFactory;
import ioc.ClassPathXmlApplicationContext;
import org.junit.Test;
import pojo.Fruit;
import util.JDBCUtils;
import util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.sql.Connection;

    @WebServlet("*.do")
    public class DispatcherServlet extends ViewBaseServlet {

        private BeanFactory beanFactory;

        public DispatcherServlet(){

        }

        @Override
        public void init() throws ServletException {
            super.init();
            this.beanFactory = new ClassPathXmlApplicationContext();
        }

    @Override
    protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        request.setCharacterEncoding("UTF-8");
        String servletPath = request.getServletPath();
        servletPath = servletPath.substring(1);
        servletPath = servletPath.substring(0, servletPath.lastIndexOf(".do"));

        // 获取要执行的方法的controller类对象
        Object controllerBeanObj = beanFactory.getBean(servletPath);
        
        String operate = request.getParameter("operate");
        if(StringUtils.isEmpty(operate)){
            operate = "index";
        }
        // 获取controller类对象中要执行的方法
        // 直接从这里try catch,则下面method.invoke发生的异常就会被捕获为InvocationTargetException
        // 导致无法抛出到Filter,不知为什么,又没有读者了解的解释一下
        try{
        	Method[] methods = controllerBeanObj.getClass().getDeclaredMethods();
            for (Method method : methods) {
                if (operate.equals(method.getName())){
                    Parameter[] parameters = method.getParameters();
                    Object[] parametersValues = new Object[parameters.length];
                    for (int i = 0; i < parameters.length; i++) {
                        if ("request".equals(parameters[i].getName())){
                            parametersValues[i] = request;
                        } else if ("response".equals(parameters[i].getName())){
                            parametersValues[i] = response;
                        } else if ("session".equals(parameters[i].getName())){
                            parametersValues[i] = request.getSession();
                        } else {
                            String parameterName = parameters[i].getName();
                            String parameterValue = request.getParameter(parameterName);
                            String typeName = parameters[i].getType().getName();
                            if (parameterValue != null){
                                if ("java.lang.Integer".equals(typeName)){
                                    parametersValues[i] = Integer.parseInt(parameterValue);
                                } else {
                                    parametersValues[i] = parameterValue;
                                }
                            }
                        }
                    }

                    method.setAccessible(true);
                    String methodReturn = (String) method.invoke(controllerBeanObj, parametersValues);       
                    if (methodReturn.startsWith("redirect:")){
                        String redirect = methodReturn.substring("redirect:".length());
                        response.sendRedirect(redirect);
                    } else{
                        processTemplate(methodReturn, request, response);
                    }
                }
            }
        } catch (Exception e) {
        	e.printStackTrace();
        	// 这里应该添加以下代码保证能够顺利抛出异常
        	// throw new DispatcherServletException("DispatcherServlet出现异常...");
        }  
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值