对于JPA自带的repository有时候会觉得不太满意,可以写一些自己的工具进去,以下对jpa的repository进行自定义扩展
1、进行扩展之后的Repository层
其中MyRepository为自定义的工具,使用的时候后方添加继承即可
package com.supcon.mare.tankinfo.repository;
import com.supcon.mare.tankinfo.entity.TankMovementEntity;
import com.supcon.mare.tankinfo.util.repositoryutil.MyRepository;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.stereotype.Repository;
/**
* @author: zhaoxu
* @description:
*/
@Repository
public interface TestRepository extends MyRepository<TankMovementEntity,Long> {
}
2、MyRepository
package com.supcon.mare.tankinfo.util.repository;
import org.springframework.data.domain.Page;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.NoRepositoryBean;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* @author: zhaoxu
* @description: JPA通用功能扩展
*/
@NoRepositoryBean
public interface MyRepository<T, ID extends Serializable> extends JpaRepository<T, ID>, JpaSpecificationExecutor<T> {
/**
* 分页条件查询
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param joinField 外键关联查询,可为空
* @param sortAttr 排序,可为空
* @return Page
*/
@Deprecated
Page<T> findByPage(Map<String, String> tableMap, List<String> excludeAttr, Map joinField, String sortAttr);
/**
* 分页条件查询,省去不必要的关联map参数
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param sortAttr 排序,可为空
* @return Page
*/
Page<T> findByPage(Map<String, String> tableMap, List<String> excludeAttr, String sortAttr);
/**
* 分页条件查询
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @return Page
*/
Page<T> findByPage(Map<String, String> tableMap, List<String> excludeAttr);
/**
* 分页条件查询
*
* @param tableMap 查询条件
* @return Page
*/
Page<T> findByPage(Map<String, String> tableMap);
/**
* 条件组合查询
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param joinField 外键关联查询,可为空
* @param sortAttr 排序,可为空
* @return list列表
*/
@Deprecated
List<T> findByConditions(Map<String, String> tableMap, List<String> excludeAttr, Map joinField, String sortAttr);
/**
* 条件组合查询,省去不必要的关联map参数
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param sortAttr 排序,可为空
* @return list列表
*/
List<T> findByConditions(Map<String, String> tableMap, List<String> excludeAttr, String sortAttr);
/**
* 条件组合查询
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @return list列表
*/
List<T> findByConditions(Map<String, String> tableMap, List<String> excludeAttr);
/**
* 条件组合查询
*
* @param tableMap 查询条件
* @return list列表
*/
List<T> findByConditions(Map<String, String> tableMap);
/**
* 假删
*
* @param ids ","隔开
*/
void deleteValid(String ids);
/**
* 全匹配查询某一个实体,查询到多个只返回第一个
*
* @param attr 属性名称,唯一标识(id、code ...)
* @param condition 对应条件(1、TK1000 ...)
* @return 实体
*/
T findOneByAttr(String attr, String condition);
/**
* 全匹配查询
*
* @param attr 属性名称(id、name、code ...)
* @param condition 对应条件(1、罐1、TK1000 ...)
* @return list列表
*/
List<T> findByAttr(String attr, String condition);
/**
* 全匹配查询实体
*
* @param attr 属性名称(id、code ...)
* @param conditions 对应条件,逗号隔开
* @return list列表
*/
List<T> findByAttrs(String attr, String conditions);
}
工具类:
package com.zx.util.util;
import com.zx.util.constant.Constants;
import com.zx.util.constant.ErrorCodeEnum;
import com.zx.util.exception.MyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import javax.persistence.criteria.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
/**
* 反射封装工具类
*
* @author : zhaoxu
*/
@Component
public class ReflectUtil {
static final Logger logger = LoggerFactory.getLogger(ReflectUtil.class);
/**
* 生成全属性条件查询通用Specification
*
* @param tableMap 属性参数
* @param clazz 要查询的实体类或vo类
* @param excludeAttr 不使用模糊搜索的字符串属性
* @param map 外键关联查询
* @param <S> 泛型
* @return Specification
*/
@Deprecated
public <S> Specification<S> createSpecification(Map<String, String> tableMap, Class clazz, List<String> excludeAttr, Map map) {
Specification<S> specification = (root, query, cb) -> {
List<Predicate> predicates = new ArrayList<>();
//未删除的数据
try {
clazz.getDeclaredField(Constants.VALID);
if (!StringUtils.isEmpty(tableMap.get(Constants.VALID))) {
predicates.add(cb.equal(root.get(Constants.VALID), Integer.valueOf(tableMap.get(Constants.VALID))));
} else {
predicates.add(cb.equal(root.get(Constants.VALID), 1));
}
} catch (NoSuchFieldException e) {
logger.warn("没有找到属性:valid");
}
Field[] declaredFields = clazz.getDeclaredFields();
for (Field field : declaredFields) {
String fieldName = field.getName();
if (!StringUtils.isEmpty(tableMap.get(fieldName))) {
String typeName = field.getGenericType().getTypeName();
Class<?> aClass;
try {
aClass = Class.forName(typeName);
} catch (ClassNotFoundException e) {
throw new MyException(ErrorCodeEnum.CANNOT_FIND_ATTR_ERROR.getErrorCode(), "未找到属性类型");
}
//属性不包含特定的属性并且是字符串采用模糊搜索
boolean isLike = aClass == String.class && (CollectionUtils.isEmpty(excludeAttr) || !excludeAttr.contains(fieldName));
if (isLike) {
String queryFieldName = "%" + tableMap.get(fieldName).replace("/", "\\/")
.replaceAll("_", "\\\\_").replaceAll("%", "\\\\%") + "%";
predicates.add(cb.like(root.get(fieldName), queryFieldName));
} else {
predicates.add(cb.and(root.get(fieldName).in(Arrays.asList(tableMap.get(fieldName).split(",")))));
}
}
}
//外键关联查询,旧
if (!CollectionUtils.isEmpty(map)) {
Iterator iterator = map.keySet().iterator();
while (iterator.hasNext()) {
String sourceKey = iterator.next().toString();
Map mapping = (Map) map.get(sourceKey);
Join join = root.join(sourceKey, JoinType.INNER);
Iterator mappingItr = mapping.keySet().iterator();
while (mappingItr.hasNext()) {
String joinKey = mappingItr.next().toString();
String joinAttr = mapping.get(joinKey).toString();
if (!StringUtils.isEmpty(tableMap.get(joinKey))) {
predicates.add(cb.and(join.get(joinAttr).in(Arrays.asList(tableMap.get(joinKey).split(",")))));
}
}
}
}
Predicate[] pre = new Predicate[predicates.size()];
Predicate preAnd = cb.and(predicates.toArray(pre));
return query.where(preAnd).getRestriction();
};
return specification;
}
/**
* 生成全属性条件查询通用Specification
*
* @param tableMap 属性参数
* @param clazz 要查询的实体类或vo类
* @param excludeAttr 不使用模糊搜索的字符串属性
* @param <S> 泛型
* @return Specification
*/
public <S> Specification<S> createSpecification(Map<String, String> tableMap, Class clazz, List<String> excludeAttr) {
Specification<S> specification = (root, query, cb) -> {
List<Predicate> predicates = new ArrayList<>();
//未删除的数据
try {
clazz.getDeclaredField(Constants.VALID);
if (!StringUtils.isEmpty(tableMap.get(Constants.VALID))) {
predicates.add(cb.equal(root.get(Constants.VALID), Integer.valueOf(tableMap.get(Constants.VALID))));
} else {
predicates.add(cb.equal(root.get(Constants.VALID), 1));
}
} catch (NoSuchFieldException e) {
logger.warn("没有找到属性:valid");
}
Field[] declaredFields = clazz.getDeclaredFields();
for (Field field : declaredFields) {
String fieldName = field.getName();
if (!StringUtils.isEmpty(tableMap.get(fieldName))) {
String typeName = field.getGenericType().getTypeName();
Class<?> aClass;
try {
aClass = Class.forName(typeName);
} catch (ClassNotFoundException e) {
throw new MyException(ErrorCodeEnum.CANNOT_FIND_ATTR_ERROR.getErrorCode(), "未找到属性类型");
}
//属性不包含特定的属性并且是字符串采用模糊搜索
boolean isLike = aClass == String.class && (CollectionUtils.isEmpty(excludeAttr) || !excludeAttr.contains(fieldName));
if (isLike) {
String queryFieldName = "%" + tableMap.get(fieldName).replace("/", "\\/")
.replaceAll("_", "\\\\_").replaceAll("%", "\\\\%") + "%";
predicates.add(cb.like(root.get(fieldName), queryFieldName));
} else {
predicates.add(cb.and(root.get(fieldName).in(Arrays.asList(tableMap.get(fieldName).split(",")))));
}
}
}
//外键关联查询,新,可以省去map参数
if (!CollectionUtils.isEmpty(tableMap)) {
for (Map.Entry<String, String> entry : tableMap.entrySet()) {
if (entry.getKey().contains(".")) {
//递归解析真实的path
predicates.add(cb.and(getRootPath(root, null, entry.getKey()).in(Arrays.asList(entry.getValue().split(",")))));
}
}
}
Predicate[] pre = new Predicate[predicates.size()];
Predicate preAnd = cb.and(predicates.toArray(pre));
return query.where(preAnd).getRestriction();
};
return specification;
}
/**
* 指定条件查询
*
* @param attr 查询的字段
* @param condition 条件
* @param <S> 泛型
* @return Specification
*/
public <S> Specification<S> createOneSpecification(String attr, String condition) {
Specification<S> specification = (root, query, cb) -> {
List<Predicate> predicates = new ArrayList<>();
//未删除的数据
try {
if (Constants.VALID.equals(attr)) {
predicates.add(cb.equal(root.get(Constants.VALID), condition));
} else {
predicates.add(cb.equal(root.get(Constants.VALID), 1));
}
} catch (Exception e) {
logger.warn("没有找到属性:valid");
}
//外键关联查询
if (attr.contains(Constants.POINT)) {
//递归解析真实的path
predicates.add(cb.equal(getRootPath(root, null, attr), condition));
} else {
predicates.add(cb.equal(root.get(attr), condition));
}
Predicate[] pre = new Predicate[predicates.size()];
Predicate preAnd = cb.and(predicates.toArray(pre));
return query.where(preAnd).getRestriction();
};
return specification;
}
/**
* 获取关联查询真实path
*
* @param root root
* @param path path
* @param allPath allPath
* @param <S> S
* @return Path
*/
public <S> Path getRootPath(Root<S> root, Path path, String allPath) {
List<String> pathList = Arrays.asList(allPath.split("\\."));
//下一个解析的path
StringBuilder restPath = new StringBuilder();
Path nowPath = null;
if (!CollectionUtils.isEmpty(pathList)) {
if (root != null) {
nowPath = root.get(pathList.get(0));
//拥有下一个解析点
if (pathList.size() > 1) {
for (int i = 1; i < pathList.size(); i++) {
restPath.append(pathList.get(i));
if (i + 1 < pathList.size()) {
restPath.append(".");
}
}
//递归
nowPath = getRootPath(null, nowPath, restPath.toString());
}
} else {
nowPath = path.get(pathList.get(0));
//拥有下一个解析点
if (pathList.size() > 1) {
for (int i = 1; i < pathList.size(); i++) {
restPath.append(pathList.get(i));
if (i + 1 < pathList.size()) {
restPath.append(".");
}
}
//递归
nowPath = getRootPath(root, nowPath, restPath.toString());
}
}
}
return nowPath;
}
/**
* 通过方法名动态执行某个方法
*
* @param object object
* @param methodName 方法名
* @param parameters 参数
* @return Object
* @throws InvocationTargetException InvocationTargetException
* @throws IllegalAccessException IllegalAccessException
* @throws NoSuchMethodException NoSuchMethodException
*/
public Object executeMethod(Object object, String methodName, Object... parameters) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
Class<?> clazz = object.getClass();
ArrayList<Class<?>> paramTypeList = new ArrayList<>();
for (Object paramType : parameters) {
paramTypeList.add(paramType.getClass());
}
Class<?>[] classArray = new Class[paramTypeList.size()];
Method method = clazz.getMethod(methodName, paramTypeList.toArray(classArray));
Object invoke = method.invoke(object, parameters);
return invoke;
}
/**
* 获取所有属性值
*
* @param object object
* @return Map
* @throws IllegalAccessException IllegalAccessException
*/
public Map<String, Object> getFieldsValue(Object object) throws IllegalAccessException {
Class<?> clazz = object.getClass();
Map<String, Object> fieldValuesMap = new HashMap<>(16);
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
Object fieldValue = field.get(object);
fieldValuesMap.put(field.getName(), fieldValue);
}
return fieldValuesMap;
}
/**
* 设置属性值
*
* @param property 设置的字段
* @param value 值
* @param object object
* @return Boolean
*/
public Boolean setValue(Object object, String property, Object value) {
Class<?> clazz = object.getClass();
try {
Field declaredField = clazz.getDeclaredField(property);
declaredField.setAccessible(true);
declaredField.set(object, value);
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (NoSuchFieldException e) {
e.printStackTrace();
}
return true;
}
/**
* 获取对象所有属性及对应的类别
*
* @param object object
* @return Map
* @throws IllegalAccessException IllegalAccessException
*/
public Map<String, Class<?>> getFields(Object object) throws IllegalAccessException {
Class<?> clazz = object.getClass();
Map<String, Class<?>> attrMap = new HashMap<>(16);
if (clazz != null) {
Iterator<String> iterator = getValues(object).keySet().iterator();
while (iterator.hasNext()) {
attrMap.put(iterator.next(), Object.class);
}
}
return attrMap;
}
/**
* 获取所有属性值
*
* @param object object
* @return Map
* @throws IllegalAccessException IllegalAccessException
*/
public Map<String, Object> getValues(Object object) throws IllegalAccessException {
Map<String, Object> fieldValuesMap = new HashMap(16);
Class<?> clazz = object.getClass();
if (clazz != null) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
Object fieldValue = field.get(object);
fieldValuesMap.put(field.getName(), fieldValue);
}
return fieldValuesMap;
}
return fieldValuesMap;
}
/**
* 获取拥有指定注解的字段
*
* @param objectClass 对象
* @param annoClass 查询的注解
* @return List
*/
public List<Field> getTargetAnnoation(Class<?> objectClass, Class<? extends Annotation> annoClass) {
List<Field> fields = new ArrayList<>();
Field[] declaredFields = objectClass.getDeclaredFields();
for (Field field : declaredFields) {
field.setAccessible(true);
if (!field.isAnnotationPresent(annoClass)) {
continue;
} else {
fields.add(field);
}
}
if (!CollectionUtils.isEmpty(fields)) {
return fields;
} else {
return null;
}
}
}
3、MyRepositoryImpl实现类
package com.supcon.mare.tankinfo.util.repository;
import com.supcon.mare.tankinfo.constant.Constants;
import com.supcon.mare.tankinfo.util.ReflectUtil;
import com.supcon.mare.tankinfo.util.Utils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.transaction.annotation.Isolation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import javax.persistence.EntityManager;
import javax.persistence.Id;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.*;
/**
* @author: zhaoxu
* @description: JPA通用功能扩展
*/
public class MyRepositoryImpl<T, ID extends Serializable>
extends SimpleJpaRepository<T, ID> implements MyRepository<T, ID> {
private ReflectUtil reflectUtil = new ReflectUtil();
private EntityManager entityManager;
private Utils utils = new Utils();
private Class<T> clazz;
@Autowired(required = false)
public MyRepositoryImpl(JpaEntityInformation<T, ID> entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.clazz = entityInformation.getJavaType();
this.entityManager = entityManager;
}
/**
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param joinField 外键关联查询,可为空
* @param sortAttr 排序,可为空
* @return Page
*/
@Override
public Page<T> findByPage(Map<String, String> tableMap, List<String> excludeAttr, Map joinField, String sortAttr) {
int current = Integer.valueOf(tableMap.get(Constants.CURRENT));
int pageSize = Integer.valueOf(tableMap.get(Constants.PAGE_SIZE));
Pageable pageable;
if (!StringUtils.isEmpty(sortAttr)) {
pageable = PageRequest.of(current - 1, pageSize, utils.sortAttr(tableMap, sortAttr));
} else {
pageable = PageRequest.of(current - 1, pageSize);
}
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, excludeAttr, joinField);
return this.findAll(specification, pageable);
}
/**
* 省去不必要的关联map
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param sortAttr 排序,可为空
* @return Page
*/
@Override
public Page<T> findByPage(Map<String, String> tableMap, List<String> excludeAttr, String sortAttr) {
int current = Integer.valueOf(tableMap.get(Constants.CURRENT));
int pageSize = Integer.valueOf(tableMap.get(Constants.PAGE_SIZE));
Pageable pageable;
if (!StringUtils.isEmpty(sortAttr)) {
pageable = PageRequest.of(current - 1, pageSize, utils.sortAttr(tableMap, sortAttr));
} else {
pageable = PageRequest.of(current - 1, pageSize);
}
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, excludeAttr);
return this.findAll(specification, pageable);
}
/**
* 省去map以及排序
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @return Page
*/
@Override
public Page<T> findByPage(Map<String, String> tableMap, List<String> excludeAttr) {
int current = Integer.valueOf(tableMap.get(Constants.CURRENT));
int pageSize = Integer.valueOf(tableMap.get(Constants.PAGE_SIZE));
Pageable pageable;
pageable = PageRequest.of(current - 1, pageSize);
//调用省去map参数的方法
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, excludeAttr);
return this.findAll(specification, pageable);
}
/**
* @param tableMap 查询条件
* @return Page
*/
@Override
public Page<T> findByPage(Map<String, String> tableMap) {
int current = Integer.valueOf(tableMap.get(Constants.CURRENT));
int pageSize = Integer.valueOf(tableMap.get(Constants.PAGE_SIZE));
Pageable pageable;
pageable = PageRequest.of(current - 1, pageSize);
//调用省去map参数的方法
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, null);
return this.findAll(specification, pageable);
}
/**
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param joinField 外键关联查询,可为空
* @param sortAttr 排序,可为空
* @return List
*/
@Override
public List<T> findByConditions(Map<String, String> tableMap, List<String> excludeAttr, Map joinField, String sortAttr) {
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, excludeAttr, joinField);
if (!StringUtils.isEmpty(sortAttr)) {
return this.findAll(specification, utils.sortAttr(tableMap, sortAttr));
} else {
return this.findAll(specification);
}
}
/**
* 省去不必要的关联map参数
*
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @param sortAttr 排序,可为空
* @return List
*/
@Override
public List<T> findByConditions(Map<String, String> tableMap, List<String> excludeAttr, String sortAttr) {
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, excludeAttr);
if (!StringUtils.isEmpty(sortAttr)) {
return this.findAll(specification, utils.sortAttr(tableMap, sortAttr));
} else {
return this.findAll(specification);
}
}
/**
* @param tableMap 查询条件
* @param excludeAttr 是字符串类型,但是不使用模糊查询的字段,可为空
* @return List
*/
@Override
public List<T> findByConditions(Map<String, String> tableMap, List<String> excludeAttr) {
//调用省去map参数的方法
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, excludeAttr);
return this.findAll(specification);
}
/**
* @param tableMap 查询条件
* @return List
*/
@Override
public List<T> findByConditions(Map<String, String> tableMap) {
//调用省去map参数的方法
Specification<T> specification = reflectUtil.createSpecification(tableMap, clazz, null);
return this.findAll(specification);
}
@Override
@Transactional(isolation = Isolation.READ_COMMITTED, rollbackFor = Exception.class)
public void deleteValid(String ids) {
List<String> strings = Arrays.asList(ids.split(","));
if (!CollectionUtils.isEmpty(strings)) {
//获取主键
List<Field> idAnnoation = reflectUtil.getTargetAnnoation(clazz, Id.class);
if (!CollectionUtils.isEmpty(idAnnoation)) {
Field field = idAnnoation.get(0);
strings.stream().forEach(id -> {
T object = this.findOneByAttr(field.getName(), id);
if (object != null) {
reflectUtil.setValue(object, "valid", 0);
this.save(object);
}
});
}
}
}
@Override
public T findOneByAttr(String attr, String condition) {
Specification<T> specification = reflectUtil.createOneSpecification(attr, condition);
Optional<T> result = this.findOne(specification);
if (result.isPresent()) {
return result.get();
} else {
return null;
}
}
@Override
public List<T> findByAttr(String attr, String condition) {
Specification<T> specification = reflectUtil.createOneSpecification(attr, condition);
List<T> all = this.findAll(specification);
return all;
}
@Override
public List<T> findByAttrs(String attr, String conditions) {
List<T> results = new ArrayList<>();
if (!StringUtils.isEmpty(conditions)) {
List<String> cons = Arrays.asList(conditions.split(","));
cons.stream().forEach(condition -> {
List<T> byAttr = findByAttr(attr, condition);
if (byAttr != null) {
results.addAll(byAttr);
}
});
}
return results;
}
}
4、添加自定义工厂类:
package com.supcon.mare.tankinfo.util.repository;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactory;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean;
import org.springframework.data.jpa.repository.support.JpaRepositoryImplementation;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFactorySupport;
import org.springframework.util.Assert;
import javax.persistence.EntityManager;
import java.io.Serializable;
/**
* @author: zhaoxu
* @description:
*/
public class BaseJpaRepositoryFactoryBean<T extends JpaRepository<S, ID>, S, ID extends Serializable> extends JpaRepositoryFactoryBean<T, S, ID> {
public BaseJpaRepositoryFactoryBean(Class<? extends T> repositoryInterface) {
super(repositoryInterface);
}
@Override
protected RepositoryFactorySupport createRepositoryFactory(EntityManager em) {
return new BaseRepositoryFactory(em);
}
private static class BaseRepositoryFactory<T, I extends Serializable>
extends JpaRepositoryFactory {
private final EntityManager em;
public BaseRepositoryFactory(EntityManager em) {
super(em);
this.em = em;
}
@Override
protected JpaRepositoryImplementation<?, ?> getTargetRepository(RepositoryInformation information, EntityManager entityManager) {
JpaEntityInformation<?, Serializable> entityInformation = this.getEntityInformation(information.getDomainType());
Object repository = this.getTargetRepositoryViaReflection(information, new Object[]{entityInformation, entityManager});
Assert.isInstanceOf(MyRepositoryImpl.class, repository);
return (JpaRepositoryImplementation) repository;
}
@Override
protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) {
return MyRepositoryImpl.class;
}
}
}
并且在启动类添加注解:
@EnableJpaRepositories(repositoryFactoryBeanClass = BaseJpaRepositoryFactoryBean.class)
或者添加config配置类:
package com.supcon.mare.tankinfo.config;
import com.supcon.mare.tankinfo.util.repositoryutil.BaseJpaRepositoryFactoryBean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.jpa.repository.config.EnableJpaRepositories;
/**
* @author: zhaoxu
* @description:
*/
@Configuration
@EnableJpaRepositories(basePackages = {"com.supcon.mare"}, repositoryFactoryBeanClass = BaseJpaRepositoryFactoryBean.class)
public class MyJpaRepositoryConfig{
}
5、使用方式:
package com.supcon.mare.tankinfo.service.impl;
import com.supcon.mare.tankinfo.entity.TankMovementEntity;
import com.supcon.mare.tankinfo.repository.TestRepository;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct;
import java.util.Map;
/**
* @author: zhaoxu
* @description:
*/
@Service
public class test {
@Autowired
TestRepository testRepository;
public void test(Map tableMap) {
Map all = testRepository.findByPage(tableMap,null,null,"taskDefineCode");
System.out.println(all);
}
}
既可以使用jpa自带的方法,也可以使用自定义的方法:
github地址:
https://github.com/zhao458114067/repository-util