package com.xxx;
import com.mongodb.BasicDBObject;
import com.mongodb.CommandResult;
import com.mongodb.DBCollection;
import com.mongodb.DBObject;
import com.tuya.arthas.core.domain.base.BaseMongoDO;
import com.tuya.arthas.core.domain.base.BathUpdateOptions;
import com.tuya.arthas.core.utils.CurrentTimeMillisClock;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.FatalBeanException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.mapping.Document;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.stereotype.Repository;
import javax.annotation.Resource;
import java.beans.IntrospectionException;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Repository("mongodbBaseDAO")
public class MongodbBaseDAO<T extends BaseMongoDO, P> implements InitializingBean {
@Resource
private MongoTemplate mongoTemplate;
private Class<T> entityClass;
private Class<T> getEntityClass() {
ParameterizedType pt = (ParameterizedType) this.getClass().getGenericSuperclass();
// 获取第一个类型参数的真实类型
entityClass = (Class<T>) pt.getActualTypeArguments()[0];
return entityClass;
}
public void save(T entity) {
if (null == entity) {
return;
}
try {
if (StringUtils.isEmpty(BeanUtils.getProperty(entity, "gmtCreate"))) {
BeanUtils.setProperty(entity, "gmtCreate", System.currentTimeMillis());
}
if (StringUtils.isEmpty(BeanUtils.getProperty(entity, "gmtModified"))) {
BeanUtils.setProperty(entity, "gmtModified", System.currentTimeMillis());
}
} catch (Exception e) {
log.error("MongodbBaseDAO save error", e);
throw new BaseException(e);
}
mongoTemplate.insert(entity);
}
public void save(Collection<T> entitys) {
if (CollectionUtils.isEmpty(entitys)) {
return;
}
entitys.stream().map(o -> {
try {
if (StringUtils.isEmpty(BeanUtils.getProperty(o, "gmtCreate"))) {
BeanUtils.setProperty(o, "gmtCreate", System.currentTimeMillis());
}
if (StringUtils.isEmpty(BeanUtils.getProperty(o, "gmtModified"))) {
BeanUtils.setProperty(o, "gmtModified", System.currentTimeMillis());
}
} catch (Exception e) {
log.error("MongodbBaseDAO save error", e);
throw new BaseException(e);
}
return o;
}).collect(Collectors.toList());
mongoTemplate.insertAll(entitys);
}
public void remove(Query query) {
if (null == query) {
return;
}
mongoTemplate.remove(query, getEntityClass());
}
public void dropCollection(Class<?> clazz) {
mongoTemplate.dropCollection(clazz);
}
public T getById(Object id) {
try {
if (null == id) {
return null;
}
return mongoTemplate.findById(id, getEntityClass());
} catch (Exception e) {
log.error("MongodbBaseDAO getById error", e);
return null;
}
}
public T getByCustomId(P customId) {
try {
if (null == customId) {
return null;
}
Query query = new Query(Criteria.where("customId").is((customId)));
return mongoTemplate.findOne(query, getEntityClass());
} catch (Exception e) {
log.error("MongodbBaseDAO getByCustomId error", e);
return null;
}
}
public List<T> getByCustomIds(List<P> customIds) {
try {
if (CollectionUtils.isEmpty(customIds)) {
return new ArrayList<>();
}
Query query = Query.query(Criteria.where("customId").in(customIds));
return mongoTemplate.find(query, getEntityClass());
} catch (Exception e) {
log.error("MongodbBaseDAO getByCustomIds error", e);
return new ArrayList<>();
}
}
public T query(Query query) {
try {
if (null == query) {
return null;
}
return mongoTemplate.findOne(query, getEntityClass());
} catch (Exception e) {
log.error("MongodbBaseDAO query error", e);
return null;
}
}
public List<T> queryList(Query query) {
try {
if (null == query) {
return new ArrayList<>();
}
return mongoTemplate.find(query, getEntityClass());
} catch (Exception e) {
log.error("MongodbBaseDAO queryList error", e);
return new ArrayList<>();
}
}
public List<T> queryAll() {
try {
Query query = new Query();
return mongoTemplate.find(query, getEntityClass());
} catch (Exception e) {
log.error("MongodbBaseDAO queryAll error", e);
return new ArrayList<>();
}
}
public void upsert(T entity) {
if (null == entity) {
return;
}
Query query = new Query(Criteria.where("customId").is(entity.getCustomId()));
Update update = convertBean2Update(entity);
mongoTemplate.upsert(query, update, getEntityClass());
}
public Update convertBean2Update(T entity) {
if (null == entity) {
return new Update();
}
Update update = new Update();
Field[] fields = getAllFields(getEntityClass());
try {
for (Field field : fields) {
field.setAccessible(true);
Object value = field.get(entity);
if ("gmtModified".equals(field.getName()) && null == value) {
update.set(field.getName(), System.currentTimeMillis());
} else if ("id".equals(field.getName()) || "gmtCreate".equals(field.getName())) {
continue;
} else if (null == value) {
continue;
} else {
update.set(field.getName(), value);
}
}
} catch (Exception e) {
log.error("convertBean2Update save error", e);
throw new BaseException(e);
}
return update;
}
public int bathUpdate(Class<?> entityClass, List<BathUpdateOptions> options) {
if (null == entityClass || CollectionUtils.isEmpty(options)) {
return -1;
}
String collectionName = determineCollectionName(entityClass);
return doBathUpdate(mongoTemplate.getCollection(collectionName),
collectionName, options, true);
}
public int batchUpdateByCustomId(List<T> updateList) {
if (CollectionUtils.isEmpty(updateList)) {
return -1;
}
List<BathUpdateOptions> updateDatas = new LinkedList<>();
updateList.forEach(o -> {
updateDatas.add(new BathUpdateOptions(Query.query(Criteria.where("customId").is(o.getCustomId())),
convertBean2Update(o), true, true));
});
return bathUpdate(getEntityClass(), updateDatas);
}
/**
* 获取文档名
* @param entityClass
* @return
*/
private static String determineCollectionName(Class<?> entityClass) {
if (entityClass == null) {
throw new InvalidDataAccessApiUsageException(
"No class parameter provided, entity collection can't be determined!");
}
String collName = entityClass.getSimpleName();
if (entityClass.isAnnotationPresent(Document.class)) {
Document document = entityClass.getAnnotation(Document.class);
collName = document.collection();
} else {
collName = collName.replaceFirst(collName.substring(0, 1)
, collName.substring(0, 1).toLowerCase());
}
return collName;
}
private static int doBathUpdate(DBCollection dbCollection, String collName,
List<BathUpdateOptions> options, boolean ordered) {
DBObject command = new BasicDBObject();
command.put("update", collName);
List<BasicDBObject> updateList = new ArrayList<BasicDBObject>();
for (BathUpdateOptions option : options) {
BasicDBObject update = new BasicDBObject();
update.put("q", option.getQuery().getQueryObject());
update.put("u", option.getUpdate().getUpdateObject());
update.put("upsert", option.isUpsert());
update.put("multi", option.isMulti());
updateList.add(update);
}
command.put("updates", updateList);
command.put("ordered", ordered);
CommandResult commandResult = dbCollection.getDB().command(command);
return Integer.parseInt(commandResult.get("n").toString());
}
/**
* 获取本类及其父类的属性的方法
*
* @param clazz 当前类对象
* @return 字段数组
*/
private static Field[] getAllFields(Class<?> clazz) {
List<Field> fieldList = new ArrayList<>();
while (clazz != null) {
fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
clazz = clazz.getSuperclass();
}
Field[] fields = new Field[fieldList.size()];
return fieldList.toArray(fields);
}
@SuppressWarnings("unchecked")
private Class<T> resolveEntityClassByGeneric() {
Type type = this.getClass().getGenericSuperclass();
if (type != null && type instanceof ParameterizedType) {
ParameterizedType ptype = (ParameterizedType) type;
Type[] types = ptype.getActualTypeArguments();
if (types != null && types.length > 0) {
Type genericClass = types[0];
if (genericClass != null && genericClass instanceof Class) {
Class<T> genericClazz = (Class<T>) genericClass;
String genericClazzName = genericClazz.getName();
if (!(genericClazz.isInterface() || genericClazzName.startsWith("java."))) {
return genericClazz;
}
}
}
}
return null;
}
@Override
public void afterPropertiesSet() throws Exception {
if (entityClass == null) {
entityClass = resolveEntityClassByGeneric();
}
if (entityClass != null) {
PropertyDescriptor[] propertyDescriptors = null;
try {
propertyDescriptors = Introspector.getBeanInfo(entityClass).getPropertyDescriptors();
} catch (IntrospectionException e) {
throw new FatalBeanException(e.getMessage(), e);
}
}
}
}