JAVA Postgresql 根据model类自动生成插入语句,批量插入数据,查询数据

27 篇文章 1 订阅
22 篇文章 0 订阅

JAVA Postgresql 根据model类自动生成插入语句,批量插入数据

  1. 所有的类可以继承自一个BaseObject类(非必需);
  2. 数据库表中的字段名及类型必须与Object类中的一一对应;
  3. 调用通用的方法(对数据进行批量插入);
  4. 调用通用的方法进行数据查询(适用于字段名与数据库一一对应的情况,field获取类型以及rs动态根据类型获取值);
  5. 递归获取父类子类所有的属性字段名;

String.format(“insert into %s(%s) values(%s)”, tableName, fieldNames, marks);

主要涉及到应用反射,获取model类的所有字段名;
拼接sql中的字段名,及占位符;
然后再次利用反射对每一个字段设置值。

源码

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/*************************************
 *Class Name: BaseObject
 *Description: <基础类>
 *@author: Seminar
 *@create: 2023/7/18
 *@since 1.0.0
 *************************************/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class BaseObject {

    private String id;
    private String name;
}
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/*************************************
 *Class Name: Student
 *Description: <子类>
 *@author: Semeinar
 *@create: 2023/7/18
 *@since 1.0.0
 *************************************/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class Student extends BaseObject {
    private String no;
    private Integer age;
    private Byte[] geom;
    private Long idcard;
    private Double score;
    private Boolean isGood;
}
package com.test;

package com.test.dao;

import com.test.model.BaseObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;

import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringJoiner;


/*************************************
 * Class Name: CommonDao
 * Description:〈通用数据Dao〉
 * @author Seminar
 * @Date 2023.01.18
 ************************************/
@Slf4j
public class CommonDao {
    @Autowired
    JdbcTemplate jdbcTemplate;

    /**
     * 根据字段数组拼接问号
     *
     * @param fields 字段数组
     * @return
     */
    protected String getQueryMark(String fields) {
        String[] arr = fields.split(",");
        int len = arr.length;
        StringJoiner joiner = new StringJoiner(",");
        while (len-- > 0) {
            joiner.add("?");
        }
        return joiner.toString();
    }

    /**
     * 获取表记录总数
     *
     * @param tableName 表名
     * @return 记录总数
     */
    public Integer getTotal(String tableName) {
        return getTotal(tableName, null);
    }

    /**
     * 根据条件获取表记录总数
     *
     * @param tableName 表名
     * @param whereSql  关键字
     * @return 记录总数
     */
    public Integer getTotal(String tableName, String whereSql) {
        String sql = String.format("select count(*) as total from %s", tableName);
        if (!StringUtils.isEmpty(whereSql)) {
            sql += " where " + whereSql;
        }
        return jdbcTemplate.queryForObject(sql, Integer.class);
    }

    /**
     * 判断查询结果集中是否存在某列
     *
     * @param rs         查询结果集
     * @param columnName 列名
     * @return true 存在; false 不存咋
     */
    public boolean isExistColumn(ResultSet rs, String columnName) {
        try {
            if (rs.findColumn(columnName) > 0) {
                return true;
            }
        } catch (SQLException e) {
            return false;
        }
        return false;
    }

    /**
     * 根据字段名获取按?拼接的字符串
     *
     * @param fields 字段名,每一个占位一个?,特殊的集合类型,特殊处理
     * @return
     */
    protected String getQueryMark2(String fields) {
        String[] fieldNames = fields.split(",");
        String delimeter = ",";
        StringBuilder sb = new StringBuilder();
        for (String field : fieldNames) {
            if ("geometry".equals(field) || "s_index".equals(field)) {
                sb.append("ST_GeomFromText(?)" + delimeter);
            } else {
                sb.append("?" + delimeter);
            }
        }

        return sb.substring(0, sb.length() - 1);
    }

    /**
     * 批量插入通用方法(注意需要Object类中所有字段名与数据库中表列名的一致)
     *
     * @param objs      对象lists
     * @param tableName 插入的表名
     * @param <T>       通用BaseObject类
     */
    public <T> void batchInsertRdClassT(List<T> objs, String tableName) {
        if (objs == null || objs.isEmpty()) {
            return;
        }
        String fieldNames = getFieldNames(objs.get(0).getClass());
        String marks = getQueryMark2(fieldNames);
        String sql = String.format("insert into %s(%s) values(%s)", tableName, fieldNames, marks);
        log.info("{}", sql);
        int[] res = jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
            @Override
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                T item = objs.get(i);
                Field[] fieldVals = item.getClass().getDeclaredFields();
                for (int j = 0; j < fieldVals.length; j++) {
                    Field field = fieldVals[j];
                    field.setAccessible(true);

                    try {
                        setValue(ps, j + 1, field.get(item));
                    } catch (IllegalAccessException e) {
                        throw new RuntimeException(e);
                    }
                }
            }

            @Override
            public int getBatchSize() {
                return objs.size();
            }
        });
    }

    /**
     * 对每一个字段设置值
     *
     * @param ps
     * @param index
     * @param column
     * @throws SQLException
     */
    private void setValue(PreparedStatement ps, int index, Object column) throws SQLException {
        if (column == null) {
            ps.setNull(index, 0);
        } else {
            if (column instanceof Long) {
                ps.setLong(index, (Long) column);
            } else if (column instanceof String) {
                ps.setString(index, (String) column);
            } else if (column instanceof Integer) {
                ps.setInt(index, (Integer) column);
            } else if (column instanceof Double) {
                ps.setDouble(index, (Double) column);
            } else if (column instanceof Boolean) {
                ps.setBoolean(index, (Boolean) column);
            } else if (column instanceof Byte[]) {
                ps.setBytes(index, (byte[]) column);
            }
        }
    }

    /**
     * 获取类的所有字段名,并,拼接
     * 同时获取父类的属性字段
     *
     * @param clazz 实体类名
     * @return
     */
    public static String getFieldNames(Class clazz) {
        List<Field> allFields = getAllFieldName(clazz);
        String[] fieldNames = new String[allFields.size()];
        for (int i = 0; i < allFields.size(); i++) {
            fieldNames[i] = allFields.get(i).getName();
        }
        return StringUtils.join(fieldNames, ",");
    }

    /**
     * 递归获取子类父类所有的public/private属性
     *
     * @param clazz
     * @return
     */
    private static List<Field> getAllFieldName(Class clazz) {
        List<Field> allFields = new ArrayList<>();

        // 获取当前对象的所有属性字段
        // clazz.getFields():获取public修饰的字段
        // clazz.getDeclaredFields(): 获取所有的字段包括private修饰的字段
        allFields.addAll(Arrays.asList(clazz.getDeclaredFields()));

        // 获取所有父类的字段, 父类中的字段需要逐级获取
        Class clazzSuper = clazz.getSuperclass();

        // 如果父类不是object,表明其继承的有其他类。 逐级获取所有父类的字段
        while (clazzSuper != Object.class) {
            allFields.addAll(Arrays.asList(clazzSuper.getDeclaredFields()));
            clazzSuper = clazzSuper.getSuperclass();
        }
        return allFields;
    }

    private BaseObject getBaseObject(String id, Class clazz, String tableName) {
        String fields = getFieldNames(clazz);
        String sql = String.format("select %s from %s where id = '%s'", fields, tableName, id);


        List<BaseObject> list = jdbcTemplate.query(sql, (rs, i) -> toBaseObject(rs, clazz));
        if (list != null && list.size() > 0) {
            return list.get(0);
        } else {
            return null;
        }
    }

    private BaseObject toBaseObject(ResultSet rs, Class clazz) {
        Object obj = null;
        try {
            obj = (BaseObject) clazz.newInstance();
            List<Field> fields = getAllFieldName(clazz);
            for (Field field : fields) {
                field.setAccessible(true);
                Object column = field.getType();

                if (rs.getString(field.getName()) == null || StringUtils.isEmpty(rs.getString(field.getName()))) {
                    continue;
                } else {
                    if (column == String.class) {
                        field.set(obj, rs.getString(field.getName()));
                    } else if (column == Long.class) {
                        field.set(obj, rs.getLong(field.getName()));
                    } else if (column == Integer.class) {
                        field.set(obj, rs.getInt(field.getName()));
                    } else if (column == Boolean.class) {
                        field.set(obj, rs.getBoolean(field.getName()));
                    } else if (column == Double.class) {
                        field.set(obj, rs.getDouble(field.getName()));
                    } else if (column == byte[].class) {
                        // 如果是geometry需要特殊处理
                        if ("geometry".equalsIgnoreCase(field.getName())) {
                            // **不能少了这一句
                            byte[] bytes = org.geotools.data.postgis.WKBReader.hexToBytes(rs.getString(field.getName()));
                            field.set(obj, bytes);

                            // 下边转用org.geotools.data.postgis.WKBReader、com.vividsolutions.jts.io.WKBReader、org.locationtech.jts.io.WKBReader转都可以
                            //org.geotools.data.postgis.WKBReader wkbReader = new org.geotools.data.postgis.WKBReader();
                        } else {
                            field.set(obj, rs.getBytes(field.getName()));
                        }
                    } else if (column == Double.class) {
                        field.set(obj, rs.getDouble(field.getName()));
                    } else if (column == Boolean.class) {
                        field.set(obj, rs.getBoolean(field.getName()));
                    }
                }
            }
        } catch (InstantiationException e) {
            throw new RuntimeException(e);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
        return (BaseObject) obj;
    }
}
		<dependency>
            <groupId>org.geotools.jdbc</groupId>
            <artifactId>gt-jdbc-postgis</artifactId>
            <version>17.1</version>
        </dependency>
        

参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序媛一枚~

您的鼓励是我创作的最大动力。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值