思路
执行一条sql都有一个通用的模板顺序:
- 打开连接
- 预处理sql(构造预处理sql的字符串)
- 获取preparedStatement对象
- 填充预处理sql中的占位符
- 执行sql
- 处理sql执行结果集
- 关闭连接
其中1、3、4、5、7在执行流程是不会变化的,变化的是2、6。由此我们可以构造一个执行sql的模板方法。
实现方式有两种:
- 抽象父类方式:其中抽象类实现步骤1、3、4、5、7和模板方法,实现类实现2、6
- lambda方式:全部在一个类中实现,将步骤2、6抽象为函数式接口即可。(此方式可以一定程度避免类爆炸)
下面采用lambda 表达式来实现
实现代码
执行sql的模板方法
/**
* sql执行的模板方法
*
* @param booleanSupplier {@code true} if {@code booleanSupplier} 执行update
* {@code false} if {@code booleanSupplier} 执行query
* @param preprocessingSQLSupplier 预处sql提供器
* @param resultCallBack 结果集处理回调函数
* @param args 填充占位符的参数
* @param <T> 结果泛型
* @return sql执行的结果
* @throws SQLException
*/
public <T> T execute(Supplier<Boolean> booleanSupplier, Supplier<String> preprocessingSQLSupplier, Function<Object, T> resultCallBack, Object... args) {
Connection connection = null;
PreparedStatement preparedStatement = null;
T r;
ResultSet resultSet = null;
try {
//1.获取连接对象
connection = DBUtil.getConnection();
//2.预处理sql
String preSql = preprocessingSQLSupplier.get();
//3.获取preparedStatement对象
preparedStatement = connection.prepareStatement(preSql);
//4.填充占位符
for (int i = 0; i < args.length; i++) {
preparedStatement.setObject(i + 1, args[i]);
}
//5.执行sql
int result;
resultSet = null;
if (booleanSupplier.get()) {
result = preparedStatement.executeUpdate();
//6.处理结果集合
r = resultCallBack.apply(result);
} else {
//6.处理结果集合
resultSet = preparedStatement.executeQuery();
r = resultCallBack.apply(resultSet);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
//7.关闭连接
DBUtil.close(connection, preparedStatement, resultSet);
}
return r;
}
添加操作调用模板方法
public <T> boolean add(T t) {
Class<?> clz = t.getClass();
Field[] fields = clz.getDeclaredFields();
String tableName = clz.getSimpleName();
Object[] args = new Object[fields.length];
boolean r = execute(() -> true, () -> {
try {
String preSql = "insert into %s (%s) values(%s)";
StringBuilder s2Builder = new StringBuilder();
StringBuilder s3Builder = new StringBuilder();
for (int i = 0; i < fields.length; i++) {
//开启暴力反射
fields[i].setAccessible(true);
args[i] = fields[i].get(t);
s2Builder.append(fields[i].getName());
s3Builder.append("?");
if (fields.length - i == 1) {
break;
}
s2Builder.append(",");
s3Builder.append(",");
}
preSql = String.format(preSql, tableName, s2Builder, s3Builder);
return preSql;
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}, (o) -> {
Integer result = (Integer) o;
return result != 0 ? true : false;
}, args);
return r;
}
更新操作调用模板方法
public <T> boolean update(T t) {
Class<?> clz = t.getClass();
Field[] fields = clz.getDeclaredFields();
String tableName = clz.getSimpleName();
Object[] args = new Object[fields.length];
boolean r = execute(() -> true, () -> {
try {
List<Field> fieldList = Arrays.stream(fields).filter(field ->
!field.getName().equals("id")
).collect(Collectors.toList());
String preSql = "update %s set %s where %s=%s";
StringBuilder s2Builder = new StringBuilder();
for (int i = 0; i < fieldList.size(); i++) {
Field field = fieldList.get(i);
//开启暴力反射
field.setAccessible(true);
args[i] = field.get(t);
s2Builder.append(field.getName());
s2Builder.append("=?");
if (fieldList.size() - i == 1) {
break;
}
s2Builder.append(",");
}
Field idField = clz.getDeclaredField("id");
idField.setAccessible(true);
args[args.length - 1] = idField.get(t);
preSql = String.format(preSql, tableName, s2Builder, idField.getName(), "?");
return preSql;
} catch (Exception e) {
throw new RuntimeException(e);
}
}, o -> {
Integer result = (Integer) o;
return result != 0 ? true : false;
}, args);
return r;
}
查询操作调用模板方法
public <T> List<T> query(Class<T> clz, Map<String, Object> params) {
Field[] fields = clz.getDeclaredFields();
String tableName = clz.getSimpleName();
List<String> keyList = params.keySet().stream().collect(Collectors.toList());
Object[] args = new Object[keyList.size()];
List<T> r = execute(() -> false, () -> {
String preSql = "select %s from %s where 1=1 %s";
StringBuilder s1Builder = new StringBuilder();
for (int i = 0; i < fields.length; i++) {
s1Builder.append(fields[i].getName());
if (fields.length - i == 1) {
break;
}
s1Builder.append(",");
}
StringBuilder s3Builder = new StringBuilder();
for (int i = 0; i < keyList.size(); i++) {
args[i] = params.get(keyList.get(i));
s3Builder.append(" and ");
s3Builder.append(keyList.get(i));
s3Builder.append("= ?");
}
preSql = String.format(preSql, s1Builder,tableName, s3Builder);
return preSql;
}, (result) -> {
try {
ResultSet resultSet = (ResultSet) result;
List<T> list = new ArrayList<>();
ResultSetMetaData metaData = resultSet.getMetaData();
int columnCount = metaData.getColumnCount();
while (resultSet.next()) {
T t = clz.newInstance();
for (int i = 0; i < columnCount; i++) {
String columnName = metaData.getColumnName(i + 1);
Object columnValue = resultSet.getObject(i + 1);
Field declaredField = clz.getDeclaredField(columnName);
declaredField.setAccessible(true);
declaredField.set(t, columnValue);
}
list.add(t);
}
return list;
} catch (Exception e) {
throw new RuntimeException(e);
}
}, args);
return r;
}
数据库工具类
public class DBUtil {
private static String username;
private static String password;
private static String url;
static {
try {
Driver driver = (Driver) Class.forName("com.mysql.cj.jdbc.Driver").newInstance();
DriverManager.registerDriver(driver);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static void init(String _url, String _username, String _password){
url = _url;
username = _username;
password = _password;
}
public static Connection getConnection(String url, String username, String password){
try {
return DriverManager.getConnection(url, username, password);
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
public static Connection getConnection() {
return getConnection(url, username, password);
}
public static void close(Connection connection, Statement statement, ResultSet resultSet) {
try {
if (Objects.nonNull(resultSet)) resultSet.close();
} catch (SQLException e) {
e.printStackTrace();
}
try {
if (Objects.nonNull(statement)) statement.close();
} catch (SQLException e) {
e.printStackTrace();
}
try {
if (Objects.nonNull(connection)) connection.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
JdbpTemplate简单的实现类
public class JdbcTemplate {
/**
* 添加
* @param t 实体
* @return
* @param <T>
*/
public <T> boolean add(T t) {
Class<?> clz = t.getClass();
Field[] fields = clz.getDeclaredFields();
String tableName = clz.getSimpleName();
Object[] args = new Object[fields.length];
boolean r = execute(() -> true, () -> {
try {
String preSql = "insert into %s (%s) values(%s)";
StringBuilder s2Builder = new StringBuilder();
StringBuilder s3Builder = new StringBuilder();
for (int i = 0; i < fields.length; i++) {
//开启暴力反射
fields[i].setAccessible(true);
args[i] = fields[i].get(t);
s2Builder.append(fields[i].getName());
s3Builder.append("?");
if (fields.length - i == 1) {
break;
}
s2Builder.append(",");
s3Builder.append(",");
}
preSql = String.format(preSql, tableName, s2Builder, s3Builder);
return preSql;
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}, (o) -> {
Integer result = (Integer) o;
return result != 0 ? true : false;
}, args);
return r;
}
/**
* 根据id更新数据
* @param t 实体
* @return
* @param <T>
*/
public <T> boolean update(T t) {
Class<?> clz = t.getClass();
Field[] fields = clz.getDeclaredFields();
String tableName = clz.getSimpleName();
Object[] args = new Object[fields.length];
boolean r = execute(() -> true, () -> {
try {
List<Field> fieldList = Arrays.stream(fields).filter(field ->
!field.getName().equals("id")
).collect(Collectors.toList());
String preSql = "update %s set %s where %s=%s";
StringBuilder s2Builder = new StringBuilder();
for (int i = 0; i < fieldList.size(); i++) {
Field field = fieldList.get(i);
//开启暴力反射
field.setAccessible(true);
args[i] = field.get(t);
s2Builder.append(field.getName());
s2Builder.append("=?");
if (fieldList.size() - i == 1) {
break;
}
s2Builder.append(",");
}
Field idField = clz.getDeclaredField("id");
idField.setAccessible(true);
args[args.length - 1] = idField.get(t);
preSql = String.format(preSql, tableName, s2Builder, idField.getName(), "?");
return preSql;
} catch (Exception e) {
throw new RuntimeException(e);
}
}, o -> {
Integer result = (Integer) o;
return result != 0 ? true : false;
}, args);
return r;
}
public <T> List<T> query(Class<T> clz, Map<String, Object> params) {
Field[] fields = clz.getDeclaredFields();
String tableName = clz.getSimpleName();
List<String> keyList = params.keySet().stream().collect(Collectors.toList());
Object[] args = new Object[keyList.size()];
List<T> r = execute(() -> false, () -> {
String preSql = "select %s from %s where 1=1 %s";
StringBuilder s1Builder = new StringBuilder();
for (int i = 0; i < fields.length; i++) {
s1Builder.append(fields[i].getName());
if (fields.length - i == 1) {
break;
}
s1Builder.append(",");
}
StringBuilder s3Builder = new StringBuilder();
for (int i = 0; i < keyList.size(); i++) {
args[i] = params.get(keyList.get(i));
s3Builder.append(" and ");
s3Builder.append(keyList.get(i));
s3Builder.append("= ?");
}
preSql = String.format(preSql, s1Builder,tableName, s3Builder);
return preSql;
}, (result) -> {
try {
ResultSet resultSet = (ResultSet) result;
List<T> list = new ArrayList<>();
ResultSetMetaData metaData = resultSet.getMetaData();
int columnCount = metaData.getColumnCount();
while (resultSet.next()) {
T t = clz.newInstance();
for (int i = 0; i < columnCount; i++) {
String columnName = metaData.getColumnName(i + 1);
Object columnValue = resultSet.getObject(i + 1);
Field declaredField = clz.getDeclaredField(columnName);
declaredField.setAccessible(true);
declaredField.set(t, columnValue);
}
list.add(t);
}
return list;
} catch (Exception e) {
throw new RuntimeException(e);
}
}, args);
return r;
}
/**
* sql执行的模板方法
*
* @param booleanSupplier {@code true} if {@code booleanSupplier} 执行update
* {@code false} if {@code booleanSupplier} 执行query
* @param preprocessingSQLSupplier 预处sql提供器
* @param resultCallBack 结果集处理回调函数
* @param args 填充占位符的参数
* @param <T> 结果泛型
* @return sql执行的结果
* @throws SQLException
*/
public <T> T execute(Supplier<Boolean> booleanSupplier, Supplier<String> preprocessingSQLSupplier, Function<Object, T> resultCallBack, Object... args) {
Connection connection = null;
PreparedStatement preparedStatement = null;
T r;
ResultSet resultSet = null;
try {
//1.获取连接对象
connection = DBUtil.getConnection();
//2.预处理sql
String preSql = preprocessingSQLSupplier.get();
//3.获取preparedStatement对象
preparedStatement = connection.prepareStatement(preSql);
//4.填充占位符
for (int i = 0; i < args.length; i++) {
preparedStatement.setObject(i + 1, args[i]);
}
//5.执行sql
int result;
resultSet = null;
if (booleanSupplier.get()) {
result = preparedStatement.executeUpdate();
//6.处理结果集合
r = resultCallBack.apply(result);
} else {
//6.处理结果集合
resultSet = preparedStatement.executeQuery();
r = resultCallBack.apply(resultSet);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
DBUtil.close(connection, preparedStatement, resultSet);
}
return r;
}
}
测试类
public class Client {
public static void main(String[] args) {
DBUtil.init("jdbc:mysql://localhost:3307/test?useSSL=false&useUnicode=true&characterEncoding=utf-8&zeroDateTimeBehavior=convertToNull&transformedBitIsBoolean=true&serverTimezone=GMT%2B8"
, "root"
,"root"
);
JdbcTemplate jdbcTemplate = new JdbcTemplate();
//新增测试
/*
User user1 = new User(10000047L,"张三", (byte) 1,1000,"xxxxa@foxmail.com",18,"13988888888");
boolean result1 = jdbcTemplate.add(user1);
System.out.println("是否新增成功:"+result1);
*/
//修改测试
/*
User user2 = new User(10000047L,"李四", (byte) 1,1000,"xxxxfff@foxmail.com",18,"13988888888");
boolean result2 = jdbcTemplate.update(user2);
System.out.println("是否修改成功:"+result2);
*/
//查询测试
HashMap<String, Object> queryMap = new HashMap<>();
queryMap.put("age",18);
queryMap.put("balance",1000);
List<User> userList = jdbcTemplate.query(User.class, queryMap);
userList.forEach(System.out::println);
}
}