<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
package com.demo;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.NoRepositoryBean;
import javax.persistence.EntityManager;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* ${DESCRIPTION}
*/
@NoRepositoryBean //接口不参与jpa的代理
public interface BaseJpaRepository<T, ID extends Serializable> extends JpaRepository<T, ID>, JpaSpecificationExecutor<T>, Serializable {
EntityManager getEntityManager();
<E> List<E> findByHql(String hql);
List<Map<?,?>> findBySql(String sql);
List<Map<?,?>> findBySql(String sql, Object[] params);
List<Map<?,?>> findBySql(String sql, Map<String, Object> params);
Map<?,?> findBySqlFirst(String sql);
Map<?,?> findBySqlFirst(String sql, Object[] params);
Map<?,?> findBySqlFirst(String sql, Map<String, Object> params);
/**
* basic == true 表示基本数据类型
*/
<E> List<E> findBySql(String sql, Class<E> clazz, boolean basic);
<E> List<E> findBySql(String sql, Class<E> clazz, boolean basic, Object[] params);
<E> List<E> findBySql(String sql, Class<E> clazz, boolean basic, Map<String, Object> params);
/**
* 分页查询
*/
<E> Page<E> findPageBySql(String sql, Pageable pageable, Class<E> clazz, boolean basic);
<E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class<E> clazz, boolean basic);
<E> Page<E> findPageBySql(String sql, Pageable pageable, Class<E> clazz, boolean basic, Object[] params);
<E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class<E> clazz, boolean basic, Object[] params);
<E> Page<E> findPageBySql(String sql, Pageable pageable, Class<E> clazz, boolean basic, Map<String, Object> params);
<E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class<E> clazz, boolean basic, Map<String, Object> params);
/**
* basic == true 表示基本数据类型
*/
<E> E findBySqlFirst(String sql, Class<E> clazz, boolean basic);
<E> E findBySqlFirst(String sql, Class<E> clazz, boolean basic, Object[] params);
<E> E findBySqlFirst(String sql, Class<E> clazz, boolean basic, Map<String, Object> params);
T findByIdNew(ID id);
/**
* 批量插入
*/
<S extends T> Iterable<S> batchSave(Iterable<S> iterable);
/**
* 批量更新
*/
<S extends T> Iterable<S> batchUpdate(Iterable<S> iterable);
void lazyInitialize(Class<T> entityClazz, List<T> l, String[] fields);
void lazyInitialize(T obj, String[] fields);
}
package com.demo;
import com.qtsec.demo.ApplicationContextProvider;
import org.hibernate.Hibernate;
import org.hibernate.query.internal.NativeQueryImpl;
import org.hibernate.transform.Transformers;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StringUtils;
import javax.persistence.EntityManager;
import javax.persistence.Query;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.math.BigInteger;
import java.util.*;
/**
* ${DESCRIPTION}
*/
@SuppressWarnings("unchecked")
public class BaseJpaRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BaseJpaRepository<T, ID> {
private static final long serialVersionUID = 5202242718223588507L;
//批量更新时的阀值,每500条数据commit一次
private static final Integer BATCH_SIZE = 500;
//通过构造方法初始化EntityManager
private final EntityManager entityManager;
public BaseJpaRepositoryImpl(JpaEntityInformation<T, ID> entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.entityManager = entityManager;
}
@Override
public EntityManager getEntityManager() {
return entityManager;
}
@Override
public <E> List<E> findByHql(String hql) {
return (List<E>) entityManager.createQuery(hql)
.getResultList();
}
@Override
public List<Map<?,?>> findBySql(String sql) {
return findBySql(sql, new HashMap<>());
}
@Override
public List<Map<?,?>> findBySql(String sql, Object[] params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.length > 0) {
for (int i = 0; i < params.length; i++) {
nativeQuery.setParameter(i + 1, params[i]);
}
}
return nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.getResultList();
}
@Override
public List<Map<?,?>> findBySql(String sql, Map<String, Object> params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.size() > 0) {
for (String key : params.keySet()) {
nativeQuery.setParameter(key, params.get(key));
}
}
return nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.getResultList();
}
@Override
public <E> List<E> findBySql(String sql, Class<E> clazz, boolean basic) {
return findBySql(sql, clazz, basic, new HashMap<>());
}
@Override
public <E> List<E> findBySql(String sql, Class<E> clazz, boolean basic, Object[] params) {
return getJpaUtil().mapListToObjectList(findBySql(sql, params), clazz, basic);
}
@Override
public <E> List<E> findBySql(String sql, Class<E> clazz, boolean basic, Map<String, Object> params) {
return getJpaUtil().mapListToObjectList(findBySql(sql, params), clazz, basic);
}
@Override
public <E> Page<E> findPageBySql(String sql, Pageable pageable, Class<E> clazz, boolean basic) {
return findPageBySql(sql, pageable, clazz, basic, new HashMap<>());
}
@Override
public <E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class<E> clazz, boolean basic) {
return findPageBySql(sql, countSql, pageable, clazz, basic, new HashMap<>());
}
@Override
public <E> Page<E> findPageBySql(String sql, Pageable pageable, Class<E> clazz, boolean basic, Object[] params) {
return findPageBySql(sql, null, pageable, clazz, basic, params);
}
@Override
public <E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class<E> clazz, boolean basic, Object[] params) {
if (!sql.toLowerCase().contains("order by")) {
StringBuilder stringBuilder = new StringBuilder(sql);
stringBuilder.append(" order by ");
final Sort sort = pageable.getSort();
final List<Sort.Order> orders = sort.toList();
for (Sort.Order order : orders) {
stringBuilder.append(order.getProperty())
.append(" ")
.append(order.getDirection().name())
.append(",");
}
sql = stringBuilder.toString();
sql = sql.substring(0, sql.length() - 1);
}
final Query nativeQuery = entityManager.createNativeQuery(sql);
nativeQuery.setFirstResult(pageable.getPageNumber() * pageable.getPageSize());
nativeQuery.setMaxResults(pageable.getPageSize());
if (params != null && params.length > 0) {
for (int i = 0; i < params.length; i++) {
nativeQuery.setParameter(i + 1, params[i]);
}
}
List<Map<?,?>> resultList = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP).getResultList();
final List<E> objectList = getJpaUtil().mapListToObjectList(resultList, clazz, basic);
if (!StringUtils.hasText(countSql)) {
countSql = "select count(*) from ( " + sql + " ) a";
}
final BigInteger count = findBySqlFirst(countSql, BigInteger.class, true);
Page<E> page = new PageImpl<>(objectList, pageable, count.longValue());
return page;
}
@Override
public <E> Page<E> findPageBySql(String sql, Pageable pageable, Class<E> clazz, boolean basic, Map<String, Object> params) {
return findPageBySql(sql, null, pageable, clazz, basic, params);
}
@Override
public <E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class<E> clazz, boolean basic, Map<String, Object> params) {
if (!sql.toLowerCase().contains("order by")) {
StringBuilder stringBuilder = new StringBuilder(sql);
stringBuilder.append(" order by ");
final Sort sort = pageable.getSort();
final List<Sort.Order> orders = sort.toList();
for (Sort.Order order : orders) {
stringBuilder.append(order.getProperty())
.append(" ")
.append(order.getDirection().name())
.append(",");
}
sql = stringBuilder.toString();
sql = sql.substring(0, sql.length() - 1);
}
final Query nativeQuery = entityManager.createNativeQuery(sql);
nativeQuery.setFirstResult(pageable.getPageNumber() * pageable.getPageSize());
nativeQuery.setMaxResults(pageable.getPageSize());
if (params != null && params.size() > 0) {
for (String key : params.keySet()) {
nativeQuery.setParameter(key, params.get(key));
}
}
List<Map<?,?>> resultList = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP).getResultList();
final List<E> objectList = getJpaUtil().mapListToObjectList(resultList, clazz, basic);
if (!StringUtils.hasText(countSql)) {
countSql = "select count(*) from ( " + sql + " ) a";
}
final BigInteger count = findBySqlFirst(countSql, BigInteger.class, true);
Page<E> page = new PageImpl<>(objectList, pageable, count.longValue());
return page;
}
@Override
public Map<?,?> findBySqlFirst(String sql) {
return findBySqlFirst(sql, new HashMap<>());
}
@Override
public Map<?,?> findBySqlFirst(String sql, Object[] params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.length > 0) {
for (int i = 0; i < params.length; i++) {
nativeQuery.setParameter(i + 1, params[i]);
}
}
final Optional<?> first = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.stream().findFirst();
if (first.isPresent()) {
return (Map<?,?>) first.get();
}
return null;
}
@Override
public Map<?,?> findBySqlFirst(String sql, Map<String, Object> params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.size() > 0) {
for (String key : params.keySet()) {
nativeQuery.setParameter(key, params.get(key));
}
}
final Optional<?> first = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.stream().findFirst();
if (first.isPresent()) {
return (Map<?,?>) first.get();
}
return null;
}
@Override
public <E> E findBySqlFirst(String sql, Class<E> clazz, boolean basic) {
return findBySqlFirst(sql, clazz, basic, new HashMap<>());
}
@Override
public <E> E findBySqlFirst(String sql, Class<E> clazz, boolean basic, Object[] params) {
return getJpaUtil().mapToObject(findBySqlFirst(sql, params), clazz, basic);
}
@Override
public <E> E findBySqlFirst(String sql, Class<E> clazz, boolean basic, Map<String, Object> params) {
return getJpaUtil().mapToObject(findBySqlFirst(sql, params), clazz, basic);
}
@Override
public T findByIdNew(ID id) {
T t = null;
if(id == null){
return null;
}
Optional<T> optional = this.findById(id);
if (optional.isPresent()) {
t = optional.get();
}
return t;
}
@Override
@Transactional
public <S extends T> Iterable<S> batchSave(Iterable<S> iterable) {
Iterator<S> iterator = iterable.iterator();
int index = 0;
while (iterator.hasNext()) {
entityManager.persist(iterator.next());
index++;
if (index % BATCH_SIZE == 0) {
entityManager.flush();
entityManager.clear();
}
}
if (index % BATCH_SIZE != 0) {
entityManager.flush();
entityManager.clear();
}
return iterable;
}
@Override
@Transactional
public <S extends T> Iterable<S> batchUpdate(Iterable<S> iterable) {
Iterator<S> iterator = iterable.iterator();
int index = 0;
while (iterator.hasNext()) {
entityManager.merge(iterator.next());
index++;
if (index % BATCH_SIZE == 0) {
entityManager.flush();
entityManager.clear();
}
}
if (index % BATCH_SIZE != 0) {
entityManager.flush();
entityManager.clear();
}
return iterable;
}
@Override
public void lazyInitialize(Class<T> entityClazz, List<T> l, String[] fields) {
if (fields != null) {
for (String field : fields) {
String targetMethod = "get" + upperFirstWord(field);
Method method;
try {
method = entityClazz.getDeclaredMethod(targetMethod);
for (T o : l) {
Hibernate.initialize(method.invoke(o));
}
} catch (Exception e1) {
e1.printStackTrace();
}
}
}
}
@Override
public void lazyInitialize(T obj,
String[] fields) {
if (obj != null) {
if (fields != null) {
for (String field : fields) {
String targetMethod = "get" + upperFirstWord(field);
Method method;
try {
method = obj.getClass().getDeclaredMethod(targetMethod);
Hibernate.initialize(method.invoke(obj));
} catch (Exception e1) {
e1.printStackTrace();
}
}
}
}
}
private String upperFirstWord(String str) {
StringBuffer sb = new StringBuffer(str);
sb.setCharAt(0, Character.toUpperCase(sb.charAt(0)));
return sb.toString();
}
private JpaUtil getJpaUtil() {
JpaUtil objectUtil = (JpaUtil) ApplicationContextProvider.getBean("jpaUtil");
return objectUtil;
}
}
- ApplicationContextProvider
package com.demo;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Component;
/**
* ${DESCRIPTION}
*/
@Component
public class ApplicationContextProvider
implements ApplicationContextAware {
/**
* 上下文对象实例
*/
private static ApplicationContext applicationContext;
/**
* 获取applicationContext
*
* @return
*/
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
ApplicationContextProvider.applicationContext = applicationContext;
}
/**
* 通过name获取 Bean.
*
* @param name
* @return
*/
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
/**
* 通过class获取Bean.
*
* @param clazz
* @param <T>
* @return
*/
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
/**
* 通过name,以及Clazz返回指定的Bean
*
* @param name
* @param clazz
* @param <T>
* @return
*/
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}
/**
* 描述 : <获得多语言的资源内容>. <br>
* <p>
* <使用方法说明>
* </p>
*
* @param code
* @param args
* @return
*/
public static String getMessage(String code, Object[] args) {
return getApplicationContext().getMessage(code, args, LocaleContextHolder.getLocale());
}
/**
* 描述 : <获得多语言的资源内容>. <br>
* <p>
* <使用方法说明>
* </p>
*
* @param code
* @param args
* @param defaultMessage
* @return
*/
public static String getMessage(String code, Object[] args,
String defaultMessage) {
return getApplicationContext().getMessage(code, args, defaultMessage,
LocaleContextHolder.getLocale());
}
}
package com.demo;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* ${DESCRIPTION}
*/
@Component("jpaUtil")
@Slf4j
public class JpaUtil {
@Autowired
ObjectMapper objectMapper;
/**
* 查询结果为List<Map>时,可以通过该方法转换为对象List,注意Map中key要与对象属性匹配,或者对象属性标注了@JsonProperty
*/
@SuppressWarnings("unchecked")
public <E> List<E> mapListToObjectList(List<Map<?,?>> mapList, Class<E> clazz, boolean basic) {
List<E> list = new ArrayList<>();
for (Map<?,?> map : mapList) {
if (basic) {
list.add((E)map.values().stream().findFirst().get());
} else {
try {
final String valueAsString = objectMapper.writeValueAsString(map);
E newInstance = (E) objectMapper.readValue(valueAsString, clazz);
list.add(newInstance);
} catch (JsonProcessingException e) {
log.error("",e);
}
}
}
return list;
}
/**
* 查询结果为Map时,可以通过该方法转换为对象,注意Map中key要与对象属性匹配,或者对象属性标注了@JsonProperty
*/
@SuppressWarnings("unchecked")
public <E> E mapToObject(Map<?,?> map, Class<E> clazz, boolean basic) {
if(map == null){
return null;
}
E newInstance = null;
//基本类型,说明返回值只有一列
if (basic) {
newInstance = (E) map.values().stream().findFirst().get();
} else {
try {
final String valueAsString = objectMapper.writeValueAsString(map);
newInstance = (E) objectMapper.readValue(valueAsString, clazz);
} catch (JsonProcessingException e) {
log.error("",e);
}
}
return newInstance;
}
}