老项目如何改造权限、数据权限系统(三 一 一拦截器 实现)
序言:
最近项目有个数据权限的业务需求,要求大致为每个单位只能查看本级单位及下属单位的数据,例如:一个集团军下属十二个旅,那么军级用户可以看到所有数据,而每个旅则只能看到本旅部的数据,以此类推;
当然通过这个办法也可以实现数据的过滤,但这样的话相比大家也都有同感,那就是每个业务模块 每个人都要进行SQL改动,这次是根据单位过滤、明天又再根据其他的属性过滤,意味着要不停的改来改去,可谓是场面壮观也,而且这种集体改造耗费了时间精力不说,还会有很多不确定因素,比如SQL写错,存在漏网之鱼等等。因此这个解决方案肯定是直接PASS掉咯;
拦截器的使用
由于项目大部分采用的持久层框架是Mybatis,也是使用的Mybatis进行分页拦截处理,因此直接采用了Mybatis拦截器实现数据权限过滤。
前面两节已经说了AOP 了 这里就直接上拦截器
原理图
拦截器:
/**
* 分页拦截器
* @author GaoYuan
* @author lihaoshan 增加了数据权限的拦截过滤
* @datetime 2017/12/1 下午5:43
*/
@Component
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class , Integer.class}),
@Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class DataAuthorityInterceptor implements Interceptor {
private static final Logger log = LoggerFactory.getLogger(DataAuthorityInterceptor.class);
/**数据库类型,不同的数据库有不同的分页方法*/
protected String databaseType="MYSQL";
@SuppressWarnings("rawtypes")
protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<>();
public String getDatabaseType() {
return databaseType;
}
public void setDatabaseType(String databaseType) {
this.databaseType = databaseType;
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
databaseType="MYSQL";
setDatabaseType(databaseType);
}
@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object intercept(Invocation invocation) throws Throwable {
long id = Thread.currentThread().getId();
PermissionContext permissionContext = RecorderUtils.get(id);
// 控制SQL和查询总数的地方
if (invocation.getTarget() instanceof StatementHandler) {
Page page = pageThreadLocal.get();
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
BoundSql boundSql = delegate.getBoundSql();
Connection connection = (Connection) invocation.getArgs()[0];
// 准备数据库类型
prepareAndCheckDatabaseType(connection);
MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
String sql = boundSql.getSql();
/** 单位数据权限拦截 begin */
//获取需要进行拦截的DAO层namespace拼接串
if (permissionContext==null) {
return invocation.proceed();
}
String value = permissionContext.getRe();
if (StringUtils.isNotBlank(value) && REQUIRED.equals(value)){
List<?> dataIds = permissionContext.getDataIds();
if(log.isInfoEnabled()){
log.info("数据权限拦截【拼接SQL】...");
}
//返回拦截包装后的sql
sql = permissionSql(sql,dataIds);
ReflectUtil.setFieldValue(boundSql, "sql", sql);
//不是分页查询
if (page == null) {
if (permissionContext!=null){
RecorderUtils.remove(id);
}
return invocation.proceed();
}
if (page.getCurrent() > -1) {
if (log.isTraceEnabled()) {
log.trace("已经设置了总页数, 不需要再查询总数.");
}
} else {
Object parameterObj = boundSql.getParameterObject();
queryTotalRecord(page, parameterObj, mappedStatement, sql,connection);
}
String pageSql = buildPageSql(page, sql);
if (log.isDebugEnabled()) {
log.debug("分页时, 生成分页pageSql......");
}
ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
if (permissionContext!=null){
RecorderUtils.remove(id);
}
return invocation.proceed();
}else{
if (permissionContext!=null){
RecorderUtils.remove(id);
}
return invocation.proceed();
}
} else { // 查询结果的地方
// 获取是否有分页Page对象
Page<?> page = findPageObject(invocation.getArgs()[1]);
if (page == null) {
if (log.isTraceEnabled()) {
log.trace("没有Page对象作为参数, 不是分页查询.");
}
return invocation.proceed();
} else {
if (log.isTraceEnabled()) {
log.trace("检测到分页Page对象, 使用分页查询.");
}
}
//设置真正的parameterObj
invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
pageThreadLocal.set(page);
try {
// Executor.query(..)
Object resultObj = invocation.proceed();
if (resultObj instanceof List) {
/* @SuppressWarnings({ "unchecked", "rawtypes" }) */
page.setRecords((List) resultObj);
}
return resultObj;
} finally {
pageThreadLocal.remove();
}
}
}
protected Page<?> findPageObject(Object parameterObj) {
if (parameterObj instanceof Page<?>) {
return (Page<?>) parameterObj;
} else if (parameterObj instanceof Map) {
for (Object val : ((Map<?, ?>) parameterObj).values()) {
if (val instanceof Page<?>) {
return (Page<?>) val;
}
}
}
return null;
}
/**
* <pre>
* 把真正的参数对象解析出来
* Spring会自动封装对个参数对象为Map<String, Object>对象
* 对于通过@Param指定key值参数我们不做处理,因为XML文件需要该KEY值
* 而对于没有@Param指定时,Spring会使用0,1作为主键
* 对于没有@Param指定名称的参数,一般XML文件会直接对真正的参数对象解析,
* 此时解析出真正的参数作为根对象
* </pre>
* @param parameterObj
* @return
*/
protected Object extractRealParameterObject(Object parameterObj) {
if (parameterObj instanceof Map<?, ?>) {
Map<?, ?> parameterMap = (Map<?, ?>) parameterObj;
if (parameterMap.size() == 2) {
boolean springMapWithNoParamName = true;
for (Object key : parameterMap.keySet()) {
if (!(key instanceof String)) {
springMapWithNoParamName = false;
break;
}
String keyStr = (String) key;
if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
springMapWithNoParamName = false;
break;
}
}
if (springMapWithNoParamName) {
for (Object value : parameterMap.values()) {
if (!(value instanceof Page<?>)) {
return value;
}
}
}
}
}
return parameterObj;
}
protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
if (databaseType == null) {
String productName = connection.getMetaData().getDatabaseProductName();
if (log.isTraceEnabled()) {
log.trace("Database productName: " + productName);
}
productName = productName.toLowerCase();
if (productName.indexOf(MYSQL) != -1) {
databaseType = MYSQL;
} else if (productName.indexOf(ORACLE) != -1) {
databaseType = ORACLE;
} else {
throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
}
if (log.isInfoEnabled()) {
log.info("自动检测到的数据库类型为: " + databaseType);
}
}
}
/**
* <pre>
* 生成分页SQL
* </pre>
*
* @param page
* @param sql
* @return
*/
protected String buildPageSql(Page<?> page, String sql) {
if (MYSQL.equalsIgnoreCase(databaseType)) {
return buildMysqlPageSql(page, sql);
} else if (ORACLE.equalsIgnoreCase(databaseType)) {
return buildOraclePageSql(page, sql);
}
return sql;
}
/**
* <pre>
* 生成Mysql分页查询SQL
* </pre>
*
* @param page
* @param sql
* @return
*/
protected String buildMysqlPageSql(Page<?> page, String sql) {
// 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
Long offset = (page.getCurrent() - 1) * page.getSize();
if(offset<0){
return " limit 0 ";
}
return new StringBuilder(sql).append(" limit ").append("?").append(",").append("?").toString();
}
/**
* <pre>
* 生成Oracle分页查询SQL
* </pre>
*
* @param page
* @param sql
* @return
*/
protected String buildOraclePageSql(Page<?> page, String sql) {
// 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
long offset = (page.getCurrent() - 1) * page.getSize() + 1;
StringBuilder sb = new StringBuilder(sql);
sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getSize());
sb.insert(0, "select * from (").append(") where r >= ").append(offset);
return sb.toString();
}
/**
* <pre>
* 查询总数
* </pre>
*
* @param page
* @param parameterObject
* @param mappedStatement
* @param sql
* @param connection
* @throws SQLException
*/
protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, String sql, Connection connection) throws SQLException {
BoundSql boundSql = mappedStatement.getBoundSql(page);
/// String sql = boundSql.getSql();
String countSql = this.buildCountSql(sql);
if (log.isDebugEnabled()) {
log.debug("分页时, 生成countSql......");
}
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
parameterHandler.setParameters(pstmt);
rs = pstmt.executeQuery();
if (rs.next()) {
long totalRecord = rs.getLong(1);
page.setTotal(totalRecord);
}
} finally {
if (rs != null) {
try {
rs.close();
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("关闭ResultSet时异常.", e);
}
}
}
if (pstmt != null) {
try {
pstmt.close();
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("关闭PreparedStatement时异常.", e);
}
}
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
*
* @param sql
* @return
*/
protected String buildCountSql(String sql) {
//查出第一个from,先转成小写
sql = sql.toLowerCase();
int index = sql.indexOf("from");
return "select count(0) " + sql.substring(index);
}
/**
* 数据权限sql包装【只能查看本级单位及下属单位的数据】
* @author lihaoshan
* @date 2018-07-19
*/
protected String permissionSql(String sql,List<?> ids) {
if(sql.contains("LIMIT")){
sql = sql.substring(0,sql.indexOf("LIMIT"));
System.out.println(sql);
}
StringBuilder sbSql = new StringBuilder(sql);
//获取当前登录人
//获取当前登录人所属单位标识
//如果有动态参数 orgId
if(ids != null && !ids.isEmpty()){
sbSql = new StringBuilder("select * from (")
.append(sbSql)
.append(" ) s ")
.append(" where id in ("+ ids.toString().substring(1,ids.toString().length()-1) +") ");
}
if(ids == null || ids.isEmpty()){
sbSql = new StringBuilder("select * from (")
.append(sbSql)
.append(" ) s ")
.append(" where id in (NULL) ");
}
return sbSql.toString();
}
}