配置AOP拦截所有controller请求, 获取当前用户需要访问的MongoDB数据库名称
package com.scy.aspect;
import com.scy.mongo.MyMongoTemplate;
import com.scy.mongo.ZbkMongoTemplate;
import com.scy.shiro.MySessionManager;
import com.scy.vo.PersonPermissionVo;
import org.apache.commons.lang.StringUtils;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.web.util.WebUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.AfterReturning;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.util.Map;
@Aspect
@Component
public class TestAspect {
private ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal<>();
@Resource
private ZbkMongoTemplate zbkMongoTemplate;
@Pointcut("execution (public * com.scy.controller..*.*(..))")
public void controllerMethodPointcut() {
}
@Before("controllerMethodPointcut()")
public void controller(JoinPoint point) {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (attributes == null) {
return;
}
HttpServletRequest request = attributes.getRequest();
String token = WebUtils.toHttp(request).getHeader(MySessionManager.TOKEN);
if (StringUtils.isNotEmpty(token)) {
PersonPermissionVo personPermissionVo = (PersonPermissionVo) SecurityUtils.getSubject().getSession().getAttribute("person");
if (personPermissionVo != null && StringUtils.isNotEmpty(personPermissionVo.getDiseaseCode())) {
MyMongoTemplate mongoTemplate = zbkMongoTemplate.getMongoTemplateByDBName(personPermissionVo.getDiseaseCode());
zbkMongoTemplate.set(mongoTemplate);
}
}
}
@AfterReturning(pointcut = "controllerMethodPointcut()")
public void doAfterReturing() {
zbkMongoTemplate.remove();
}
}
核心步骤: 根据数据库名称获取MongoTemplate
package com.scy.mongo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.SimpleMongoClientDatabaseFactory;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@Component
public class ZbkMongoTemplate {
private final Logger logger = LoggerFactory.getLogger(getClass());
private static ThreadLocal<MongoTemplate> mongoTemplateThreadLocal = new ThreadLocal<>();
private Object lock = new Object();
@Value("${spring.data.mongodb.uri}")
private String uri;
private Map<String, MyMongoTemplate> templateMuliteMap = new ConcurrentHashMap<>();
public MyMongoTemplate getMongoTemplateByDBName(String dbName) {
MyMongoTemplate mongoTemplate = templateMuliteMap.get(dbName);
if(mongoTemplate==null){
synchronized (lock) {
SimpleMongoClientDatabaseFactory simpleMongoClientDbFactory = new SimpleMongoClientDatabaseFactory(this.uri.replace("#", dbName));
mongoTemplate = new MyMongoTemplate(simpleMongoClientDbFactory);
logger.info("生成数据库" + dbName + "的mongoTemplate");
templateMuliteMap.put(dbName, mongoTemplate);
}
}
return mongoTemplate;
}
public Set<String> getDiseaseCodeSet() {
return templateMuliteMap.keySet();
}
public MongoTemplate get() {
return mongoTemplateThreadLocal.get();
}
public void set(MongoTemplate mongoTemplate) {
this.mongoTemplateThreadLocal.set(mongoTemplate);
}
public void remove() {
this.mongoTemplateThreadLocal.remove();
}
}
MongoDB url配置
spring.data.mongodb.uri=mongodb://127.0.0.1:20000/#
使用
- 注入ZbkMongoTemplate
- 获取MongoTemplate
- 可以获取的登录用户的token时
MongoTemplate myMongoTemplate = zbkMongoTemplate.get();
- 获取不到登录用户的token时
MongoTemplate myMongoTemplate = zbkMongoTemplate.getMongoTemplateByDBName("数据库名称");