模板方法模式-实现简单的JdbcTemplate

文章展示了如何利用Java的Lambda表达式实现一个SQL执行的模板方法,该方法用于简化数据库的增删改查操作。通过抽象出SQL预处理和结果处理为函数式接口,减少了代码重复,提高了代码复用性。
摘要由CSDN通过智能技术生成

思路

        执行一条sql都有一个通用的模板顺序:

  1. 打开连接
  2. 预处理sql(构造预处理sql的字符串)
  3. 获取preparedStatement对象
  4. 填充预处理sql中的占位符
  5. 执行sql
  6. 处理sql执行结果集
  7. 关闭连接

        其中1、3、4、5、7在执行流程是不会变化的,变化的是2、6。由此我们可以构造一个执行sql的模板方法。

实现方式有两种:

  1. 抽象父类方式:其中抽象类实现步骤1、3、4、5、7和模板方法,实现类实现2、6
  2. lambda方式:全部在一个类中实现,将步骤2、6抽象为函数式接口即可。(此方式可以一定程度避免类爆炸)

下面采用lambda 表达式来实现

 实现代码

执行sql的模板方法

/**
     * sql执行的模板方法
     *
     * @param booleanSupplier          {@code true} if {@code booleanSupplier} 执行update
     *                                 {@code false} if {@code booleanSupplier} 执行query
     * @param preprocessingSQLSupplier 预处sql提供器
     * @param resultCallBack           结果集处理回调函数
     * @param args                     填充占位符的参数
     * @param <T>                      结果泛型
     * @return sql执行的结果
     * @throws SQLException
     */
    public <T> T execute(Supplier<Boolean> booleanSupplier, Supplier<String> preprocessingSQLSupplier, Function<Object, T> resultCallBack, Object... args) {
        Connection connection = null;
        PreparedStatement preparedStatement = null;
        T r;
        ResultSet resultSet = null;
        try {
            //1.获取连接对象
            connection = DBUtil.getConnection();
            //2.预处理sql
            String preSql = preprocessingSQLSupplier.get();
            //3.获取preparedStatement对象
            preparedStatement = connection.prepareStatement(preSql);
            //4.填充占位符
            for (int i = 0; i < args.length; i++) {
                preparedStatement.setObject(i + 1, args[i]);
            }
            //5.执行sql
            int result;
            resultSet = null;
            if (booleanSupplier.get()) {
                result = preparedStatement.executeUpdate();
                //6.处理结果集合
                r = resultCallBack.apply(result);
            } else {
                //6.处理结果集合
                resultSet = preparedStatement.executeQuery();
                r = resultCallBack.apply(resultSet);
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        } finally {
            //7.关闭连接
            DBUtil.close(connection, preparedStatement, resultSet);
        }
        return r;
    }

 添加操作调用模板方法

public <T> boolean add(T t) {
        Class<?> clz = t.getClass();
        Field[] fields = clz.getDeclaredFields();
        String tableName = clz.getSimpleName();
        Object[] args = new Object[fields.length];

        boolean r = execute(() -> true, () -> {
            try {
                String preSql = "insert into %s (%s) values(%s)";
                StringBuilder s2Builder = new StringBuilder();
                StringBuilder s3Builder = new StringBuilder();
                for (int i = 0; i < fields.length; i++) {
                    //开启暴力反射
                    fields[i].setAccessible(true);
                    args[i] = fields[i].get(t);
                    s2Builder.append(fields[i].getName());
                    s3Builder.append("?");
                    if (fields.length - i == 1) {
                        break;
                    }
                    s2Builder.append(",");
                    s3Builder.append(",");
                }
                preSql = String.format(preSql, tableName, s2Builder, s3Builder);
                return preSql;
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }, (o) -> {
            Integer result = (Integer) o;
            return result != 0 ? true : false;
        }, args);
        return r;
    }

更新操作调用模板方法

public <T> boolean update(T t) {
        Class<?> clz = t.getClass();
        Field[] fields = clz.getDeclaredFields();
        String tableName = clz.getSimpleName();
        Object[] args = new Object[fields.length];
        boolean r = execute(() -> true, () -> {
            try {
                List<Field> fieldList = Arrays.stream(fields).filter(field ->
                        !field.getName().equals("id")
                ).collect(Collectors.toList());

                String preSql = "update %s set %s where %s=%s";
                StringBuilder s2Builder = new StringBuilder();
                for (int i = 0; i < fieldList.size(); i++) {
                    Field field = fieldList.get(i);
                    //开启暴力反射
                    field.setAccessible(true);
                    args[i] = field.get(t);
                    s2Builder.append(field.getName());
                    s2Builder.append("=?");
                    if (fieldList.size() - i == 1) {
                        break;
                    }
                    s2Builder.append(",");
                }
                Field idField = clz.getDeclaredField("id");
                idField.setAccessible(true);
                args[args.length - 1] = idField.get(t);
                preSql = String.format(preSql, tableName, s2Builder, idField.getName(), "?");
                return preSql;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }, o -> {
            Integer result = (Integer) o;
            return result != 0 ? true : false;
        }, args);
        return r;
    }

查询操作调用模板方法

public <T> List<T> query(Class<T> clz, Map<String, Object> params) {
        Field[] fields = clz.getDeclaredFields();
        String tableName = clz.getSimpleName();
        List<String> keyList = params.keySet().stream().collect(Collectors.toList());
        Object[] args = new Object[keyList.size()];
        List<T> r = execute(() -> false, () -> {
            String preSql = "select %s from %s where 1=1  %s";
            StringBuilder s1Builder = new StringBuilder();
            for (int i = 0; i < fields.length; i++) {
                s1Builder.append(fields[i].getName());
                if (fields.length - i == 1) {
                    break;
                }
                s1Builder.append(",");
            }
            StringBuilder s3Builder = new StringBuilder();
            for (int i = 0; i < keyList.size(); i++) {
                args[i] = params.get(keyList.get(i));
                s3Builder.append(" and ");
                s3Builder.append(keyList.get(i));
                s3Builder.append("= ?");
            }
            preSql = String.format(preSql, s1Builder,tableName,  s3Builder);
            return preSql;
        }, (result) -> {
            try {
                ResultSet resultSet = (ResultSet) result;
                List<T> list = new ArrayList<>();
                ResultSetMetaData metaData = resultSet.getMetaData();
                int columnCount = metaData.getColumnCount();

                while (resultSet.next()) {
                    T t = clz.newInstance();
                    for (int i = 0; i < columnCount; i++) {
                        String columnName = metaData.getColumnName(i + 1);
                        Object columnValue = resultSet.getObject(i + 1);
                        Field declaredField = clz.getDeclaredField(columnName);
                        declaredField.setAccessible(true);
                        declaredField.set(t, columnValue);
                    }
                    list.add(t);
                }
                return list;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }, args);
        return r;
}

数据库工具类

public class DBUtil {

    private static String username;

    private static String password;

    private static String url;

    static {
        try {
            Driver driver = (Driver) Class.forName("com.mysql.cj.jdbc.Driver").newInstance();
            DriverManager.registerDriver(driver);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static void init(String _url, String _username, String _password){
        url =  _url;
        username =  _username;
        password =  _password;
    }

    public static Connection getConnection(String url, String username, String password){
        try {
            return DriverManager.getConnection(url, username, password);
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public static Connection getConnection() {
        return getConnection(url, username, password);
    }

    public static void close(Connection connection, Statement statement, ResultSet resultSet) {
        try {
            if (Objects.nonNull(resultSet)) resultSet.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        try {
            if (Objects.nonNull(statement)) statement.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        try {
            if (Objects.nonNull(connection)) connection.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

}

JdbpTemplate简单的实现类

public class JdbcTemplate {

    /**
     * 添加
     * @param t 实体
     * @return
     * @param <T>
     */
    public <T> boolean add(T t) {
        Class<?> clz = t.getClass();
        Field[] fields = clz.getDeclaredFields();
        String tableName = clz.getSimpleName();
        Object[] args = new Object[fields.length];

        boolean r = execute(() -> true, () -> {
            try {
                String preSql = "insert into %s (%s) values(%s)";
                StringBuilder s2Builder = new StringBuilder();
                StringBuilder s3Builder = new StringBuilder();
                for (int i = 0; i < fields.length; i++) {
                    //开启暴力反射
                    fields[i].setAccessible(true);
                    args[i] = fields[i].get(t);
                    s2Builder.append(fields[i].getName());
                    s3Builder.append("?");
                    if (fields.length - i == 1) {
                        break;
                    }
                    s2Builder.append(",");
                    s3Builder.append(",");
                }
                preSql = String.format(preSql, tableName, s2Builder, s3Builder);
                return preSql;
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }, (o) -> {
            Integer result = (Integer) o;
            return result != 0 ? true : false;
        }, args);
        return r;
    }


    /**
     * 根据id更新数据
     * @param t 实体
     * @return
     * @param <T>
     */
    public <T> boolean update(T t) {
        Class<?> clz = t.getClass();
        Field[] fields = clz.getDeclaredFields();
        String tableName = clz.getSimpleName();
        Object[] args = new Object[fields.length];
        boolean r = execute(() -> true, () -> {
            try {
                List<Field> fieldList = Arrays.stream(fields).filter(field ->
                        !field.getName().equals("id")
                ).collect(Collectors.toList());

                String preSql = "update %s set %s where %s=%s";
                StringBuilder s2Builder = new StringBuilder();
                for (int i = 0; i < fieldList.size(); i++) {
                    Field field = fieldList.get(i);
                    //开启暴力反射
                    field.setAccessible(true);
                    args[i] = field.get(t);
                    s2Builder.append(field.getName());
                    s2Builder.append("=?");
                    if (fieldList.size() - i == 1) {
                        break;
                    }
                    s2Builder.append(",");
                }
                Field idField = clz.getDeclaredField("id");
                idField.setAccessible(true);
                args[args.length - 1] = idField.get(t);
                preSql = String.format(preSql, tableName, s2Builder, idField.getName(), "?");
                return preSql;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }, o -> {
            Integer result = (Integer) o;
            return result != 0 ? true : false;
        }, args);
        return r;
    }

    public <T> List<T> query(Class<T> clz, Map<String, Object> params) {
        Field[] fields = clz.getDeclaredFields();
        String tableName = clz.getSimpleName();
        List<String> keyList = params.keySet().stream().collect(Collectors.toList());
        Object[] args = new Object[keyList.size()];
        List<T> r = execute(() -> false, () -> {

            String preSql = "select %s from %s where 1=1  %s";
            StringBuilder s1Builder = new StringBuilder();
            for (int i = 0; i < fields.length; i++) {
                s1Builder.append(fields[i].getName());
                if (fields.length - i == 1) {
                    break;
                }
                s1Builder.append(",");
            }

            StringBuilder s3Builder = new StringBuilder();
            for (int i = 0; i < keyList.size(); i++) {
                args[i] = params.get(keyList.get(i));
                s3Builder.append(" and ");
                s3Builder.append(keyList.get(i));
                s3Builder.append("= ?");
            }
            preSql = String.format(preSql, s1Builder,tableName,  s3Builder);
            return preSql;
        }, (result) -> {
            try {
                ResultSet resultSet = (ResultSet) result;
                List<T> list = new ArrayList<>();
                ResultSetMetaData metaData = resultSet.getMetaData();
                int columnCount = metaData.getColumnCount();

                while (resultSet.next()) {
                    T t = clz.newInstance();
                    for (int i = 0; i < columnCount; i++) {
                        String columnName = metaData.getColumnName(i + 1);
                        Object columnValue = resultSet.getObject(i + 1);
                        Field declaredField = clz.getDeclaredField(columnName);
                        declaredField.setAccessible(true);
                        declaredField.set(t, columnValue);
                    }
                    list.add(t);
                }
                return list;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }, args);
        return r;
    }

    /**
     * sql执行的模板方法
     *
     * @param booleanSupplier          {@code true} if {@code booleanSupplier} 执行update
     *                                 {@code false} if {@code booleanSupplier} 执行query
     * @param preprocessingSQLSupplier 预处sql提供器
     * @param resultCallBack           结果集处理回调函数
     * @param args                     填充占位符的参数
     * @param <T>                      结果泛型
     * @return sql执行的结果
     * @throws SQLException
     */
    public <T> T execute(Supplier<Boolean> booleanSupplier, Supplier<String> preprocessingSQLSupplier, Function<Object, T> resultCallBack, Object... args) {
        Connection connection = null;
        PreparedStatement preparedStatement = null;
        T r;
        ResultSet resultSet = null;
        try {
            //1.获取连接对象
            connection = DBUtil.getConnection();
            //2.预处理sql
            String preSql = preprocessingSQLSupplier.get();
            //3.获取preparedStatement对象
            preparedStatement = connection.prepareStatement(preSql);
            //4.填充占位符
            for (int i = 0; i < args.length; i++) {
                preparedStatement.setObject(i + 1, args[i]);
            }
            //5.执行sql
            int result;
            resultSet = null;
            if (booleanSupplier.get()) {
                result = preparedStatement.executeUpdate();
                //6.处理结果集合
                r = resultCallBack.apply(result);
            } else {
                //6.处理结果集合
                resultSet = preparedStatement.executeQuery();
                r = resultCallBack.apply(resultSet);
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        } finally {
            DBUtil.close(connection, preparedStatement, resultSet);
        }
        return r;
    }



}

测试类

public class Client {
    public static void main(String[] args)  {

        DBUtil.init("jdbc:mysql://localhost:3307/test?useSSL=false&useUnicode=true&characterEncoding=utf-8&zeroDateTimeBehavior=convertToNull&transformedBitIsBoolean=true&serverTimezone=GMT%2B8"
                , "root"
                ,"root"
        );
        JdbcTemplate jdbcTemplate = new JdbcTemplate();
        //新增测试
      /*
        User user1 = new User(10000047L,"张三", (byte) 1,1000,"xxxxa@foxmail.com",18,"13988888888");
        boolean result1 = jdbcTemplate.add(user1);
        System.out.println("是否新增成功:"+result1);
        */

        //修改测试
       /*
        User user2 = new User(10000047L,"李四", (byte) 1,1000,"xxxxfff@foxmail.com",18,"13988888888");
        boolean result2 = jdbcTemplate.update(user2);
        System.out.println("是否修改成功:"+result2);
        */

        //查询测试
        HashMap<String, Object> queryMap = new HashMap<>();
        queryMap.put("age",18);
        queryMap.put("balance",1000);
        List<User> userList = jdbcTemplate.query(User.class, queryMap);
        userList.forEach(System.out::println);
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值