让mybatis sql注解支持 IN 多参数传递

import java.util.List;
import java.util.Map;

import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.springframework.stereotype.Repository;

/**
 *
 */
@Repository
public interface UserDao {
	@Select("select count(1) from cms_user where role in (#{roles})")
	int countAllUsers(@Param("roles") int... roles);

	@Select("select count(1) from cms_user where role in #{roles} AND 1=1")
	int countAllUsers2(@Param("roles") List<Integer> roles);

	@Select("select count(1) from cms_user where name in #{names}")
	int countAllUsers3(@Param("names") String... names);

	@Select("SELECT id,task_id,current_count,`status`,update_time FROM user_task_center_record WHERE user_id in #{userIds} AND task_id in #{taskIds} AND update_time > curdate()")
	List<Map<Object, Object>> getTodayTaskRecords2(@Param("userIds") int[] userIds,@Param("taskIds") List<Integer> taskIds
			);
	@Select("SELECT id,task_id,current_count,`status`,update_time FROM user_task_center_record WHERE user_id in #{userIds} AND task_id in #{taskIds} AND update_time > curdate()")
	List<Map<Object, Object>> getTodayTaskRecords3(@Param("taskIds") List<Integer> taskIds,@Param("userIds") int[] userIds
			);

	@Select("SELECT id,task_id,current_count,`status`,update_time FROM user_task_center_record WHERE user_id=#{userId} AND task_id in #{taskIds} AND update_time > curdate()")
	List<Map<Object, Object>> getTodayTaskRecords(@Param("userIds") int userId,
			@Param("taskIds") List<Integer> taskIds);
}

mybatis 不支持上面这样的调用,需要自己配置sql构建类,比较麻烦。

一是繁琐,重复劳动,违反unix理念,能让机器做的就不要hack.

二是sql可读性变差,无法直观的通过方法注解看到sql语句,还要去找配置类。


如何解决这个问题呢?

一开始想到的方案是利用ps.setArray()方法,无奈mysql驱动这块没有做实现。

这里提供一个方法可以实现,将下面的类贴在项目中即可。

原理也比较简单,就是在处理PreparedStatement的sql参数之前把 sql里面的IN 参数替换成实际的参数,并且在变量列表中删去对应的项,这样的话ps就不用处理IN参数了。

其他的方案也可以尝试,比如 把in ? 替换成in(?,?,?...)再对ps进行set,不过显然更复杂。


/** 
 *    Copyright 2009-2015 the original author or authors. 
 * 
 *    Licensed under the Apache License, Version 2.0 (the "License"); 
 *    you may not use this file except in compliance with the License. 
 *    You may obtain a copy of the License at 
 * 
 *       http://www.apache.org/licenses/LICENSE-2.0 
 * 
 *    Unless required by applicable law or agreed to in writing, software 
 *    distributed under the License is distributed on an "AS IS" BASIS, 
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 *    See the License for the specific language governing permissions and 
 *    limitations under the License. 
 */
package org.apache.ibatis.builder;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.session.Configuration;

/**
 * @author Clinton Begin
 * @author houkx -修改历史:支持sql in 多参数处理
 */
public class StaticSqlSource implements SqlSource {

    private String sql;
    private List<ParameterMapping> parameterMappings;
    private Configuration configuration;
    private SqlSegEntry[] sqlCache;
    private List<ParameterMapping> parameterMappingsForIn;

    public StaticSqlSource(Configuration configuration, String sql) {
        this(configuration, sql, null);
    }

    public StaticSqlSource(Configuration configuration, String sql, List<ParameterMapping> parameterMappings) {
        this.sql = sql;
        this.parameterMappings = parameterMappings;
        this.configuration = configuration;
        if (parameterMappings != null && parameterMappings.size() > 0) {
            this.sqlCache = cache(sql);
            if (sqlCache != null) {
                final int LEN = parameterMappings.size();
                HashSet<Integer> indexSet = new HashSet<>();
                for (SqlSegEntry e : sqlCache) {
                    if (e.index >= 0) {
                        indexSet.add(e.index);
                    }
                }
                ArrayList<ParameterMapping> paramMappings = new ArrayList<>(LEN);
                for (int i = 0; i < LEN; i++) {
                    if (!indexSet.contains(i)) {
                        paramMappings.add(parameterMappings.get(i));
                    }
                }
                parameterMappingsForIn = paramMappings;
            }
        }
    }

    @Override
    public BoundSql getBoundSql(Object parameterObject) {
        String sql = this.sql;
        List<ParameterMapping> parameterMappings = this.parameterMappings;
        if (sqlCache != null && (parameterObject instanceof Map)) {
            Map<?, ?> argMap = (Map<?, ?>) parameterObject;
            StringBuilder sqlBuilder = new StringBuilder(sql.length() + 128);
            for (SqlSegEntry e : sqlCache) {
                sqlBuilder.append(e.segment);
                if (e.index >= 0) {
                    ParameterMapping m = parameterMappings.get(e.index);
                    Object obj = argMap.get(m.getProperty());
                    appendParameter(sqlBuilder, obj);
                }
            }
            sql = sqlBuilder.toString();
            parameterMappings = this.parameterMappingsForIn;
        }
        return new BoundSql(configuration, sql, parameterMappings, parameterObject);
    }

    private void appendParameter(StringBuilder sb, Object arrObj) {
        Class<?> clazz = arrObj.getClass();
        final char STR = '\'';
        if (clazz.isArray()) {
            sb.append('(');
            int len = Array.getLength(arrObj);
            if (len > 0) {
                Class<?> eClass = clazz.getComponentType();
                if (eClass.isPrimitive()) {
                    sb.append(Array.get(arrObj, 0));
                    for (int i = 1; i < len; i++) {
                        sb.append(',').append(Array.get(arrObj, i));
                    }
                } else {
                    Object[] args = (Object[]) arrObj;
                    if (CharSequence.class.isAssignableFrom(eClass)) {
                        sb.append(STR).append(args[0]).append(STR);
                        for (int i = 1; i < len; i++) {
                            sb.append(',').append(STR).append(args[i]).append(STR);
                        }
                    } else {
                        sb.append(args[0]);
                        for (int i = 1; i < len; i++) {
                            sb.append(',').append(args[i]);
                        }
                    }
                }
            }
            sb.append(')');
        } else if (Collection.class.isAssignableFrom(clazz)) {
            sb.append('(');
            Collection<?> col = (Collection<?>) arrObj;
            Iterator<?> i = col.iterator();
            if (i.hasNext()) {
                boolean isString = false;
                Object eObj = i.next();
                isString = eObj instanceof CharSequence;
                if (isString) {
                    sb.append(STR).append(eObj).append(STR);
                    while (i.hasNext()) {
                        sb.append(',').append(STR).append(i.next()).append(STR);
                    }
                } else {
                    sb.append(eObj);
                    while (i.hasNext()) {
                        sb.append(',').append(i.next());
                    }
                }
            }
            sb.append(')');
        }
    }

    private static SqlSegEntry[] cache(String sql) {
        String[] segs = sql.split("[?]");
        int countIn = 0;
        ArrayList<SqlSegEntry> list = new ArrayList<>(segs.length);
        StringBuilder sb = new StringBuilder(sql.length());
        for (int i = 0; i < segs.length; i++) {
            String s = segs[i], ts = s.trim().toLowerCase(), tm = ts;
            if (ts.endsWith("(")) {
                tm = ts.substring(0, ts.length() - 1).trim();
            }
            if (tm.endsWith(" in")) {
                int st = 0;
                if (ts.charAt(0) == ')') {
                    st = s.indexOf(')') + 1;
                }
                if (tm.length() != ts.length()) {
                    s = s.substring(st, s.lastIndexOf('('));
                } else if (st > 0) {
                    s = s.substring(st);
                }
                sb.append(s);
                SqlSegEntry e = new SqlSegEntry();
                e.index = i;
                e.segment = sb.toString();
                list.add(e);
                sb.delete(0, sb.length());
                countIn++;
            } else {
                if (ts.charAt(0) == ')') {
                    s = s.substring(s.indexOf(')') + 1);
                }
                sb.append(s);
                if (i < segs.length - 1) {
                    sb.append('?');
                }
            }
        }
        if (countIn == 0) {
            return null;
        }
        if (sb.length() > 0) {
            SqlSegEntry e = new SqlSegEntry();
            e.segment = sb.toString();
            list.add(e);
        }
        sb = null;
        return list.toArray(new SqlSegEntry[list.size()]);
    }

    // --------
    private static class SqlSegEntry {
        String segment;
        int index = -1;
    }
}






评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值