这篇文章分享一下第一次启动项目时遇到的flyway和查询数据库操作的顺序问题,因为shiro的配置类里需要查询所有的接口权限,然后交给我们自定义的过滤器来处理,但是第一次启动的时候,表和视图都还没有创建,所以启动的时候肯定会报错,表xxx不存在。
package cn.edu.sgu.www.mhxysy.config;
import cn.edu.sgu.www.mhxysy.entity.system.Permission;
import cn.edu.sgu.www.mhxysy.filter.AuthorizationFilter;
import cn.edu.sgu.www.mhxysy.mapper.system.PermissionMapper;
import cn.edu.sgu.www.mhxysy.realm.UserRealm;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import javax.servlet.Filter;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* shiro配置类
*/
@Configuration
public class ShiroConfig {
private final PermissionMapper permissionMapper;
@Autowired
public ShiroConfig(PermissionMapper permissionMapper) {
this.permissionMapper = permissionMapper;
}
@Bean
public UserRealm userRealm() {
return new UserRealm();
}
/**
* 配置安全管理器
* @param userRealm UserRealm
* @return DefaultWebSecurityManager
*/
@Bean(name = "securityManager")
public DefaultWebSecurityManager securityManager(@Qualifier("userRealm") UserRealm userRealm) {
DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
securityManager.setRealm(userRealm);
return securityManager;
}
/**
* 配置Shiro过滤器工厂
* @param securityManager 安全管理器
* @return ShiroFilterFactoryBean
*/
@Bean
@DependsOn("flywayConfig")
public ShiroFilterFactoryBean shiroFilter(@Qualifier("securityManager") DefaultWebSecurityManager securityManager) {
ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
// 注册安全管理器
shiroFilterFactoryBean.setSecurityManager(securityManager);
// 当用户访问认证资源的时候,如果用户没有登录,那么就会跳转到该属性指定的页面
shiroFilterFactoryBean.setLoginUrl("/login.html");
// 添加自定义过滤器
Map<String, Filter> filters = shiroFilterFactoryBean.getFilters();
filters.put("authorization", new AuthorizationFilter());
shiroFilterFactoryBean.setFilters(filters);
// 定义资源访问规则
Map<String, String> map = new LinkedHashMap<>();
map.put("/", "authc");
map.put("/html/*", "authc");
map.put("/index.html", "authc");
List<Permission> permissions = permissionMapper.selectList(null);
for (Permission permission : permissions) {
map.put(permission.getUrl(), "authorization");
}
shiroFilterFactoryBean.setFilterChainDefinitionMap(map);
return shiroFilterFactoryBean;
}
}
以上代码是shiro的配置类,可以看到这段代码,就是查询权限表的数据,然后把权限对应的URL都交给我们自定义的过滤器AuthorizationFilter处理
List<Permission> permissions = permissionMapper.selectList(null);
for (Permission permission : permissions) {
map.put(permission.getUrl(), "authorization");
}
AuthorizationFilter的代码
package cn.edu.sgu.www.mhxysy.filter;
import cn.edu.sgu.www.mhxysy.handler.AnonymityAccessHandler;
import cn.edu.sgu.www.mhxysy.restful.JsonResult;
import cn.edu.sgu.www.mhxysy.restful.ResponseCode;
import com.alibaba.fastjson.JSON;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.subject.Subject;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;
/**
* 鉴权过滤器
* @author heyunlin
* @version 1.0
*/
@WebFilter
public class AuthorizationFilter implements Filter {
@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws ServletException, IOException {
HttpServletRequest request = (HttpServletRequest) req;
String requestURI = request.getRequestURI();
// 匿名访问接口
List<String> list = AnonymityAccessHandler.getList();
// 跳过鉴权
if (list.contains(requestURI)) {
chain.doFilter(req, resp);
return;
}
Subject subject = SecurityUtils.getSubject();
if (subject != null && !subject.isPermitted(requestURI)) {
HttpServletResponse response = (HttpServletResponse) resp;
response.setContentType("application/json;charset=utf-8");
// 构建返回对象
JsonResult<Void> jsonResult= JsonResult.error(ResponseCode.UNAUTHORIZED, "正在访问未授权的资源");
String data = JSON.toJSONString(jsonResult);
response.getWriter().write(data);
return;
}
chain.doFilter(req, resp);
}
}
这段代码是扫描controller包下控制器中使用了@AnonymityAccess注解的方法的请求路径。
List<String> list = AnonymityAccessHandler.getList();
AnonymityAccessHandler的代码
package cn.edu.sgu.www.mhxysy.handler;
import cn.edu.sgu.www.mhxysy.MhxysyApplication;
import cn.edu.sgu.www.mhxysy.annotation.AnonymityAccess;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import java.io.File;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
/**
* @author heyunlin
* @version 1.0
*/
@Component
public class AnonymityAccessHandler implements InitializingBean {
private static List<String> classPaths = new ArrayList<>();
private static final List<String> list = new ArrayList<>();
@Override
public void afterPropertiesSet() throws ClassNotFoundException {
// 找到所有被@AnonymityAccess注解标注的接口路径
// 扫描路径
String basePackage = "cn.edu.sgu.www.mhxysy.controller";
String classpath = MhxysyApplication.class.getResource("/").getPath().replaceFirst("/", "");
String searchPath = classpath + basePackage.replace(".", "/");
searchPath = searchPath.replace("test-classes", "classes");
classPaths = getClassPaths(new File(searchPath));
for(String classPath : classPaths) {
// 得到类的全限定名
classPath = classPath.replace(classpath.replace("/", "\\")
.replaceFirst("\\\\", ""), "")
.replace("\\", ".")
.replace(".class", "");
classpath = classPath.substring(classPath.indexOf(basePackage));
// 通过反射获取类的信息
Class<?> cls = Class.forName(classpath);
// 获取标注在类上的@RequestMapping注解
RequestMapping requestMapping = cls.getAnnotation(RequestMapping.class);
// 控制器类上的路径
String prefix = "";
if(requestMapping != null) {
// path或者value
prefix = requestMapping.value().length > 0 ? requestMapping.value()[0] : requestMapping.path()[0];
}
Method[] methods = cls.getDeclaredMethods();
for (Method method : methods) {
if (method.isAnnotationPresent(AnonymityAccess.class)) {
String url = null;
// 获取url
if (method.isAnnotationPresent(RequestMapping.class)) {
requestMapping = method.getAnnotation(RequestMapping.class);
url = prefix + (requestMapping.value().length > 0 ? requestMapping.value()[0] : requestMapping.path()[0]);
} else if (method.isAnnotationPresent(GetMapping.class)) {
GetMapping getMapping = method.getAnnotation(GetMapping.class);
url = prefix + getMapping.value()[0];
} else if (method.isAnnotationPresent(PostMapping.class)) {
PostMapping postMapping = method.getAnnotation(PostMapping.class);
url = prefix + postMapping.value()[0];
}
// 处理URL
if(url != null && url.endsWith("/")) {
url = url.substring(0, url.length() - 1);
}
list.add(url);
}
}
}
}
private List<String> getClassPaths(File path) {
if (path.isDirectory()) {
File[] files = path.listFiles();
if (files != null) {
for (File file : files) {
getClassPaths(file);
}
}
} else {
if (path.getName().endsWith(".class")) {
classPaths.add(path.getPath());
}
}
return classPaths;
}
public static List<String> getList() {
return list;
}
}
好了,介绍完了之后进入正题。
要想让查询数据库的操作在flyway完成建表之后执行,需要以下步骤:
步骤一:自定义flyway的配置类。
package cn.edu.sgu.www.mhxysy.config;
import org.flywaydb.core.Flyway;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import javax.annotation.PostConstruct;
import javax.sql.DataSource;
/**
* flyway配置类
* @author heyunlin
* @version 1.0
*/
@Configuration
public class FlywayConfig {
private final DataSource dataSource;
@Autowired
public FlywayConfig(DataSource dataSource) {
this.dataSource = dataSource;
}
@PostConstruct
public void migrate() {
Flyway flyway = Flyway.configure()
.dataSource(dataSource)
.locations("db/migration")
.baselineOnMigrate(true)
.load();
flyway.migrate();
}
}
但是这样会有一点小问题,无论配置是否开启flyway,都会执行db/migration下的sql脚本。
改进:引入FlywayProperties,获取我们配置的flyway.enabled值,只有flyway.enabled设置为true才执行。
package cn.edu.sgu.www.mhxysy.config;
import org.flywaydb.core.Flyway;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.flyway.FlywayProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.annotation.PostConstruct;
import javax.sql.DataSource;
/**
* flyway配置类
* @author heyunlin
* @version 1.0
*/
@Configuration
public class FlywayConfig {
private final DataSource dataSource;
@Autowired
public FlywayConfig(DataSource dataSource) {
this.dataSource = dataSource;
}
@Bean
public FlywayProperties flywayProperties() {
return new FlywayProperties();
}
@PostConstruct
public void migrate() {
FlywayProperties flywayProperties = flywayProperties();
if (flywayProperties.isEnabled()) {
Flyway flyway = Flyway.configure()
.dataSource(dataSource)
.locations("db/migration")
.baselineOnMigrate(true)
.load();
flyway.migrate();
}
}
}
步骤二:启动类上排除flyway的自动配置类
package cn.edu.sgu.www.mhxysy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.flyway.FlywayAutoConfiguration;
import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
/**
* @author heyunlin
* @version 1.0
*/
@EnableDiscoveryClient
@SpringBootApplication(exclude = {FlywayAutoConfiguration.class})
public class MhxysyApplication {
private final static Logger logger = LoggerFactory.getLogger(MhxysyApplication.class);
public static void main(String[] args) {
if (logger.isDebugEnabled()) {
logger.debug("启动梦幻西游手游管理......");
}
SpringApplication.run(MhxysyApplication.class, args);
}
}
步骤三:在shiro查询数据库的方法上使用@DependsOn("flywayConfig"),让shiro配置类依赖于flyway的配置类。
好了,文章就分享到这里了,看完不要忘了点赞+收藏哦~