通过反射注解批量插入数据到DB

版权声明:本文为博主原创文章,未经博主允许可以转载。 https://blog.csdn.net/u013673976/article/details/44646135
批量导入思路

最近遇到一个需要批量导入数据问题。后来考虑运用反射做成一个工具类,思路是首先定义注解接口,在bean类上加注解,运行时通过反射获取传入Bean的注解,自动生成需要插入DB的SQL,根据设置的参数值批量提交。不需要写具体的SQL,也没有DAO的实现,这样一来批量导入的实现就和具体的数据库表彻底解耦。实际批量执行的SQL如下:
insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) VALUES (?,?,?,?,?,?,?,?) ON DUPLICATE KEY UPDATE type=?,weight=?,score=?

第一步,定义注解接口

注解接口Table中定义了数据库名和表名。RetentionPolicy.RUNTIME表示该注解保存到运行时,因为我们需要在运行时,去读取注解参数来生成具体的SQL。

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Table {

    /**
     * 表名
     * @return
     */
    String tableName() default "";

    /**
     * 数据库名称
     * @return
     */
    String dbName();
}

注解接口TableField中定义了数据库表名的各个具体字段名称,以及该字段是否忽略(忽略的话就会以数据库表定义默认值填充,DB非null字段的注解不允许出现把ignore注解设置为true)。update注解是在主键在DB重复时,需要更新的字段。

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface TableField {

    /**
     * 对应数据库字段名称
     * 
     * @return
     */
    String fieldName() default "";

    /**
     * 是否是主键
     * 
     * @return
     */
    boolean pk() default false;

    /**
     * 是否忽略该字段
     * 
     * @return
     */
    boolean ignore() default false;

    /**
     * 当数据存在时,是否更新该字段
     * 
     * @return
     */
    boolean update() default false;
}
第二步,给Bean添加注解

给Bean添加注解(为了简洁省略了import和set/get方法以及其他属性),@TableField(fieldName = "company_id")表示companyId字段对应DB表的字段名为"company_id",其中updateTime属性的注解含有ignore=true,表示该属性值会被忽略。另外serialVersionUID属性由于没有@TableField注解,在更新DB时也会被忽略。
代码如下:

@Table(dbName = "company", tableName = "company_candidate")
public class CompanyCandidateModel implements Serializable{
	private static final long serialVersionUID = -1234554321773322135L;
	@TableField(fieldName = "company_id")
	private int companyId;
	@TableField(fieldName = "user_id")
	private int userId;
	//名片id
	@TableField(fieldName = "card_id")
	private int cardId;
	//facebookId
	@TableField(fieldName = "facebook_id")
	private long facebookId;
    @TableField(fieldName="type", update = true)
	private int type;
	@TableField(fieldName = "create_time")
	private Date createTime;
	@TableField(fieldName = "update_time", ignore=true)
	private Date updateTime;
	// 权重
    @TableField(fieldName="weight", update = true)
	private int weight;
	// 分值
    @TableField(fieldName="score", update = true)
	private double score;
第三步,读取注解的反射工具类

读取第二步Bean类的注解的反射工具类。利用反射getAnnotation(TableField.class)读取注解信息,为批量SQL的拼接最好准备。
getTableBeanFieldMap()方法里生成一个LinkedHashMap对象,是为了保证生成插入SQL的field顺序,之后也能按同样的顺序给参数赋值,避免错位。getSqlParamFields()方法也类似,是为了给PreparedStatement设置参数用。
代码如下:

public class ReflectUtil {
    /**
     * <Class,<表定义Field名,Bean定义Field>>的map缓存
     */
    private static final Map<Class<?>, Map<string field="">> classTableBeanFieldMap = new HashMap<Class<?>, Map<string field="">>();
    // 用来按顺序填充SQL参数,其中存储的Field和classTableBeanFieldMap保存同样的顺序,但数量多出ON DUPLICATE KEY UPDATE部分Field
    private static final Map<Class<?>, List<field>> sqlParamFieldsMap = new HashMap<Class<?>, List<field>>(); 
    private ReflectUtil(){};
	
    /**
     * 获取该类上所有@TableField注解,且没有忽略的字段的Map。
     * <br />返回一个有序的LinkedHashMap类型
     * <br />其中key为DB表中的字段,value为Bean类里的属性Field对象
     * @param clazz
     * @return
     */
    public static Map<string field=""> getTableBeanFieldMap(Class<?> clazz) {
    	// 从缓存获取
    	Map<string field=""> fieldsMap = classTableBeanFieldMap.get(clazz);
    	if (fieldsMap == null) {
    		fieldsMap = new LinkedHashMap<string field="">();
            for (Field field : clazz.getDeclaredFields()) {// 获得所有声明属性数组的一个拷贝
            	TableField annotation = field.getAnnotation(TableField.class);
                if (annotation != null && !annotation.ignore() && !"".equals(annotation.fieldName())) {
                    field.setAccessible(true);// 方便后续获取私有域的值
                	fieldsMap.put(annotation.fieldName(), field);
                }
			}
            // 放入缓存
            classTableBeanFieldMap.put(clazz, fieldsMap);
    	}
    	return fieldsMap;
    }
	
    /**
     * 获取该类上所有@TableField注解,且没有忽略的字段的Map。ON DUPLICATE KEY UPDATE后需要更新的字段追加在list最后,为了填充参数值准备
     * <br />返回一个有序的ArrayList类型
     * <br />其中key为DB表中的字段,value为Bean类里的属性Field对象
     * @param clazz
     * @return
     */
    public static List<field> getSqlParamFields(Class<?> clazz) {

    	// 从缓存获取
    	List<field> sqlParamFields = sqlParamFieldsMap.get(clazz);
    	if (sqlParamFields == null) {
    		// 获取所有参数字段
        	Map<string field=""> fieldsMap = getTableBeanFieldMap(clazz);
    		sqlParamFields = new ArrayList<field>(fieldsMap.size() * 2);
        	// SQL后段ON DUPLICATE KEY UPDATE需要更新的字段
        	List<field> updateParamFields = new ArrayList<field>();

    		Iterator<Entry<string field="">> iter = fieldsMap.entrySet().iterator();
    		while (iter.hasNext()) {
    			Entry<string field=""> entry = (Entry<string field="">) iter.next();
    			Field field = entry.getValue();
    			// insert语句对应sql参数字段
    			sqlParamFields.add(field);

                // ON DUPLICATE KEY UPDATE后面语句对应sql参数字段
                TableField annotation = field.getAnnotation(TableField.class);
    			if (annotation != null && !annotation.ignore() && annotation.update()) {
    				updateParamFields.add(field);
    			}
    		}
    		sqlParamFields.addAll(updateParamFields);

            // 放入缓存
    		sqlParamFieldsMap.put(clazz, sqlParamFields);
    	}
    	return sqlParamFields;
    }

    /**
     * 获取表名,对象中使用@Table的tableName来标记对应数据库的表名,若未标记则自动将类名转成小写
     * 
     * @param clazz
     * @return
     */
    public static String getTableName(Class<?> clazz) {
        Table table = clazz.getAnnotation(Table.class);
        if (table != null && table.tableName() != null && !"".equals(table.tableName())) {
            return table.tableName();
        }
        // 当未配置@Table的tableName,自动将类名转成小写
        return clazz.getSimpleName().toLowerCase();
    }

    /**
     * 获取数据库名,对象中使用@Table的dbName来标记对应数据库名
     * 
     * @param clazz
     * @return
     */
    public static String getDBName(Class<?> clazz) {
        Table table = clazz.getAnnotation(Table.class);
        if (table != null && table.dbName() != null) {
            // 注解@Table的dbName
            return table.dbName();
        }
        return "";
    }
第四步,生成SQL语句

根据上一步的方法,生成真正执行的SQL语句。insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) VALUES (?,?,?,?,?,?,?,?) ON DUPLICATE KEY UPDATE type=?,weight=?,score=?
代码如下:

public class SQLUtil {
    private static final char COMMA = ',';
    private static final char BRACKETS_BEGIN = '(';
    private static final char BRACKETS_END = ')';
    private static final char QUESTION_MARK = '?';
    private static final char EQUAL_SIGN = '=';

    private static final String INSERT_BEGIN = "INSERT INTO ";
    private static final String INSERT_VALURS = " VALUES ";
    private static final String DUPLICATE_UPDATE = " ON DUPLICATE KEY UPDATE ";

    // 数据库表名和对应insertupdateSQL的缓存
    private static final Map<string string=""> tableInsertSqlMap = new HashMap<string string="">();

    /**
     * 获取插入的sql语句,对象中使用@TableField的fieldName来标记对应数据库的列名,若未标记则忽略
     * 必须标记@TableField(fieldName = "company_id")注解
     * 
     * @param tableName
     * @param fieldsMap
     * @return
     * @throws Exception
     */
    public static String getInsertSql(String tableName, Map<string field=""> fieldsMap) throws Exception {

    	String sql = tableInsertSqlMap.get(tableName);
    	if (sql == null) {
    		StringBuilder sbSql = new StringBuilder(300).append(INSERT_BEGIN);
    		StringBuilder sbValue = new StringBuilder(INSERT_VALURS);
    		StringBuilder sbUpdate = new StringBuilder(100).append(DUPLICATE_UPDATE);
    		sbSql.append(tableName);
    		sbSql.append(BRACKETS_BEGIN);
    		sbValue.append(BRACKETS_BEGIN);

    		Iterator<Entry<string field="">> iter = fieldsMap.entrySet().iterator();
    		while (iter.hasNext()) {
    			Entry<string field=""> entry = (Entry<string field="">) iter.next();
    			String tableFieldName = entry.getKey();
    			Field field = entry.getValue();
			
    			sbSql.append(tableFieldName);
    			sbSql.append(COMMA);

    			sbValue.append(QUESTION_MARK);
    			sbValue.append(COMMA);
            
    			TableField tableField = field.getAnnotation(TableField.class);
    			if (tableField != null && tableField.update()) {
    				sbUpdate.append(tableFieldName);
    				sbUpdate.append(EQUAL_SIGN);
    				sbUpdate.append(QUESTION_MARK);
    				sbUpdate.append(COMMA);
    			}
    		}
    		// 去掉最后的逗号
    		sbSql.deleteCharAt(sbSql.length() - 1);
    		sbValue.deleteCharAt(sbValue.length() - 1);
		
    		sbSql.append(BRACKETS_END);
    		sbValue.append(BRACKETS_END);
    		sbSql.append(sbValue);
        
    		if (!sbUpdate.toString().equals(DUPLICATE_UPDATE)) {
    			sbUpdate.deleteCharAt(sbUpdate.length() - 1);
    			sbSql.append(sbUpdate);
    		}
    		sql = sbSql.toString();
    		tableInsertSqlMap.put(tableName, sql);
    	}
        return sql;
    }
第五步,批量SQL插入实现

从连接池获取Connection,SQLUtil.getInsertSql()获取执行的SQL语句,根据sqlParamFields来为PreparedStatement填充参数值。当循环的值集合到达batchNum时就提交一次。
代码如下:

    /**
     * 批量插入,如果主键一致则更新。结果返回更新记录条数<br />
     * @param dataList
     *            要插入的对象List
     * @param batchNum
     *            每次批量插入条数
     * @return 更新记录条数
     */
    public int batchInsertSQL(List<? extends Object> dataList, int batchNum) throws Exception {
    	if (dataList == null || dataList.isEmpty()) {
    		return 0;
    	}
        Class<?> clazz = dataList.get(0).getClass();
        String tableName = ReflectUtil.getTableName(clazz);
        String dbName = ReflectUtil.getDBName(clazz);

        Connection connnection = null;
        PreparedStatement preparedStatement = null;
        // 获取所有需要更新到DB的属性域
        Map<string field=""> fieldsMap = ReflectUtil.getTableBeanFieldMap(dataList.get(0).getClass());
        // 根据需要插入更新的字段生成SQL语句
        String sql = SQLUtil.getInsertSql(tableName, fieldsMap);
        log.debug("prepare to start batch operation , sql = " + sql + " , dbName = " + dbName);
        // 获取和SQL语句同样顺序的填充参数Fields
        List<field> sqlParamFields = ReflectUtil.getSqlParamFields(dataList.get(0).getClass());
        // 最终更新结果条数
        int result = 0;
        int parameterIndex = 1;// SQL填充参数开始位置为1
        // 执行错误的对象
        List<object> errorsRecords = new ArrayList</object><object>(batchNum);//指定数组大小
        // 计数器,batchNum提交后内循环累计次数
        int innerCount = 0;
        try {
            connnection = this.getConnection(dbName);
            // 设置非自动提交
            connnection.setAutoCommit(false);
            preparedStatement = connnection.prepareStatement(sql);
            // 当前操作的对象
            Object object = null;
            int totalRecordCount = dataList.size();
            for (int current = 0; current < totalRecordCount; current++) {
                innerCount++;
                object = dataList.get(current);
            	parameterIndex = 1;// 开始参数位置为1
            	for(Field field : sqlParamFields) {
            		// 放入insert语句对应sql参数
                    preparedStatement.setObject(parameterIndex++, field.get(object));
            	}

            	errorsRecords.add(object);
                preparedStatement.addBatch();
                // 达到批量次数就提交一次
                if (innerCount >= batchNum || current >= totalRecordCount - 1) {
                    // 执行batch操作
                    preparedStatement.executeBatch();
                    preparedStatement.clearBatch();
                    // 提交
                    connnection.commit();
                    // 记录提交成功条数
                    result += innerCount;
                    innerCount = 0;
                    errorsRecords.clear();
                }
                // 尽早让GC回收
                dataList.set(current, null);
            }
            return result;
        } catch (Exception e) {
            // 失败后处理方法
            CallBackImpl.getInstance().exectuer(sql, errorsRecords, e);
            BatchDBException be = new BatchDBException("batch run error , dbName = " + dbName + " sql = " + sql, e);
            be.initCause(e);
            throw be;
        } finally {
            // 关闭
            if (preparedStatement != null) {
            	preparedStatement.clearBatch();
                preparedStatement.close();
            }
            if (connnection != null)
                connnection.close();
        }
    }
最后,批量工具类使用例子

在mysql下的开发环境下测试,5万条数据大概13秒。

List<companycandidatemodel> updateDataList = new ArrayList<companycandidatemodel>(50000);
// ...为updateDataList填充数据
int result = batchJdbcTemplate.batchInsertSQL(updateDataList, 50);
阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页