JAVA Postgresql 根据model类自动生成插入语句,批量插入数据
- 所有的类可以继承自一个BaseObject类(非必需);
- 数据库表中的字段名及类型必须与Object类中的一一对应;
- 调用通用的方法(对数据进行批量插入);
- 调用通用的方法进行数据查询(适用于字段名与数据库一一对应的情况,field获取类型以及rs动态根据类型获取值);
- 递归获取父类子类所有的属性字段名;
String.format(“insert into %s(%s) values(%s)”, tableName, fieldNames, marks);
主要涉及到应用反射,获取model类的所有字段名;
拼接sql中的字段名,及占位符;
然后再次利用反射对每一个字段设置值。
源码
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/*************************************
*Class Name: BaseObject
*Description: <基础类>
*@author: Seminar
*@create: 2023/7/18
*@since 1.0.0
*************************************/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class BaseObject {
private String id;
private String name;
}
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/*************************************
*Class Name: Student
*Description: <子类>
*@author: Semeinar
*@create: 2023/7/18
*@since 1.0.0
*************************************/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class Student extends BaseObject {
private String no;
private Integer age;
private Byte[] geom;
private Long idcard;
private Double score;
private Boolean isGood;
}
package com.test;
package com.test.dao;
import com.test.model.BaseObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringJoiner;
/*************************************
* Class Name: CommonDao
* Description:〈通用数据Dao〉
* @author Seminar
* @Date 2023.01.18
************************************/
@Slf4j
public class CommonDao {
@Autowired
JdbcTemplate jdbcTemplate;
/**
* 根据字段数组拼接问号
*
* @param fields 字段数组
* @return
*/
protected String getQueryMark(String fields) {
String[] arr = fields.split(",");
int len = arr.length;
StringJoiner joiner = new StringJoiner(",");
while (len-- > 0) {
joiner.add("?");
}
return joiner.toString();
}
/**
* 获取表记录总数
*
* @param tableName 表名
* @return 记录总数
*/
public Integer getTotal(String tableName) {
return getTotal(tableName, null);
}
/**
* 根据条件获取表记录总数
*
* @param tableName 表名
* @param whereSql 关键字
* @return 记录总数
*/
public Integer getTotal(String tableName, String whereSql) {
String sql = String.format("select count(*) as total from %s", tableName);
if (!StringUtils.isEmpty(whereSql)) {
sql += " where " + whereSql;
}
return jdbcTemplate.queryForObject(sql, Integer.class);
}
/**
* 判断查询结果集中是否存在某列
*
* @param rs 查询结果集
* @param columnName 列名
* @return true 存在; false 不存咋
*/
public boolean isExistColumn(ResultSet rs, String columnName) {
try {
if (rs.findColumn(columnName) > 0) {
return true;
}
} catch (SQLException e) {
return false;
}
return false;
}
/**
* 根据字段名获取按?拼接的字符串
*
* @param fields 字段名,每一个占位一个?,特殊的集合类型,特殊处理
* @return
*/
protected String getQueryMark2(String fields) {
String[] fieldNames = fields.split(",");
String delimeter = ",";
StringBuilder sb = new StringBuilder();
for (String field : fieldNames) {
if ("geometry".equals(field) || "s_index".equals(field)) {
sb.append("ST_GeomFromText(?)" + delimeter);
} else {
sb.append("?" + delimeter);
}
}
return sb.substring(0, sb.length() - 1);
}
/**
* 批量插入通用方法(注意需要Object类中所有字段名与数据库中表列名的一致)
*
* @param objs 对象lists
* @param tableName 插入的表名
* @param <T> 通用BaseObject类
*/
public <T> void batchInsertRdClassT(List<T> objs, String tableName) {
if (objs == null || objs.isEmpty()) {
return;
}
String fieldNames = getFieldNames(objs.get(0).getClass());
String marks = getQueryMark2(fieldNames);
String sql = String.format("insert into %s(%s) values(%s)", tableName, fieldNames, marks);
log.info("{}", sql);
int[] res = jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
T item = objs.get(i);
Field[] fieldVals = item.getClass().getDeclaredFields();
for (int j = 0; j < fieldVals.length; j++) {
Field field = fieldVals[j];
field.setAccessible(true);
try {
setValue(ps, j + 1, field.get(item));
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
@Override
public int getBatchSize() {
return objs.size();
}
});
}
/**
* 对每一个字段设置值
*
* @param ps
* @param index
* @param column
* @throws SQLException
*/
private void setValue(PreparedStatement ps, int index, Object column) throws SQLException {
if (column == null) {
ps.setNull(index, 0);
} else {
if (column instanceof Long) {
ps.setLong(index, (Long) column);
} else if (column instanceof String) {
ps.setString(index, (String) column);
} else if (column instanceof Integer) {
ps.setInt(index, (Integer) column);
} else if (column instanceof Double) {
ps.setDouble(index, (Double) column);
} else if (column instanceof Boolean) {
ps.setBoolean(index, (Boolean) column);
} else if (column instanceof Byte[]) {
ps.setBytes(index, (byte[]) column);
}
}
}
/**
* 获取类的所有字段名,并,拼接
* 同时获取父类的属性字段
*
* @param clazz 实体类名
* @return
*/
public static String getFieldNames(Class clazz) {
List<Field> allFields = getAllFieldName(clazz);
String[] fieldNames = new String[allFields.size()];
for (int i = 0; i < allFields.size(); i++) {
fieldNames[i] = allFields.get(i).getName();
}
return StringUtils.join(fieldNames, ",");
}
/**
* 递归获取子类父类所有的public/private属性
*
* @param clazz
* @return
*/
private static List<Field> getAllFieldName(Class clazz) {
List<Field> allFields = new ArrayList<>();
// 获取当前对象的所有属性字段
// clazz.getFields():获取public修饰的字段
// clazz.getDeclaredFields(): 获取所有的字段包括private修饰的字段
allFields.addAll(Arrays.asList(clazz.getDeclaredFields()));
// 获取所有父类的字段, 父类中的字段需要逐级获取
Class clazzSuper = clazz.getSuperclass();
// 如果父类不是object,表明其继承的有其他类。 逐级获取所有父类的字段
while (clazzSuper != Object.class) {
allFields.addAll(Arrays.asList(clazzSuper.getDeclaredFields()));
clazzSuper = clazzSuper.getSuperclass();
}
return allFields;
}
private BaseObject getBaseObject(String id, Class clazz, String tableName) {
String fields = getFieldNames(clazz);
String sql = String.format("select %s from %s where id = '%s'", fields, tableName, id);
List<BaseObject> list = jdbcTemplate.query(sql, (rs, i) -> toBaseObject(rs, clazz));
if (list != null && list.size() > 0) {
return list.get(0);
} else {
return null;
}
}
private BaseObject toBaseObject(ResultSet rs, Class clazz) {
Object obj = null;
try {
obj = (BaseObject) clazz.newInstance();
List<Field> fields = getAllFieldName(clazz);
for (Field field : fields) {
field.setAccessible(true);
Object column = field.getType();
if (rs.getString(field.getName()) == null || StringUtils.isEmpty(rs.getString(field.getName()))) {
continue;
} else {
if (column == String.class) {
field.set(obj, rs.getString(field.getName()));
} else if (column == Long.class) {
field.set(obj, rs.getLong(field.getName()));
} else if (column == Integer.class) {
field.set(obj, rs.getInt(field.getName()));
} else if (column == Boolean.class) {
field.set(obj, rs.getBoolean(field.getName()));
} else if (column == Double.class) {
field.set(obj, rs.getDouble(field.getName()));
} else if (column == byte[].class) {
// 如果是geometry需要特殊处理
if ("geometry".equalsIgnoreCase(field.getName())) {
// **不能少了这一句
byte[] bytes = org.geotools.data.postgis.WKBReader.hexToBytes(rs.getString(field.getName()));
field.set(obj, bytes);
// 下边转用org.geotools.data.postgis.WKBReader、com.vividsolutions.jts.io.WKBReader、org.locationtech.jts.io.WKBReader转都可以
//org.geotools.data.postgis.WKBReader wkbReader = new org.geotools.data.postgis.WKBReader();
} else {
field.set(obj, rs.getBytes(field.getName()));
}
} else if (column == Double.class) {
field.set(obj, rs.getDouble(field.getName()));
} else if (column == Boolean.class) {
field.set(obj, rs.getBoolean(field.getName()));
}
}
}
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (SQLException e) {
throw new RuntimeException(e);
}
return (BaseObject) obj;
}
}
<dependency>
<groupId>org.geotools.jdbc</groupId>
<artifactId>gt-jdbc-postgis</artifactId>
<version>17.1</version>
</dependency>