import com.mongodb.MongoClient;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import org.bson.Document;
import org.bson.codecs.configuration.CodecRegistries;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
/**
* MongoDB客户端
*
* @author zoult on 2018/01/07
*/
@Component
public class MongoDBClient implements InitializingBean, DisposableBean {
private MongoClient client;
private CodecRegistry pojoCodecRegistry;
@Value("${mongodb.hosts}")
private String hosts;
@Value("${mongodb.user}")
private String user;
@Value("${mongodb.pwd}")
private String password;
@Value("${mongodb.database}")
private String database;
@Override
public void afterPropertiesSet() throws Exception {
List<ServerAddress> addressList = new ArrayList<>();
for (String host : hosts.split(",")) {
String[] hostPorts = host.split(":");
if (hostPorts.length == 2) {
addressList.add(new ServerAddress(hostPorts[0], Integer.valueOf(hostPorts[1])));
}
}
MongoCredential credential = MongoCredential.createCredential(user, database, password.toCharArray());
client = new MongoClient(addressList, Arrays.asList(credential));
pojoCodecRegistry = CodecRegistries.fromRegistries(MongoClient.getDefaultCodecRegistry(),
CodecRegistries.fromProviders(PojoCodecProvider.builder()
.automatic(true)
.build()));
}
@Override
public void destroy() throws Exception {
if (client != null) {
client.close();
}
}
/**
* 插入单个文档
*
* @param collectionName 集合名称
* @param pojoClass 文档POJO
* @param value 文档内容
* @param <T>
*/
public <T> void insertOne(String collectionName, Class<T> pojoClass, T value) {
Assert.notNull(value, "value can't be null");
MongoCollection<T> collection = getCollection(database, collectionName, pojoClass);
collection.insertOne(value);
}
/**
* 插入单个文档
*
* @param collectionName 集合名称
* @param document 文档内容
*/
public void insertOne(String collectionName, Document document) {
Assert.notNull(document, "document can't be null");
MongoCollection collection = getCollection(database, collectionName, document.getClass());
collection.insertOne(document);
}
/**
* 插入多个文档
*
* @param collectionName 集合名称
* @param pojoClass 文档POJO
* @param values 文档列表
* @param <T>
*/
public <T> void insertMany(String collectionName, Class<T> pojoClass, List<T> values) {
Assert.notEmpty(values, "values can't be empty");
MongoCollection collection = getCollection(database, collectionName, pojoClass);
collection.insertMany(values);
}
/**
* 插入多个文档
*
* @param collectionName 集合名称
* @param documents 文档列表
*/
public void insertMany(String collectionName, List<Document> documents) {
Assert.notEmpty(documents, "documents can't be empty");
MongoCollection collection = getCollection(database, collectionName, Document.class);
collection.insertMany(documents);
}
/**
* 查询符合条件的文档
*
* @param collectionName 集合名称
* @param pojoClass 文档POJO
* @param condition 查询条件
* @param <T>
* @return
*/
public <T> List<T> find(String collectionName, Class<T> pojoClass, Bson condition) {
Assert.notNull(condition, "condition can't be null");
MongoCollection collection = getCollection(database, collectionName, pojoClass);
List<T> findList = new ArrayList<>();
collection.find(condition).forEach(new Consumer() {
@Override
public void accept(Object o) {
findList.add((T) o);
}
});
return findList;
}
/**
* 查询符合条件的文档
*
* @param collectionName 集合名称
* @param condition 查询条件
* @return
*/
public List<Document> find(String collectionName, Bson condition) {
Assert.notNull(condition, "condition can't be null");
MongoCollection collection = getCollection(database, collectionName, Document.class);
List<Document> findList = new ArrayList<>();
collection.find(condition).forEach(new Consumer() {
@Override
public void accept(Object o) {
findList.add((Document) o);
}
});
return findList;
}
/**
* 查询符合条件的首个文档
*
* @param collectionName 集合名称
* @param pojoClass 文档POJO
* @param condition 查询条件
* @param <T>
* @return
*/
public <T> T findFirst(String collectionName, Class<T> pojoClass, Bson condition) {
Assert.notNull(condition, "condition can't be null");
MongoCollection collection = getCollection(database, collectionName, pojoClass);
return (T) collection.find(condition).first();
}
/**
* 查询符合条件的首个文档
*
* @param collectionName 集合名称
* @param condition 查询条件
* @return
*/
public Document findFirst(String collectionName, Bson condition) {
Assert.notNull(condition, "condition can't be null");
MongoCollection collection = getCollection(database, collectionName, Document.class);
return (Document) collection.find(condition).first();
}
/**
* 更新符合条件的首个文档
*
* @param collectionName 集合名称
* @param condition 查询条件
* @param set 更新内容
* @return
*/
public Long updateOne(String collectionName, Bson condition, Bson set) {
Assert.notNull(condition, "condition can't be null");
Assert.notNull(set, "set can't be null");
MongoCollection collection = getCollection(database, collectionName, Document.class);
return collection.updateOne(condition, set).getModifiedCount();
}
/**
* 更新符合条件的所有文档
*
* @param collectionName 集合名称
* @param condition 查询条件
* @param set 更新内容
* @return
*/
public Long updateMany(String collectionName, Bson condition, Bson set) {
Assert.notNull(condition, "condition can't be null");
Assert.notNull(set, "set can't be null");
MongoCollection collection = getCollection(database, collectionName, Document.class);
return collection.updateMany(condition, set).getModifiedCount();
}
/**
* 删除符合条件的首个文档
*
* @param collectionName 集合名称
* @param condition 查询条件
* @return
*/
public Long deleteOne(String collectionName, Bson condition) {
Assert.notNull(condition, "condition can't be null");
MongoCollection collection = getCollection(database, collectionName, Document.class);
return collection.deleteOne(condition).getDeletedCount();
}
/**
* 删除符合条件的所有文档
*
* @param collectionName 集合名称
* @param condition 查询条件
* @return
*/
public Long deleteMany(String collectionName, Bson condition) {
Assert.notNull(condition, "condition can't be null");
MongoCollection collection = getCollection(database, collectionName, Document.class);
return collection.deleteMany(condition).getDeletedCount();
}
private <T> MongoCollection<T> getCollection(String database, String collectionName, Class<T> pojoClass) {
Assert.hasText(database, "database can't be empty");
Assert.hasText(collectionName, "collection name can't be empty");
Assert.notNull(pojoClass, "pojoClass can't be null");
MongoDatabase mongoDatabase = client.getDatabase(database);
if (pojoClass.isAssignableFrom(Document.class)) {
return mongoDatabase.getCollection(collectionName, pojoClass);
}
return mongoDatabase.getCollection(collectionName, pojoClass).withCodecRegistry(pojoCodecRegistry);
}
}