注意点:
此种方法是通过反射,拼接完整的merge into语句,然后通过mybatis去执行sql。
这种方法在更新时需要注意,传入的实体类对于属性值为null的,也会将原来对应的字段值修改为null。所以使用此种方法批量更新,需要将所有的值都塞入实体类。
oracle,需要主键自增的字段必须要有序列和触发器。其次,无论是db2还是oracle都需要配合@TableId注解来指定主键,实体类中多余的属性需要用@TableField(exist=flase)来排除。
附上源码,仅供参考:
import com.cmbchina.cc.mc.aptms.infrastructure.helper.McCmsDbSysParmCacheHelper;
import com.cmbchina.cc.mc.aptms.infrastructure.repository.db2.BatchMapper;
import com.cmbchina.cc.mc.aptms.infrastructure.util.SqlTemplateUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
/**
* @author LGQ
* @Description 通用批量操作方法
* @date 2022/8/2 19:04
*/
@Service
@Slf4j
public class CommonBatchService {
@Value("${batchSize:1000}")
private int batchSize;
@Autowired
private BatchMapper batchMapper;
@Autowired
private McCmsDbSysParmCacheHelper mcCmsDbSysParmCacheHelper;
public <T> Boolean batchSaveOrUpdate(List<T> sources, Class<T> clazz) {
return batchSaveOrUpdate(sources, clazz, batchSize);
}
public <T> Boolean batchSaveOrUpdate(List<T> sources, Class<T> clazz, int size) {
if (CollectionUtils.isEmpty(sources)) {
return false;
}
try {
Boolean isOracleDataSource = mcCmsDbSysParmCacheHelper.isChangeDatabase();
int sourceSize = sources.size();
for (int i = 0; i < sourceSize; i = i + size) {
List<T> subList;
if (i + size > sourceSize) {
subList = sources.subList(i, sourceSize);
} else {
subList = sources.subList(i, i + size);
}
String batchSaveOrUpdateSql = SqlTemplateUtils.batchSaveOrUpdateSql(subList, clazz, isOracleDataSource);
batchMapper.mergeInto(batchSaveOrUpdateSql);
}
} catch (Exception e) {
log.error(e.toString());
return false;
}
return true;
}
}
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Update;
/**
* @author LGQ
* @date 2022/8/2 19:02
*/
@Mapper
public interface BatchMapper extends BaseMapper {
@Update("${sql}")
void mergeInto(@Param("sql") String sql);
}
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.cmbchina.cc.mc.aptms.enums.MCSysErrorCodeType;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Field;
import java.util.List;
/**
* @author LGQ
* @date 2022/8/1 10:44
*/
@Slf4j
public class SqlTemplateUtils {
public static <T> String batchSaveOrUpdateSql(List<T> sources, Class<T> clazz, Boolean isOracleDataSource) {
if (sources.size() < 1) {
return null;
}
StringBuffer finalSqlTemplate = new StringBuffer("MERGE INTO ");
finalSqlTemplate.append(StringUtils.getTableName(clazz)).append(" u USING ( ");
StringBuffer tableIdName = new StringBuffer("");
StringBuffer basicSql = getBasicSql(sources, clazz, tableIdName, isOracleDataSource);
if (StringUtils.isEmptyString(tableIdName.toString())) {
throw new RuntimeException(MCSysErrorCodeType.UN_CAUGHT_TABLE_ID.getErrorMessage());
}
finalSqlTemplate.append(basicSql).append(" ) t on (").append(String.format("u.%s = t.%s) ", tableIdName, tableIdName));
StringBuffer lastSql = getLastSql(clazz, isOracleDataSource);
return finalSqlTemplate.append(lastSql).toString();
}
private static <T> StringBuffer getBasicSql(List<T> list, Class<T> clazz, StringBuffer tableIdName, Boolean isOracleDataSource) {
StringBuffer basicSql = new StringBuffer("");
Field[] declaredFields = clazz.getDeclaredFields();
int fieldsLength = declaredFields.length;
String tableIdValue = "";
for (int j = 0; j < list.size(); ) {
basicSql.append("SELECT ");
Object o = list.get(j);
for (int i = 0; i < fieldsLength; ) {
Field field = declaredFields[i];
TableField tableField = field.getAnnotation(TableField.class);
if (null != tableField && !tableField.exist()) {
i++;
continue;
}
TableId tableId = field.getAnnotation(TableId.class);
String tableFieldValue = getColumnByField(field);
if (null != tableId) {
tableIdValue = tableFieldValue;
}
if (StringUtils.isEmpty(tableFieldValue)) {
throw new RuntimeException(MCSysErrorCodeType.UN_CAUGHT_COLUMNS.getErrorMessage());
}
field.setAccessible(true);
try {
String simpleName = field.getType().getSimpleName();
Object fieldValue = field.get(o);
if (null == fieldValue) {
if (!isOracleDataSource && null != tableId) {
basicSql.append("-1 AS ").append(tableFieldValue);
} else {
basicSql.append("NULL AS ").append(tableFieldValue);
}
} else if ("String".equals(simpleName)) {
basicSql.append(String.format("'%s'", fieldValue)).append(" AS ").append(tableFieldValue);
} else if ("Timestamp".equals(simpleName)) {
basicSql.append(String.format("TIMESTAMP'%s'", fieldValue)).append(" AS ").append(tableFieldValue);
} else if ("Date".equals(simpleName)) {
if (isOracleDataSource) {
basicSql.append(String.format("TO_DATE('%s','yyyy-mm-dd hh24:mi:ss')", fieldValue)).append(" AS ").append(tableFieldValue);
} else {
basicSql.append(String.format("'%s'", fieldValue)).append(" AS ").append(tableFieldValue);
}
} else {
basicSql.append(fieldValue).append(" AS ").append(tableFieldValue);
}
if (++i < fieldsLength) {
basicSql.append(", ");
}
} catch (IllegalAccessException e) {
log.error(e.toString());
}
}
if (isOracleDataSource) {
basicSql.append(" FROM DUAL");
} else {
basicSql.append(" FROM sysibm.sysdummy1");
}
if (++j < list.size()) {
basicSql.append(" UNION ALL ");
}
}
tableIdName.append(tableIdValue);
return basicSql;
}
private static <T> StringBuffer getLastSql(Class<T> clazz, Boolean isOracleDataSource) {
StringBuffer matchedBuf = new StringBuffer(" WHEN MATCHED THEN UPDATE SET ");
String whereSql = "";
StringBuffer notMatchedBuf = new StringBuffer(" WHEN NOT MATCHED THEN INSERT ");
StringBuffer columns = new StringBuffer("( ");
StringBuffer values = new StringBuffer("VALUES ( ");
Field[] declaredFields = clazz.getDeclaredFields();
for (int i = 0; i < declaredFields.length; ) {
Field field = declaredFields[i];
TableField tableField = field.getAnnotation(TableField.class);
if (null != tableField && !tableField.exist()) {
i++;
continue;
}
TableId tableId = field.getAnnotation(TableId.class);
String tableFieldVal = getColumnByField(field);
if (null != tableId) {
if (!isOracleDataSource) {
i++;
continue;
}
whereSql = String.format(" WHERE u.%s = t.%s ", tableFieldVal, tableFieldVal);
columns.append(tableFieldVal).append(", ");
values.append(String.format("t.%s", tableFieldVal)).append(", ");
i++;
continue;
}
matchedBuf.append(String.format("u.%s = t.%s", tableFieldVal, tableFieldVal));
columns.append(tableFieldVal);
values.append(String.format("t.%s", tableFieldVal));
if (++i < declaredFields.length) {
matchedBuf.append(", ");
columns.append(", ");
values.append(", ");
} else {
columns.append(") ");
values.append(")");
}
}
matchedBuf.append(whereSql).append(notMatchedBuf).append(columns).append(values);
return matchedBuf;
}
/**
* @description 通过属性获取字段名
* @author LGQ
* @date 2022/8/4 9:40
*/
private static String getColumnByField(Field field) {
String fieldName = field.getName();
String[] split = fieldName.split("");
StringBuffer buffer = new StringBuffer("");
for (String s : split) {
if (s.equals(s.toUpperCase())) {
buffer.append("_").append(s.toUpperCase());
} else {
buffer.append(s.toUpperCase());
}
}
return buffer.toString();
}
}
import com.baomidou.mybatisplus.annotation.TableName;
/**
* @date 2020/09/18
**/
public class StringUtils {
private StringUtils() {
}
public static <T> String getTableName(Class<T> clazz) {
TableName tableName = clazz.getAnnotation(TableName.class);
if (null == tableName || isEmptyString(tableName.value())) {
throw new RuntimeException("未查询到表名");
}
return tableName.value();
}
public static boolean isEmpty(Object str) {
return str == null || "".equals(str);
}
public static boolean isNotEmpty(Object str) {
return str != null && !"".equals(str);
}
/**
* @description null转空字符串
* @author LGQ
* @date 2022/6/28 8:47
*/
public static String nullToString(String s) {
return isEmptyString(s) ? "" : s;
}
public static boolean isEmptyString(String s) {
if (s == null) {
return true;
}
return "".equals(s.trim());
}
/**
* 计算字符串在给定字符串中是否至少出现了几次
*
* @param src 字符串
* @param toFind 要查找的字符串
* @param atLeastCount 至少出现的次数
* @return boolean
*/
public static boolean occurAtLeastCount(String src, String toFind, int atLeastCount) {
if (isEmptyString(src) || isEmptyString(toFind) || atLeastCount < 1) {
return false;
}
int index = 0;
int count = 0;
while ((index = src.indexOf(toFind, index)) != -1) {
count++;
if (count >= atLeastCount) {
return true;
}
index = index + toFind.length();
}
return false;
}
public static boolean isBlank(String string) {
int strLen;
if (string == null || (strLen = string.length()) == 0)
return true;
for (int i = 0; i < strLen; i++)
if (!Character.isWhitespace(string.charAt(i)))
return false;
return true;
}
public static String trim(String string) {
return string != null ? string.trim() : null;
}
}