由于开启事务的jdbc大数据量批量插入效率比mybatis高很多,因此采用jdbc插入,但是插入时对象和属性不同,导致很多的set操作,重复代码高,可读性差,因此利用反射和注解封装成工具类进行解决此问题
import java.lang.annotation.*;
/**
* 表属性名注解类,用于jdbc自动生成简单插入的sql语句和进行字段set操作,以提高大批量数据插入的效率
*
* @author jiangle
* @date 2022-07-28
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface TableFieldName {
// 属性名称
String value();
// 是否拼接存在更新字段,true表示是,false表示否,默认是
boolean flag() default true;
}
import com.baomidou.mybatisplus.annotation.TableName;
import com.craiditx.datacollector.common.annotation.TableFieldName;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
/**
* 反射工具类
*
* @author jiangle
* @date 2021/09/03
*/
public class ReflectUtil {
private static final Logger LOG = LoggerFactory.getLogger(ReflectUtil.class);
private static final String COMMA = ",";
private static final String LEFT_BRACKET = "(";
private static final String RIGHT_BRACKET = ")";
private static final String SPACE = " ";
private static final String QUESTION_MARK = "?";
/**
* 根据对象和属性名称获取字符串值
*
* @param obj 对象
* @param fieldName 属性名称
* @return 字符串值
*/
public static String getFieldValueByFieldName(Object obj, String fieldName) {
Object value = getFieldValueObjectByFieldName(obj, fieldName);
if (!Objects.isNull(value)) {
return value.toString();
}
return StringUtils.EMPTY;
}
/**
* 根据对象和属性名称获取值对象
*
* @param obj 对象
* @param fieldName 属性名称
* @return 值对象
*/
public static Object getFieldValueObjectByFieldName(Object obj, String fieldName) {
//得到class
Class cls = obj.getClass();
//得到所有属性
Field[] fields = cls.getDeclaredFields();
for (int i = 0; i < fields.length; i++) {//遍历
try {
//得到属性
Field field = fields[i];
//打开私有访问
field.setAccessible(true);
//获取属性
String name = field.getName();
if (StringUtils.equals(name, fieldName)) {
return field.get(obj);
}
} catch (IllegalAccessException e) {
LOG.info("根据属性名称反射获取属性值时异常!");
}
}
return null;
}
/**
* 根据对象、属性名称、值设置到对象里面
*
* @param object 对象
* @param fieldName 属性名称
* @param value 值
*/
public static void setValueByFieldName(Object object, String fieldName, String value) {
if (object == null || StringUtils.isBlank(fieldName)) {
return;
}
try {
String methodName = "set" + fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
Method method = object.getClass().getMethod(methodName, String.class);
method.invoke(object, value);
} catch (Exception e) {
return;
}
}
/**
* 根据类获取sql和封装map
*
* @param cls 类
* @param indexFieldNameMap key为下标,value为对应属性名称
* @return sql
*/
public static String getSqlByObject(Class cls, LinkedHashMap<Integer, String> indexFieldNameMap) {
Annotation annotation = cls.getAnnotation(TableName.class);
if (Objects.isNull(annotation)) {
return StringUtils.EMPTY;
}
TableName tableName = (TableName) annotation;
String tableNameStr = tableName.value();
if (StringUtils.isBlank(tableNameStr)) {
return StringUtils.EMPTY;
}
Field[] fields = cls.getDeclaredFields();
if (Objects.isNull(fields) || fields.length == 0) {
return StringUtils.EMPTY;
}
StringBuilder preStringBuilder = new StringBuilder("insert into " + tableNameStr + LEFT_BRACKET);
StringBuilder suffixStringBuilder = new StringBuilder(" values(");
StringBuilder duplicateStringBuilder = new StringBuilder();
int index = 1;
for (int i = 0; i < fields.length; i++) {
Field field = fields[i];
field.setAccessible(true);
TableFieldName tableFieldName = field.getAnnotation(TableFieldName.class);
if (!Objects.isNull(tableFieldName)) {
String fieldName = tableFieldName.value();
if (StringUtils.isNotBlank(fieldName)) {
preStringBuilder.append(fieldName).append(COMMA);
suffixStringBuilder.append(QUESTION_MARK).append(COMMA);
indexFieldNameMap.put(index, field.getName());
index++;
}
boolean flag = tableFieldName.flag();
if (flag) {
duplicateStringBuilder.append(fieldName).append("=VALUES(").append(fieldName).append(RIGHT_BRACKET).append(COMMA);
}
}
}
preStringBuilder.deleteCharAt(preStringBuilder.length()-1);
suffixStringBuilder.deleteCharAt(suffixStringBuilder.length()-1);
suffixStringBuilder.append(RIGHT_BRACKET);
preStringBuilder.append(RIGHT_BRACKET);
if (StringUtils.isNotBlank(duplicateStringBuilder)) {
duplicateStringBuilder.deleteCharAt(duplicateStringBuilder.length()-1);
return preStringBuilder.append(suffixStringBuilder).append(" ON DUPLICATE KEY UPDATE ").append(duplicateStringBuilder).toString();
}
return preStringBuilder.append(suffixStringBuilder).toString();
}
/**
* JDBC设置下标和对应值
*
* @param prepareStatement 预处理对象
* @param indexFieldNameMap key为下标,value为对应属性名称
* @param obj 对象
* @throws SQLException sql执行异常
*/
public static void setTableFieldValue(PreparedStatement prepareStatement, LinkedHashMap<Integer, String> indexFieldNameMap, Object obj) throws SQLException {
if (Objects.isNull(prepareStatement) || MapUtils.isEmpty(indexFieldNameMap)) {
return;
}
for (Map.Entry<Integer, String> entry : indexFieldNameMap.entrySet()) {
prepareStatement.setObject(entry.getKey(), getFieldValueObjectByFieldName(obj, entry.getValue()));
}
}
}
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.LinkedHashMap;
import java.util.List;
/**
* jdbc批量插入封装类
*
* @author jiangle
* @date 2022-07-22
*/
@Component
@Slf4j
public class JdbcUtil {
@Autowired
private DataSource dataSource;
/**
* 获取数据库连接
*
* @return 数据库连接
*/
public Connection getConnection(){
Connection conn = null;
try {
conn = dataSource.getConnection();
} catch (Exception e) {
e.printStackTrace();
}
return conn;
}
/**
* 关闭流
*
* @param obj 流对象
*/
public void close(AutoCloseable obj) {
if (obj != null) {
try {
obj.close();
} catch (Exception e) {
log.info(e.getMessage());
}
}
}
/**
* 根据实体类和需要插入的数据集合进行批量插入到表中
*
* @param cls 实体类
* @param list 需要插入的数据集合
*/
public void batchInsertList(Class cls, List<?> list) {
Connection connection = getConnection();
PreparedStatement prepareStatement = null;
try {
connection.setAutoCommit(false);
LinkedHashMap<Integer, String> indexFieldNameMap = new LinkedHashMap<>();
String sql = ReflectUtil.getSqlByObject(cls, indexFieldNameMap);
log.info("jdbc的sql="+sql);
prepareStatement = connection.prepareStatement(sql);
for (Object object : list) {
ReflectUtil.setTableFieldValue(prepareStatement, indexFieldNameMap, object);
prepareStatement.addBatch();
}
prepareStatement.executeBatch();
connection.commit();
} catch (SQLException e) {
log.info(e.getMessage());
} finally {
close(prepareStatement);
close(connection);
}
}
}
测试:
@Data
@EqualsAndHashCode(callSuper = true)
@Accessors(chain = true)
@NoArgsConstructor
@AllArgsConstructor
@TableName("ios_search_reports")
public class IosSearchReports {
@TableFieldName(value = "keyword_id")
private Long keywordId;
@TableFieldName(value = "keyword")
private String keyword;
@TableFieldName(value = "search_term_text")
private String searchTermText;
@TableFieldName(value = "country_or_region")
private String countryOrRegion;
@TableFieldName(value = "search_term_source")
private String searchTermSource;
@TableFieldName(value = "match_type")
private String matchType;
@TableFieldName(value = "ad_group_id")
private Long adGroupId;
@TableFieldName(value = "ad_group_name")
private String adGroupName;
@TableFieldName(value = "bid_amount")
private String bidAmount;
@TableFieldName(value = "local_spend")
private String localSpend;
@TableFieldName(value = "avg_cpt")
private String avgCpt;
@TableFieldName(value = "avg_cpa")
private String avgCpa;
@TableFieldName(value = "impressions")
private Integer impressions;
@TableFieldName(value = "taps")
private Integer taps;
@TableFieldName(value = "installs")
private Integer installs;
@TableFieldName(value = "conversion_rate")
private BigDecimal conversionRate;
@TableFieldName(value = "lat_on_installs")
private Integer latOnInstalls;
@TableFieldName(value = "lat_off_installs")
private Integer latOffInstalls;
@TableFieldName(value = "new_downloads")
private Long newDownloads;
@TableFieldName(value = "redownloads")
private Long redownloads;
@TableFieldName(value = "date")
private String date;
}
伪代码
@Service
@Slf4j
public class Demo {
@Autowired
private JdbcUtil jdbcUtil;
public void test() {
List<IosSearchReports> iosSearchReportsList = new ArrayList<>();
// 省略自己设置值
if (iosSearchReportsList.size() > 0) {
jdbcUtil.batchInsertList(IosSearchReports.class, iosSearchReportsList);
}
}
}