jpa中自己拼接where条件写一大堆Predicate和CriteriaBuilder进行拼接写一大堆if-else 会有大量代码冗余,于是出现了帮助类来解决where条件生成的代码
jpa的where条件生成帮助类为两个:
Query用来指定使用什么方式查询;QueryHelp用来生成条件语句。
Query 类:
package com.starcity.utils;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import com.starcity.annotation.Query;
import lombok.extern.slf4j.Slf4j;
import org.hibernate.query.criteria.internal.ValueHandlerFactory;
import javax.persistence.criteria.*;
import java.lang.reflect.Field;
import java.util.*;
import java.util.stream.Collectors;
import static java.util.stream.Collectors.joining;
/**
* @author WeiMaoMao
* @date
*/
@Slf4j
@SuppressWarnings({"unchecked", "all"})
public class QueryHelp {
public static <R, Q> Predicate getPredicate(Root<R> root, Q query, CriteriaBuilder cb) {
List<Predicate> list = new ArrayList<>();
if (query == null) {
return cb.and(list.toArray(new Predicate[0]));
}
try {
List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
for (Field field : fields) {
boolean accessible = field.isAccessible();
field.setAccessible(true);
Query q = field.getAnnotation(Query.class);
if (q != null) {
String propName = q.propName();
String joinName = q.joinName();
String blurry = q.blurry();
String attributeName = isBlank(propName) ? field.getName() : propName;
Class<?> fieldType = field.getType();
Object val = field.get(query);
if (ObjectUtil.isNull(val) || "".equals(val)) {
continue;
}
Join join = null;
// 模糊多字段
if (ObjectUtil.isNotEmpty(blurry)) {
String[] blurrys = blurry.split(",");
List<Predicate> orPredicate = new ArrayList<>();
for (String s : blurrys) {
orPredicate.add(cb.like(root.get(s)
.as(String.class), "%" + val.toString() + "%"));
}
Predicate[] p = new Predicate[orPredicate.size()];
list.add(cb.or(orPredicate.toArray(p)));
continue;
}
if (ObjectUtil.isNotEmpty(joinName)) {
String[] joinNames = joinName.split(">");
for (String name : joinNames) {
switch (q.join()) {
case LEFT:
if (ObjectUtil.isNotNull(join)) {
join = join.join(name, JoinType.LEFT);
} else {
join = root.join(name, JoinType.LEFT);
}
break;
case RIGHT:
if (ObjectUtil.isNotNull(join)) {
join = join.join(name, JoinType.RIGHT);
} else {
join = root.join(name, JoinType.RIGHT);
}
break;
default:
break;
}
}
}
switch (q.type()) {
case EQUAL:
list.add(cb.equal(getExpression(attributeName, join, root)
.as((Class<? extends Comparable>) fieldType), val));
break;
case GREATER_THAN:
list.add(cb.greaterThanOrEqualTo(getExpression(attributeName, join, root)
.as((Class<? extends Comparable>) fieldType), (Comparable) val));
break;
case LESS_THAN:
list.add(cb.lessThanOrEqualTo(getExpression(attributeName, join, root)
.as((Class<? extends Comparable>) fieldType), (Comparable) val));
break;
case LESS_THAN_NQ:
list.add(cb.lessThan(getExpression(attributeName, join, root)
.as((Class<? extends Comparable>) fieldType), (Comparable) val));
break;
case INNER_LIKE:
list.add(cb.like(getExpression(attributeName, join, root)
.as(String.class), "%" + val.toString() + "%"));
break;
case LEFT_LIKE:
list.add(cb.like(getExpression(attributeName, join, root)
.as(String.class), "%" + val.toString()));
break;
case RIGHT_LIKE:
list.add(cb.like(getExpression(attributeName, join, root)
.as(String.class), val.toString() + "%"));
break;
case IN:
if (CollUtil.isNotEmpty((Collection<Long>) val)) {
list.add(getExpression(attributeName, join, root).in((Collection<Long>) val));
}
break;
case NOT_EQUAL:
list.add(cb.notEqual(getExpression(attributeName, join, root), val));
break;
case NOT_NULL:
list.add(cb.isNotNull(getExpression(attributeName, join, root)));
break;
case NULL:
if((boolean)val){
list.add(cb.isNull(getExpression(attributeName, join, root)));
}
break;
case BETWEEN:
List<Object> between = new ArrayList<>((List<Object>) val);
list.add(cb.between(getExpression(attributeName, join, root).as((Class<? extends Comparable>) between.get(0).getClass()),
(Comparable) between.get(0), (Comparable) between.get(1)));
break;
default:
break;
}
}
field.setAccessible(accessible);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
}
int size = list.size();
return cb.and(list.toArray(new Predicate[size]));
}
@SuppressWarnings("unchecked")
private static <T, R> Expression<T> getExpression(String attributeName, Join join, Root<R> root) {
if (ObjectUtil.isNotEmpty(join)) {
return join.get(attributeName);
} else {
return root.get(attributeName);
}
}
private static boolean isBlank(final CharSequence cs) {
int strLen;
if (cs == null || (strLen = cs.length()) == 0) {
return true;
}
for (int i = 0; i < strLen; i++) {
if (!Character.isWhitespace(cs.charAt(i))) {
return false;
}
}
return true;
}
private static List<Field> getAllFields(Class clazz, List<Field> fields) {
if (clazz != null) {
fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
getAllFields(clazz.getSuperclass(), fields);
}
return fields;
}
public static <Q> StringBuffer buildSqlWhere(Q query, StringBuffer sql) {
if (query == null) {
return sql;
}
try {
List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
for (Field field : fields) {
boolean accessible = field.isAccessible();
field.setAccessible(true);
Query q = field.getAnnotation(Query.class);
if (q != null) {
String propName = q.propName();
String joinName = q.joinName();
String blurry = q.blurry();
String attributeName = isBlank(propName) ? field.getName() : propName;
Class<?> fieldType = field.getType();
Object val = field.get(query);
if (ObjectUtil.isNull(val) || "".equals(val)) {
continue;
}
// 模糊多字段
if (ObjectUtil.isNotEmpty(blurry)) {
String[] blurrys = blurry.split(",");
sql.append(" and (");
for (int i = 0; i < blurrys.length; i++) {
if (i != 0 && i < blurrys.length) {
sql.append(" or ");
}
sql.append(" " + blurrys[i] + " like '%" + val.toString() + "%'");
}
sql.append(")");
continue;
}
if (ObjectUtil.isNotEmpty(joinName)) {
String[] joinNames = joinName.split(">");
for (String name : joinNames) {
switch (q.join()) {
case LEFT:
sql.append(" left join " + name);
break;
case RIGHT:
sql.append(" right join " + name);
break;
default:
break;
}
}
}
switch (q.type()) {
case EQUAL:
checkStringAndNum(fieldType, sql, "=", attributeName, val);
break;
case GREATER_THAN:
checkStringAndNum(fieldType, sql, ">=", attributeName, val);
break;
case LESS_THAN:
checkStringAndNum(fieldType, sql, "<=", attributeName, val);
break;
case LESS_THAN_NQ:
checkStringAndNum(fieldType, sql, "<", attributeName, val);
break;
case INNER_LIKE:
sql.append(" and " + attributeName + " like '%" + val.toString() + "%' ");
break;
case LEFT_LIKE:
sql.append(" and " + attributeName + " like '%" + val.toString() + "' ");
break;
case RIGHT_LIKE:
sql.append(" and " + attributeName + " like '" + val.toString() + "%' ");
break;
case IN:
if (CollUtil.isNotEmpty((Collection<Object>) val)) {
Collection<Object> list = (Collection<Object>) val;
Object o = list.stream().findFirst().orElse(null);
if (ValueHandlerFactory.isNumeric(o.getClass())) {
Collection<Number> numberList = (Collection<Number>) val;
sql.append(" and " + attributeName + " in (" + list.stream().map(item->item+"").collect(joining(",")) + ")");
} else {
Collection<String> strList = (Collection<String>) val;
sql.append(" and " + attributeName + " in ('" + strList.stream().collect(Collectors.joining("','")) + "')");
}
}
break;
case NOT_EQUAL:
checkStringAndNum(fieldType, sql, "!=", attributeName, val);
break;
case NOT_NULL:
sql.append(" and " + attributeName + " IS NOT NULL ");
break;
case NULL:
if((boolean)val){
sql.append(" and " + attributeName + " IS NULL ");
}
break;
case BETWEEN:
List<Object> between = new ArrayList<>((List<Object>) val);
if (ValueHandlerFactory.isNumeric(fieldType)) {
sql.append(" and " + attributeName + " BETWEEN " + between.get(0) + " AND " + between.get(1));
break;
}
sql.append(" and " + attributeName + " BETWEEN '" + between.get(0) + "' AND '" + between.get(1) + "' ");
break;
default:
break;
}
}
field.setAccessible(accessible);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
}
return sql;
}
/**
* 仅仅构建where的部分
* @param query
* @param sql
* @param <Q>
* @return
*/
public static <Q> StringBuffer buildSqlOnlyWhere(Q query) {
StringBuffer sql=new StringBuffer();
if (query == null) {
return sql;
}
try {
List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
for (Field field : fields) {
boolean accessible = field.isAccessible();
field.setAccessible(true);
Query q = field.getAnnotation(Query.class);
if (q != null) {
String propName = q.propName();
String joinName = q.joinName();
String blurry = q.blurry();
String attributeName = isBlank(propName) ? field.getName() : propName;
Class<?> fieldType = field.getType();
Object val = field.get(query);
if (ObjectUtil.isNull(val) || "".equals(val)) {
continue;
}
// 模糊多字段
if (ObjectUtil.isNotEmpty(blurry)) {
String[] blurrys = blurry.split(",");
sql.append(" and (");
for (int i = 0; i < blurrys.length; i++) {
if (i != 0 && i < blurrys.length) {
sql.append(" or ");
}
sql.append(" " + blurrys[i] + " like '%" + val.toString() + "%'");
}
sql.append(")");
continue;
}
if (ObjectUtil.isNotEmpty(joinName)) {
String[] joinNames = joinName.split(">");
for (String name : joinNames) {
switch (q.join()) {
case LEFT:
sql.append(" left join " + name);
break;
case RIGHT:
sql.append(" right join " + name);
break;
default:
break;
}
}
}
switch (q.type()) {
case EQUAL:
checkStringAndNum(fieldType, sql, "=", attributeName, val);
break;
case GREATER_THAN:
checkStringAndNum(fieldType, sql, ">=", attributeName, val);
break;
case LESS_THAN:
checkStringAndNum(fieldType, sql, "<=", attributeName, val);
break;
case LESS_THAN_NQ:
checkStringAndNum(fieldType, sql, "<", attributeName, val);
break;
case INNER_LIKE:
sql.append(" and " + attributeName + " like '%" + val.toString() + "%' ");
break;
case LEFT_LIKE:
sql.append(" and " + attributeName + " like '%" + val.toString() + "' ");
break;
case RIGHT_LIKE:
sql.append(" and " + attributeName + " like '" + val.toString() + "%' ");
break;
case IN:
if (CollUtil.isNotEmpty((Collection<String>) val)) {
if (ValueHandlerFactory.isNumeric(fieldType)) {
sql.append(" and " + attributeName + " in (" + ((Collection<String>) val).stream().collect(joining(",")) + ")");
} else {
sql.append(" and " + attributeName + " in ('" + ((Collection<String>) val).stream().collect(joining("','")) + "')");
}
}
break;
case NOT_EQUAL:
checkStringAndNum(fieldType, sql, "!=", attributeName, val);
break;
case NOT_NULL:
sql.append(" and " + attributeName + " IS NOT NULL ");
break;
case NULL:
if((boolean)val){
sql.append(" and " + attributeName + " IS NULL ");
}
break;
case BETWEEN:
List<Object> between = new ArrayList<>((List<Object>) val);
if (ValueHandlerFactory.isNumeric(fieldType)) {
sql.append(" and " + attributeName + " BETWEEN " + between.get(0) + " AND " + between.get(1));
break;
}
sql.append(" and " + attributeName + " BETWEEN '" + between.get(0) + "' AND '" + between.get(1) + "' ");
break;
default:
break;
}
}
field.setAccessible(accessible);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
}
return sql;
}
/**
* 判断字段类型并加入引号
*
* @param fieldType 字段类型
* @param sql 总体sql
* @param type = >= <=
* @param attributeName 字段名称
* @param val 字段值
*/
private static void checkStringAndNum(Class<?> fieldType, StringBuffer sql, String type, String attributeName, Object val) {
if (attributeName.equals("enabled")) {
System.out.println(ValueHandlerFactory.isBoolean(val));
}
//如果是数字类型
if (ValueHandlerFactory.isNumeric(fieldType)) {
sql.append(" and " + attributeName + " " + type + " " + val);
}//如果是boolean
else if (ValueHandlerFactory.isBoolean(val)) {
sql.append(" and " + attributeName + " " + type + " " + ((Boolean) val ? 1 : 0) + "");
}//字符
else {
sql.append(" and " + attributeName + " " + type + " '" + val + "' ");
}
}
}
QueryHelp上面有三个方法:
第一个方法是jpa自身调用时引用
第二个方法是使用entityManager.createNativeQuery调用
第三个方法是只生成where条件,用于自己拼接
举例说明:
第一种:
第二种:
第三种不出示例喽。