上一篇我们为框架添加了AOP特性【从零写javaweb框架】(十)加载AOP框架,现在可以利用这个特性来为框架添加事务处理。
定义一个事务注解:
package org.smart4j.framework.annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* desc : 事务注解
* Created by Lon on 2018/2/17.
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Transaction {
}
Transaction注解可以标记在Service的方法上,表示是一个有事务的方法,如果抛出异常,则会进行回滚操作。
我们还需要一个操作数据库的工具类DatabaseHelper,它的底层是Apache common 的DBUtil,用ThreadLocal实现了线程安全,用DBCP数据连接池来提高性能:
package org.smart4j.framework.helper;
import org.apache.commons.dbcp2.BasicDataSource;
import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.dbutils.handlers.BeanHandler;
import org.apache.commons.dbutils.handlers.BeanListHandler;
import org.apache.commons.dbutils.handlers.MapListHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.smart4j.framework.util.CollectionUtil;
import org.smart4j.framework.util.PropsUtil;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
/**
* desc : 数据库操作助手类
* Created by Lon on 2018/2/17.
*/
public final class DatabaseHelper {
private static final Logger LOGGER = LoggerFactory.getLogger(DatabaseHelper.class);
private static final String DRIVER;
private static final String URL;
private static final String USERNAME;
private static final String PASSWORD;
// DBUtil
private static final QueryRunner QUERY_RUNNER = new QueryRunner();
// 隔离线程
private static final ThreadLocal<Connection> CONNECTION_HOLDER = new ThreadLocal<Connection>();
// DBCP数据库连接池
private static final BasicDataSource DATA_SOURCE = new BasicDataSource();
static {
Properties conf = PropsUtil.loadProps("config.properties");
DRIVER = conf.getProperty("jdbc.driver");
URL = conf.getProperty("jdbc.url");
USERNAME = conf.getProperty("jdbc.username");
PASSWORD = conf.getProperty("jdbc.password");
DATA_SOURCE.setDriverClassName(DRIVER);
DATA_SOURCE.setUrl(URL);
DATA_SOURCE.setUsername(USERNAME);
DATA_SOURCE.setPassword(PASSWORD);
}
/**
* 获取数据库连接
*/
public static Connection getConnection(){
Connection conn = CONNECTION_HOLDER.get();
if (conn == null){
try {
conn = DATA_SOURCE.getConnection();
} catch (SQLException e) {
LOGGER.error("get connection failure", e);
throw new RuntimeException(e);
} finally {
CONNECTION_HOLDER.set(conn);
}
}
return conn;
}
/**
* 查询实体列表
*/
public static <T> List<T> queryEntityList(Class<T> entityClass, String sql, Object... params){
List<T> entityList;
try {
Connection conn = getConnection();
entityList = QUERY_RUNNER.query(conn, sql, new BeanListHandler<T>(entityClass), params);
} catch (SQLException e) {
LOGGER.error("query entity list failure", e);
throw new RuntimeException(e);
}
return entityList;
}
/**
* 查询实体
*/
public static <T> T queryEntity(Class<T> entityClass, String sql, Object... params){
T entity;
try {
Connection conn = getConnection();
entity = QUERY_RUNNER.query(conn, sql, new BeanHandler<T>(entityClass), params);
} catch (SQLException e){
LOGGER.error("query entity failure", e);
throw new RuntimeException(e);
}
return entity;
}
/**
* 可进行联表查询,返回一个List<Map>
*/
public static List<Map<String, Object>> executeQuery(String sql, Object... params){
List<Map<String, Object>> result;
try {
Connection conn = getConnection();
result = QUERY_RUNNER.query(conn, sql, new MapListHandler(), params);
} catch (Exception e){
LOGGER.error("execute query failure", e);
throw new RuntimeException(e);
}
return result;
}
/**
* 执行更新语句 (包括update/insert/delete), 返回更新的行数
*/
public static int excuteUpdate(String sql, Object... params){
int rows = 0;
try {
Connection conn = getConnection();
rows = QUERY_RUNNER.update(conn, sql, params);
} catch (SQLException e){
LOGGER.error("execute update failure", e);
throw new RuntimeException(e);
}
return rows;
}
/**
* 插入实体
*/
public static <T> boolean insertEntity(Class<T> entityClass, Map<String, Object> fieldMap){
if (CollectionUtil.isEmpty(fieldMap)){
LOGGER.error("can not insert entity");
return false;
}
String sql = "INSERT INTO " + getTableName(entityClass);
StringBuilder columns = new StringBuilder("(");
StringBuilder values = new StringBuilder("(");
for (String fieldName : fieldMap.keySet()){
columns.append(fieldName).append(", ");
values.append("?, ");
}
columns.replace(columns.lastIndexOf(", "), columns.length(), ")");
values.replace(values.lastIndexOf(", "), values.length(), ")");
sql += columns + " VALUES " + values;
Object[] params = fieldMap.values().toArray();
return excuteUpdate(sql, params) == 1;
}
/**
* 更新实体
*/
public static <T> boolean updateEntity(Class<T> entityClass, long id, Map<String, Object> fieldMap){
if (CollectionUtil.isEmpty(fieldMap)){
LOGGER.error("can not update entity: fieldMap is empty");
return false;
}
String sql = "UPDATE " + getTableName(entityClass) + " SET ";
StringBuilder columns = new StringBuilder();
for (String fieldName : fieldMap.keySet()){
columns.append(fieldName).append("=? ");
}
sql += columns.substring(0, columns.lastIndexOf(", ")) + " WHERE id=?";
List<Object> paramList = new ArrayList<Object>();
paramList.addAll(fieldMap.values());
paramList.add(id);
Object[] params = paramList.toArray();
return excuteUpdate(sql, params) == 1;
}
/**
* 删除实体
*/
public static <T> boolean deleteEntity(Class<T> entityClass, long id){
String sql = "DELETE FROM " + getTableName(entityClass) + " WHERE id=?";
return excuteUpdate(sql, id) == 1;
}
/**
* 执行SQL文件
*/
public static void executeSqlFile(String filePath){
InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(filePath);
BufferedReader reader = new BufferedReader(new InputStreamReader(is));
try {
String sql;
while ((sql=reader.readLine()) != null){
excuteUpdate(sql);
}
} catch (Exception e){
LOGGER.error("execute sql file failure", e);
throw new RuntimeException(e);
}
}
/**
* 获取数据库表名
*/
private static String getTableName(Class<?> entityClass){
return entityClass.getSimpleName();
}
/**
* 开启事务
*/
public static void beginTransaction(){
Connection conn = getConnection();
if (conn != null){
try {
conn.setAutoCommit(false);
} catch (SQLException e) {
LOGGER.error("begin transaction failure", e);
throw new RuntimeException(e);
} finally {
CONNECTION_HOLDER.set(conn);
}
}
}
/**
* 提交事务
*/
public static void commitTransaction(){
Connection conn = getConnection();
if (conn != null){
try {
conn.commit();
conn.close();
} catch (SQLException e) {
LOGGER.error("commit transaction failure", e);
throw new RuntimeException(e);
} finally {
CONNECTION_HOLDER.remove();
}
}
}
/**
* 回滚事务
*/
public static void rollbackTransaction(){
Connection conn = getConnection();
if (conn != null){
try {
conn.rollback();
conn.close();
} catch (SQLException e) {
LOGGER.error("rollback transaction failure", e);
throw new RuntimeException(e);
} finally {
CONNECTION_HOLDER.remove();
}
}
}
}
有了上面的代码后,现在可以编写事务代理切面类了,它继承了Proxy接口:
package org.smart4j.framework.proxy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.smart4j.framework.annotation.Transaction;
import org.smart4j.framework.helper.DatabaseHelper;
import java.lang.reflect.Method;
/**
* desc : 事务代理
* Created by Lon on 2018/2/18.
*/
public class TransactionProxy implements Proxy{
private static final Logger LOGGER = LoggerFactory.getLogger(TransactionProxy.class);
private static final ThreadLocal<Boolean> FLAG_HOLDER = new ThreadLocal<Boolean>(){
@Override
protected Boolean initialValue() {
return false;
}
};
public Object doProxy(ProxyChain proxyChain) throws Throwable {
Object result;
boolean flag = FLAG_HOLDER.get();
Method method = proxyChain.getTargetMethod();
if (!flag && method.isAnnotationPresent(Transaction.class)){
FLAG_HOLDER.set(true);
try {
DatabaseHelper.beginTransaction();
LOGGER.debug("begin transaction");
result = proxyChain.doProxyChain();
LOGGER.debug("commit transaction");
DatabaseHelper.commitTransaction();
} catch (Exception e){
DatabaseHelper.rollbackTransaction();
LOGGER.debug("rollback transaction");
throw e;
} finally {
FLAG_HOLDER.remove();
}
} else {
result = proxyChain.doProxyChain();
}
return result;
}
}
其中FLAG_HOLDER是一个标志,它可以保证同一线程中事务控制相关逻辑只会执行一次,通过ProxyChain对象可以获取目标方法,进而判断该方法是否带有Transaction注解
最后就是在框架加载时,需要把事务切面也给加载进来,所以需要改动一下AopHelper类:
/**
* 一个代理类会对应多个目标类(被代理的类)。
* 这个方法会返回各个代理类与目标类 的映射关系
*/
private static Map<Class<?>, Set<Class<?>>> createProxyMap() throws Exception{
Map<Class<?>, Set<Class<?>>> proxyMap = new HashMap<Class<?>, Set<Class<?>>>();
addAspectProxy(proxyMap);
addTransactionProxy(proxyMap);
return proxyMap;
}
/**
* 添加一般的切面代理
*/
private static void addAspectProxy(Map<Class<?>, Set<Class<?>>> proxyMap) throws Exception{
// 获取所有继承AspectProxy类的代理类
Set<Class<?>> proxyClassSet = ClassHelper.getClassSetBySuper(AspectProxy.class);
// 遍历所有代理类
for (Class<?> proxyClass : proxyClassSet){
// 如果当前代理类有Aspect注解
if (proxyClass.isAnnotationPresent(Aspect.class)){
Aspect aspect = proxyClass.getAnnotation(Aspect.class);
Set<Class<?>> targetClassSet = createTargetClassSet(aspect);
proxyMap.put(proxyClass, targetClassSet);
}
}
}
/**
* 添加事务代理
*/
private static void addTransactionProxy(Map<Class<?>, Set<Class<?>>> proxyMap){
Set<Class<?>> serviceClassSet = ClassHelper.getClassSetByAnnotation(Service.class);
proxyMap.put(TransactionProxy.class, serviceClassSet);
}
一个简单的事务框架就此完成