接上篇 https://blog.csdn.net/a184838158/article/details/82629337
jpa在简单数据库操作时会非常简单,但是涉及到动态条件查询的时候就会非常的蛋疼。之前我的做法是声明多个接口来动态选择执行,但是对于多个条件的动态查询时则代码量非常庞大。
首先说一下我的设计思路,
- 编写自定义注解
- 封装一个bean作为查询对象
- 该bean每一个属性作为一个查询条件
- 每个属性利用注解设置查询属性
- 重写BaseRepositoryFactoryBean,重写JpaRepository的默认实现
- 在JpaRepository的默认实现中,编写处理查询对象的参数拼接
注解
package com.fewstrong.reposiotry.support;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 查询对象所用注解
* @author Fewstrong
*
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface QueryField {
QueryType type();
String name() default "";
}
查询类型枚举
package com.fewstrong.reposiotry.support;
/**
* 查询对象的注解类型
* @author Fewstrong
*
*/
public enum QueryType {
EQUAL(false),
BEWTEEN(false),
LESS_THAN(false),
LESS_THAN_EQUAL(false),
GREATEROR_THAN(false),
GREATEROR_THAN_EQUAL(false),
NOT_EQUAL(false),
IS_NULL(true),
IS_NOT_NULL(true),
RIGHT_LIKE(false),
LEFT_LIKE(false),
FULL_LIKE(false),
DEFAULT_LIKE(false),
NOT_LIKE(false),
IN(false);
// 是否可以为空
private boolean isCanBeNull;
private QueryType(boolean isCanBeNull) {
this.isCanBeNull = isCanBeNull;
}
public boolean isNotCanBeNull() {
return !this.isCanBeNull;
}
public boolean isCanBeNull() {
return this.isCanBeNull;
}
}
这个类用于比较查询(大于、小于等)
package com.fewstrong.reposiotry.support;
/**
*
* 用于比较查询
*
* @author Fewstrong
*/
public class QueryBetween<T extends Comparable<?>>{
public T before;
public T after;
public T getBefore() {
return before;
}
public void setBefore(T before) {
this.before = before;
}
public T getAfter() {
return after;
}
public void setAfter(T after) {
this.after = after;
}
}
自定义接口,所有的动态查询条件必须实现该接口
package com.fewstrong.reposiotry.support;
/**
*
* 默认动态查询对象
*
* @author Fewstrong
*/
public interface DataQueryObject {
}
扩充接口,分别为排序,以及分页查询
package com.fewstrong.reposiotry.support;
/**
* 排序查询对象
* @author Fewstrong
*
*/
public class DataQueryObjectSort implements DataQueryObject {
protected String propertyName;
protected boolean ascending = true;
public String getPropertyName() {
return propertyName;
}
public void setPropertyName(String propertyName) {
this.propertyName = propertyName;
}
public boolean isAscending() {
return ascending;
}
public void setAscending(boolean ascending) {
this.ascending = ascending;
}
}
package com.fewstrong.reposiotry.support;
/**
* 分页查询对象
* @author Fewstrong
*
*/
public class DataQueryObjectPage extends DataQueryObjectSort {
protected Integer page = 0;
protected Integer size = 10;
public Integer getPage() {
return page;
}
public void setPage(Integer page) {
this.page = page;
}
public Integer getSize() {
return size;
}
public void setSize(Integer size) {
this.size = size;
}
}
一些基础bean已经完成,现在需要完成重写spring的一些类
重写 JpaRepository
package com.fewstrong.reposiotry;
import java.io.Serializable;
import java.util.List;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.NoRepositoryBean;
import org.springframework.transaction.annotation.Transactional;
import com.fewstrong.reposiotry.support.DataQueryObject;
import com.fewstrong.reposiotry.support.DataQueryObjectPage;
import com.fewstrong.reposiotry.support.DataQueryObjectSort;
/**
*
* @author fewstorng
**/
@NoRepositoryBean
@Transactional(readOnly=true,rollbackFor = Exception.class)
public interface BaseRepository<T,ID extends Serializable> extends JpaRepository<T,ID>, JpaSpecificationExecutor<T> {
// 普通查询
List<T> findAll(DataQueryObject query);
// 分页查询
Page<T> findAll(DataQueryObject query, Pageable page);
// 分页查询
Page<T> findAll(DataQueryObjectPage dataQueryObjectpage);
// 排序查询
List<T> findAll(DataQueryObject dataQueryObject, Sort sort);
// 排序查询
List<T> findAll(DataQueryObjectSort dataQueryObjectSort);
}
重写JpaRepositoryFactoryBean
package com.fewstrong.reposiotry;
import java.io.Serializable;
import javax.persistence.EntityManager;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactory;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFactorySupport;
/**
*
* @author fewstrong
**/
@SuppressWarnings({"rawtypes","unchecked"})
public class BaseRepositoryFactoryBean<R extends JpaRepository<T, I>, T,
I extends Serializable> extends JpaRepositoryFactoryBean<R, T, I> {
/**
* Creates a new {@link JpaRepositoryFactoryBean} for the given repository interface.
*
* @param repositoryInterface must not be {@literal null}.
*/
public BaseRepositoryFactoryBean(Class<? extends R> repositoryInterface) {
super(repositoryInterface);
}
@Override
protected RepositoryFactorySupport createRepositoryFactory(EntityManager em) {
return new BaseRepositoryFactory(em);
}
/**
* 创建一个内部类,该类不用在外部访问
* @param <T>
* @param <I>
*/
private static class BaseRepositoryFactory<T, I extends Serializable>
extends JpaRepositoryFactory {
private final EntityManager em;
public BaseRepositoryFactory(EntityManager em) {
super(em);
this.em = em;
}
/**
* 设置具体的实现类是BaseRepositoryImpl
* @param information
* @return
*/
@Override
protected Object getTargetRepository(RepositoryInformation information) {
return new BaseRepositoryImpl<T, I>((Class<T>) information.getDomainType(), em);
}
/**
* 获取具体的实现类的class
* @param metadata
* @return
*/
@Override
protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) {
return BaseRepositoryImpl.class;
}
}
}
接下来是重头戏,实现自己写的BaseRepository并利用JpaSpecificationExecutor来完成条件的动态封装
package com.fewstrong.reposiotry;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import javax.persistence.EntityManager;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaBuilder.In;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Path;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Direction;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.util.StringUtils;
import com.fewstrong.reposiotry.support.DataQueryObject;
import com.fewstrong.reposiotry.support.DataQueryObjectPage;
import com.fewstrong.reposiotry.support.DataQueryObjectSort;
import com.fewstrong.reposiotry.support.QueryBetween;
import com.fewstrong.reposiotry.support.QueryField;
import com.fewstrong.reposiotry.support.QueryType;
/**
*
* @author fewstrong
**/
public class BaseRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID>
implements BaseRepository<T, ID>, JpaSpecificationExecutor<T> {
static Logger logger = LoggerFactory.getLogger(BaseRepositoryImpl.class);
private final EntityManager entityManager;
private final Class<T> clazz;
public BaseRepositoryImpl(Class<T> domainClass, EntityManager entityManager) {
super(domainClass, entityManager);
this.clazz = domainClass;
this.entityManager = entityManager;
exmple = this;
}
final private BaseRepositoryImpl exmple;
@Override
public List<T> findAll(DataQueryObjectSort dataQueryObjectSort) {
final DataQueryObject dqo = dataQueryObjectSort;
// 如果排序内容为空 则执行不排序的 查找
if(dataQueryObjectSort.getPropertyName() != null && dataQueryObjectSort.getPropertyName().trim().length() != 0) {
return this.findAll(dqo, new Sort(new Order(dataQueryObjectSort.isAscending() ? Direction.ASC : Direction.DESC, dataQueryObjectSort.getPropertyName())));
} else {
return this.findAll(dqo);
}
}
@Override
public List<T> findAll(DataQueryObject dataQueryObject, Sort sort) {
final DataQueryObject dqo = dataQueryObject;
return this.findAll(new Specification<T>() {
@Override
public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder cb) {
return exmple.getPredocate(root, query, cb, dqo);
}
}, sort);
}
@Override
public List<T> findAll(DataQueryObject dataQueryObject) {
final DataQueryObject dqo = dataQueryObject;
return this.findAll(new Specification<T>() {
@Override
public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder cb) {
return exmple.getPredocate(root, query, cb, dqo);
}
});
}
@Override
public Page<T> findAll(DataQueryObjectPage dataQueryObjectpage) {
Pageable pageable = null;
if(dataQueryObjectpage.getPropertyName() != null && dataQueryObjectpage.getPropertyName().trim().length() != 0) {
pageable = new PageRequest(dataQueryObjectpage.getPage(), dataQueryObjectpage.getSize(), new Sort(new Order(dataQueryObjectpage.isAscending() ? Direction.ASC : Direction.DESC, dataQueryObjectpage.getPropertyName())));
} else {
pageable = new PageRequest(dataQueryObjectpage.getPage(), dataQueryObjectpage.getSize());
}
return this.findAll(dataQueryObjectpage, pageable);
}
@Override
public Page<T> findAll(DataQueryObject dataQueryObject, Pageable page) {
final DataQueryObject dqo = dataQueryObject;
return this.findAll(new Specification<T>() {
@Override
public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder cb) {
return exmple.getPredocate(root, query, cb, dqo);
}
}, page);
}
// 核心方法 拼接条件
protected Predicate getPredocate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder cb,
DataQueryObject dqo) {
List<Predicate> predicates = new ArrayList<>();
// 获取查询对象的所有属性
Field[] fields = dqo.getClass().getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
String queryFiled = null;
QueryType queryType = null;
Object value = null;
Predicate predicate = null;
// 获取属性的 自定义注解类型
QueryField annotaion = field.getAnnotation(QueryField.class);
// 如果没有注解 则跳过
if(annotaion == null) {
continue;
}
// 如果注解中 name为空 则用字段名称作为属性名
if(!StringUtils.isEmpty(annotaion.name())) {
queryFiled = annotaion.name();
} else {
queryFiled = field.getName();
}
queryType = annotaion.type();
try {
value = field.get(dqo);
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
// 判断字段类型是否为空
if(value == null && queryType.isNotCanBeNull()) {
logger.debug("查询类型:" + queryType + "不允许为空。");
continue;
}
// 判断注解中 的条件类型
switch (queryType) {
case EQUAL:
Path<Object> equal = getRootByQueryFiled(queryFiled, root);
predicate = cb.equal(equal, value);
predicates.add(predicate);
break;
case BEWTEEN:
Path<Comparable> between = getRootByQueryFiledComparable(queryFiled, root);
QueryBetween queryBetween = null;
if(value instanceof QueryBetween)
queryBetween = (QueryBetween) value;
else
continue;
predicate = cb.between(between, queryBetween.after, queryBetween.before);
predicates.add(predicate);
break;
case LESS_THAN:
Path<Comparable> lessThan = getRootByQueryFiledComparable(queryFiled, root);
if(value instanceof QueryBetween)
queryBetween = (QueryBetween) value;
else
continue;
predicate = cb.lessThan(lessThan, queryBetween.after);
predicates.add(predicate);
break;
case LESS_THAN_EQUAL:
Path<Comparable> lessThanOrEqualTo = getRootByQueryFiledComparable(queryFiled, root);
if(value instanceof QueryBetween)
queryBetween = (QueryBetween) value;
else
continue;
predicate = cb.lessThanOrEqualTo(lessThanOrEqualTo, queryBetween.after);
predicates.add(predicate);
break;
case GREATEROR_THAN:
Path<Comparable> greaterThan = getRootByQueryFiledComparable(queryFiled, root);
if(value instanceof QueryBetween)
queryBetween = (QueryBetween) value;
else
continue;
predicate = cb.greaterThan(greaterThan, queryBetween.after);
predicates.add(predicate);
break;
case GREATEROR_THAN_EQUAL:
Path<Comparable> greaterThanOrEqualTo = getRootByQueryFiledComparable(queryFiled, root);
if(value instanceof QueryBetween)
queryBetween = (QueryBetween) value;
else
continue;
predicate = cb.lessThanOrEqualTo(greaterThanOrEqualTo, queryBetween.after);
predicates.add(predicate);
break;
case NOT_EQUAL:
Path<Object> notEqual = getRootByQueryFiled(queryFiled, root);
predicate = cb.notEqual(notEqual, value);
predicates.add(predicate);
break;
case IS_NULL:
Path<Object> isNull = getRootByQueryFiled(queryFiled, root);
predicate = cb.isNull(isNull);
predicates.add(predicate);
break;
case IS_NOT_NULL:
Path<Object> isNotNull = getRootByQueryFiled(queryFiled, root);
predicate = cb.isNotNull(isNotNull);
predicates.add(predicate);
break;
case LEFT_LIKE:
Path<String> leftLike = getRootByQueryFiledString(queryFiled, root);
predicate = cb.like(leftLike, "%" + value.toString());
predicates.add(predicate);
break;
case RIGHT_LIKE:
Path<String> rightLike = getRootByQueryFiledString(queryFiled, root);
predicate = cb.like(rightLike, value.toString() + "%");
predicates.add(predicate);
break;
case FULL_LIKE:
Path<String> fullLike = getRootByQueryFiledString(queryFiled, root);
predicate = cb.like(fullLike, "%" + value.toString() + "%");
predicates.add(predicate);
break;
case DEFAULT_LIKE:
Path<String> like = getRootByQueryFiledString(queryFiled, root);
predicate = cb.like(like, value.toString());
predicates.add(predicate);
break;
case NOT_LIKE:
Path<String> notLike = getRootByQueryFiledString(queryFiled, root);
predicate = cb.like(notLike, value.toString());
predicates.add(predicate);
break;
case IN:
Path<Object> in = getRootByQueryFiled(queryFiled, root);
In ins = cb.in(in);
List inList = null;
if(value instanceof List) {
inList = (List) value;
}
for (Object object : inList) {
ins.value(object);
}
predicates.add(ins);
break;
default:
break;
}
}
// 如果 为空 代表 没有任何有效的条件
if(predicates.size() == 0) {
return cb.and();
}
Object[] list = predicates.toArray();
Predicate[] t = new Predicate[predicates.size()];
Predicate[] result = predicates.toArray(t);
return cb.and(result);
}
private Path<Object> getRootByQueryFiled(String queryFiled, Root<T> root) {
if(queryFiled.indexOf(".") < 0) {
return root.get(queryFiled);
} else {
return getRootByQueryFiled(queryFiled.substring(queryFiled.indexOf(".") + 1, queryFiled.length()), root.get(queryFiled.substring(0, queryFiled.indexOf("."))));
}
}
private Path<Object> getRootByQueryFiled(String queryFiled, Path<Object> path) {
if(queryFiled.indexOf(".") < 0) {
return path.get(queryFiled);
} else {
return getRootByQueryFiled(queryFiled.substring(queryFiled.indexOf(".") + 1, queryFiled.length()), path.get(queryFiled.substring(0, queryFiled.indexOf("."))));
}
}
private Path<String> getRootByQueryFiledString(String queryFiled, Root<T> root) {
if(queryFiled.indexOf(".") < 0) {
return root.get(queryFiled);
} else {
return getRootByQueryFiledString(queryFiled.substring(queryFiled.indexOf(".") + 1, queryFiled.length()), root.get(queryFiled.substring(0, queryFiled.indexOf("."))));
}
}
private Path<String> getRootByQueryFiledString(String queryFiled, Path<Object> path) {
if(queryFiled.indexOf(".") < 0) {
return path.get(queryFiled);
} else {
return getRootByQueryFiledString(queryFiled.substring(queryFiled.indexOf(".") + 1, queryFiled.length()), path.get(queryFiled.substring(0, queryFiled.indexOf("."))));
}
}
private Path<Comparable> getRootByQueryFiledComparable(String queryFiled, Root<T> root) {
if(queryFiled.indexOf(".") < 0) {
return root.get(queryFiled);
} else {
return getRootByQueryFiledComparable(queryFiled.substring(queryFiled.indexOf(".") + 1, queryFiled.length()), root.get(queryFiled.substring(0, queryFiled.indexOf("."))));
}
}
private Path<Comparable> getRootByQueryFiledComparable(String queryFiled, Path<Object> path) {
if(queryFiled.indexOf(".") < 0) {
return path.get(queryFiled);
} else {
return getRootByQueryFiledComparable(queryFiled.substring(queryFiled.indexOf(".") + 1, queryFiled.length()), path.get(queryFiled.substring(0, queryFiled.indexOf("."))));
}
}
}
至此,动态条件查询的框架部分已经完成,利用DataQueryObject我们可以完成复杂动态条件的查询。
在entity中添加age用于测试排序
package com.fewstrong.Entity;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
@Entity(name = "db_user")
public class User {
@Id
@Column
@GeneratedValue(strategy = GenerationType.AUTO)
private Long id;
@Column(length = 64)
private String name;
@Column(length = 5)
private Integer age;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Integer getAge() {
return age;
}
public void setAge(Integer age) {
this.age = age;
}
}
查询对象
package com.fewstrong.dto.qo;
import com.fewstrong.reposiotry.support.DataQueryObjectPage;
import com.fewstrong.reposiotry.support.QueryBetween;
import com.fewstrong.reposiotry.support.QueryField;
import com.fewstrong.reposiotry.support.QueryType;
public class UserQo extends DataQueryObjectPage {
@QueryField(type = QueryType.FULL_LIKE, name = "name")
private String name;
@QueryField(type = QueryType.BEWTEEN, name = "age")
private QueryBetween<Integer> betweenAge;
@QueryField(type = QueryType.EQUAL, name = "age")
private Integer equalAge;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public QueryBetween<Integer> getBetweenAge() {
return betweenAge;
}
public void setBetweenAge(QueryBetween<Integer> betweenAge) {
this.betweenAge = betweenAge;
}
public Integer getEqualAge() {
return equalAge;
}
public void setEqualAge(Integer equalAge) {
this.equalAge = equalAge;
}
}
修改 UserRepository,改为集成 BaseRepository
package com.fewstrong.jpa;
import com.fewstrong.Entity.User;
import com.fewstrong.reposiotry.BaseRepository;
public interface UserRepository extends BaseRepository<User, Long>{
}
在 UserController 和 UserService中添加代码
package com.fewstrong.controller;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import com.fewstrong.Entity.User;
import com.fewstrong.dto.qo.UserQo;
import com.fewstrong.service.UserService;
@RestController
@RequestMapping("/user")
public class UserController {
@Autowired
UserService userService;
@RequestMapping(value = "/all", method = { RequestMethod.GET })
public List<User> all() {
return userService.findAll();
}
@RequestMapping(value = "/all", method = { RequestMethod.POST })
public Page<User> all(@RequestBody UserQo userQo) {
return userService.findAll(userQo);
}
}
package com.fewstrong.service;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.stereotype.Service;
import com.fewstrong.Entity.User;
import com.fewstrong.dto.qo.UserQo;
import com.fewstrong.jpa.UserRepository;
@Service
public class UserService {
@Autowired
UserRepository userRepository;
public List<User> findAll() {
return userRepository.findAll();
}
public Page<User> findAll(UserQo userQo) {
return userRepository.findAll(userQo);
}
}
在main方法中添加注解
package com.fewstrong;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.domain.EntityScan;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.data.jpa.repository.config.EnableJpaRepositories;
import com.fewstrong.reposiotry.BaseRepositoryFactoryBean;
@ComponentScan(basePackages = { "com.fewstrong" })
@EnableAutoConfiguration
@EntityScan("com.fewstrong")
// 指定自己的工厂类
@EnableJpaRepositories(basePackages = "com.fewstrong",
repositoryFactoryBeanClass = BaseRepositoryFactoryBean.class)
@SpringBootApplication
public class TestApplication {
public static void main(String[] args) {
SpringApplication.run(TestApplication.class, args);
}
}
运行项目
在数据库中完善数据
由于使用的是post请求,所以利用python做测试
import requests,json
url = 'http://localhost:10001/user/all'
headers = {'content-type': "application/json"}
print("第一次请求-----------------------")
body = {
"betweenAge" : {
"before" : 15,
"after" : 13
}
}
r = requests.post(url, data=json.dumps(body), headers=headers)
print(body)
print(r.content)
print("第二次请求-----------------------")
body1 = {
"equalAge" : 10,
"name" : "jerry"
}
r = requests.post(url, data=json.dumps(body1), headers=headers)
print(body1)
print(r.content)
请求结果
如图,测试成功!
附上代码链接 https://github.com/fewstrong87/test/tree/feature%23jpa_specification
如果需要使用关联查询 则需要在bean中设置一对多 多对多等关联关系,在查询对象查询属性的QueryField 中的name应该 写成’xxx.yyyy’ 的形式即可。