老项目如何改造权限、数据权限系统(三 一 一拦截器 实现)

老项目如何改造权限、数据权限系统(三 一 一拦截器 实现)

序言:

​ 最近项目有个数据权限的业务需求,要求大致为每个单位只能查看本级单位及下属单位的数据,例如:一个集团军下属十二个旅,那么军级用户可以看到所有数据,而每个旅则只能看到本旅部的数据,以此类推;

​ 当然通过这个办法也可以实现数据的过滤,但这样的话相比大家也都有同感,那就是每个业务模块 每个人都要进行SQL改动,这次是根据单位过滤、明天又再根据其他的属性过滤,意味着要不停的改来改去,可谓是场面壮观也,而且这种集体改造耗费了时间精力不说,还会有很多不确定因素,比如SQL写错,存在漏网之鱼等等。因此这个解决方案肯定是直接PASS掉咯;

拦截器的使用

​ 由于项目大部分采用的持久层框架是Mybatis,也是使用的Mybatis进行分页拦截处理,因此直接采用了Mybatis拦截器实现数据权限过滤。

前面两节已经说了AOP 了 这里就直接上拦截器

​ 原理图

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1On1uVOR-1621221576903)(/Users/paramland/Library/Application Support/typora-user-images/image-20210517110358494.png)]

拦截器:

/**
 * 分页拦截器
 * @author GaoYuan
 * @author lihaoshan 增加了数据权限的拦截过滤
 * @datetime 2017/12/1 下午5:43
 */
@Component
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class , Integer.class}),
		@Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class DataAuthorityInterceptor implements Interceptor {

	private static final Logger log = LoggerFactory.getLogger(DataAuthorityInterceptor.class);

	/**数据库类型,不同的数据库有不同的分页方法*/
	protected String databaseType="MYSQL";

	@SuppressWarnings("rawtypes")
	protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<>();


	public String getDatabaseType() {
		return databaseType;
	}

	public void setDatabaseType(String databaseType) {
		this.databaseType = databaseType;
	}

	@Override
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {
		databaseType="MYSQL";
		setDatabaseType(databaseType);
	}

	@Override
	@SuppressWarnings({ "unchecked", "rawtypes" })
	public Object intercept(Invocation invocation) throws Throwable {

		long id = Thread.currentThread().getId();
		PermissionContext permissionContext = RecorderUtils.get(id);

		// 控制SQL和查询总数的地方
		if (invocation.getTarget() instanceof StatementHandler) {
			Page page = pageThreadLocal.get();
			RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
			StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
			BoundSql boundSql = delegate.getBoundSql();
			Connection connection = (Connection) invocation.getArgs()[0];
			// 准备数据库类型
			prepareAndCheckDatabaseType(connection);
			MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
			String sql = boundSql.getSql();

			/** 单位数据权限拦截 begin */
			//获取需要进行拦截的DAO层namespace拼接串

			if (permissionContext==null) {
				return invocation.proceed();
			}
			String value = permissionContext.getRe();

			if (StringUtils.isNotBlank(value) && REQUIRED.equals(value)){
				List<?> dataIds = permissionContext.getDataIds();
				if(log.isInfoEnabled()){
					log.info("数据权限拦截【拼接SQL】...");
				}
				//返回拦截包装后的sql
				sql = permissionSql(sql,dataIds);
				ReflectUtil.setFieldValue(boundSql, "sql", sql);
				//不是分页查询
				if (page == null) {
					if (permissionContext!=null){
						RecorderUtils.remove(id);
					}
					return invocation.proceed();
				}

				if (page.getCurrent() > -1) {
					if (log.isTraceEnabled()) {
						log.trace("已经设置了总页数, 不需要再查询总数.");
					}
				} else {
					Object parameterObj = boundSql.getParameterObject();
					queryTotalRecord(page, parameterObj, mappedStatement, sql,connection);
				}

				String pageSql = buildPageSql(page, sql);
				if (log.isDebugEnabled()) {
					log.debug("分页时, 生成分页pageSql......");
				}
				ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
				if (permissionContext!=null){
					RecorderUtils.remove(id);
				}
				return invocation.proceed();
			}else{
				if (permissionContext!=null){
					RecorderUtils.remove(id);
				}
				return invocation.proceed();
			}
		} else { // 查询结果的地方
			// 获取是否有分页Page对象
			Page<?> page = findPageObject(invocation.getArgs()[1]);
			if (page == null) {
				if (log.isTraceEnabled()) {
					log.trace("没有Page对象作为参数, 不是分页查询.");
				}
				return invocation.proceed();
			} else {
				if (log.isTraceEnabled()) {
					log.trace("检测到分页Page对象, 使用分页查询.");
				}
			}
			//设置真正的parameterObj
			invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
			pageThreadLocal.set(page);
			try {
				// Executor.query(..)
				Object resultObj = invocation.proceed();
				if (resultObj instanceof List) {
					/* @SuppressWarnings({ "unchecked", "rawtypes" }) */
					page.setRecords((List) resultObj);
				}
				return resultObj;
			} finally {
				pageThreadLocal.remove();
			}
		}
	}

	protected Page<?> findPageObject(Object parameterObj) {
		if (parameterObj instanceof Page<?>) {
			return (Page<?>) parameterObj;
		} else if (parameterObj instanceof Map) {
			for (Object val : ((Map<?, ?>) parameterObj).values()) {
				if (val instanceof Page<?>) {
					return (Page<?>) val;
				}
			}
		}
		return null;
	}

	/**
	 * <pre>
	 * 把真正的参数对象解析出来
	 * Spring会自动封装对个参数对象为Map<String, Object>对象
	 * 对于通过@Param指定key值参数我们不做处理,因为XML文件需要该KEY值
	 * 而对于没有@Param指定时,Spring会使用0,1作为主键
	 * 对于没有@Param指定名称的参数,一般XML文件会直接对真正的参数对象解析,
	 * 此时解析出真正的参数作为根对象
	 * </pre>
	 * @param parameterObj
	 * @return
	 */
	protected Object extractRealParameterObject(Object parameterObj) {
		if (parameterObj instanceof Map<?, ?>) {
			Map<?, ?> parameterMap = (Map<?, ?>) parameterObj;
			if (parameterMap.size() == 2) {
				boolean springMapWithNoParamName = true;
				for (Object key : parameterMap.keySet()) {
					if (!(key instanceof String)) {
						springMapWithNoParamName = false;
						break;
					}
					String keyStr = (String) key;
					if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
						springMapWithNoParamName = false;
						break;
					}
				}
				if (springMapWithNoParamName) {
					for (Object value : parameterMap.values()) {
						if (!(value instanceof Page<?>)) {
							return value;
						}
					}
				}
			}
		}
		return parameterObj;
	}

	protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
		if (databaseType == null) {
			String productName = connection.getMetaData().getDatabaseProductName();
			if (log.isTraceEnabled()) {
				log.trace("Database productName: " + productName);
			}
			productName = productName.toLowerCase();
			if (productName.indexOf(MYSQL) != -1) {
				databaseType = MYSQL;
			} else if (productName.indexOf(ORACLE) != -1) {
				databaseType = ORACLE;
			} else {
				throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
			}
			if (log.isInfoEnabled()) {
				log.info("自动检测到的数据库类型为: " + databaseType);
			}
		}
	}

	/**
	 * <pre>
	 * 生成分页SQL
	 * </pre>
	 *
	 * @param page
	 * @param sql
	 * @return
	 */
	protected String buildPageSql(Page<?> page, String sql) {
		if (MYSQL.equalsIgnoreCase(databaseType)) {
			return buildMysqlPageSql(page, sql);
		} else if (ORACLE.equalsIgnoreCase(databaseType)) {
			return buildOraclePageSql(page, sql);
		}
		return sql;
	}

	/**
	 * <pre>
	 * 生成Mysql分页查询SQL
	 * </pre>
	 *
	 * @param page
	 * @param sql
	 * @return
	 */
	protected String buildMysqlPageSql(Page<?> page, String sql) {
		// 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
		Long offset = (page.getCurrent() - 1) * page.getSize();
		if(offset<0){
			return " limit 0 ";
		}
		return new StringBuilder(sql).append(" limit ").append("?").append(",").append("?").toString();
	}

	/**
	 * <pre>
	 * 生成Oracle分页查询SQL
	 * </pre>
	 *
	 * @param page
	 * @param sql
	 * @return
	 */
	protected String buildOraclePageSql(Page<?> page, String sql) {
		// 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
		long offset = (page.getCurrent() - 1) * page.getSize() + 1;
		StringBuilder sb = new StringBuilder(sql);
		sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getSize());
		sb.insert(0, "select * from (").append(") where r >= ").append(offset);
		return sb.toString();
	}

	/**
	 * <pre>
	 * 查询总数
	 * </pre>
	 *
	 * @param page
	 * @param parameterObject
	 * @param mappedStatement
	 * @param sql
	 * @param connection
	 * @throws SQLException
	 */
	protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, String sql, Connection connection) throws SQLException {
		BoundSql boundSql = mappedStatement.getBoundSql(page);
///        String sql = boundSql.getSql();

		String countSql = this.buildCountSql(sql);
		if (log.isDebugEnabled()) {
			log.debug("分页时, 生成countSql......");
		}

		List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
		BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
		ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
		PreparedStatement pstmt = null;
		ResultSet rs = null;
		try {
			pstmt = connection.prepareStatement(countSql);
			parameterHandler.setParameters(pstmt);
			rs = pstmt.executeQuery();
			if (rs.next()) {
				long totalRecord = rs.getLong(1);
				page.setTotal(totalRecord);
			}
		} finally {
			if (rs != null) {
				try {
					rs.close();
				} catch (Exception e) {
					if (log.isWarnEnabled()) {
						log.warn("关闭ResultSet时异常.", e);
					}
				}
			}
			if (pstmt != null) {
				try {
					pstmt.close();
				} catch (Exception e) {
					if (log.isWarnEnabled()) {
						log.warn("关闭PreparedStatement时异常.", e);
					}
				}
			}
		}
	}

	/**
	 * 根据原Sql语句获取对应的查询总记录数的Sql语句
	 *
	 * @param sql
	 * @return
	 */
	protected String buildCountSql(String sql) {
		//查出第一个from,先转成小写
		sql = sql.toLowerCase();
		int index = sql.indexOf("from");
		return "select count(0) " + sql.substring(index);
	}

	/**
	 * 数据权限sql包装【只能查看本级单位及下属单位的数据】
	 * @author lihaoshan
	 * @date 2018-07-19
	 */
	protected String permissionSql(String sql,List<?> ids) {
		if(sql.contains("LIMIT")){
			sql = sql.substring(0,sql.indexOf("LIMIT"));
			System.out.println(sql);
		}
		StringBuilder sbSql = new StringBuilder(sql);
		//获取当前登录人
		//获取当前登录人所属单位标识
		//如果有动态参数 orgId
		if(ids != null && !ids.isEmpty()){

			sbSql = new StringBuilder("select * from (")
					.append(sbSql)
					.append(" ) s ")
					.append(" where id in ("+ ids.toString().substring(1,ids.toString().length()-1) +") ");
		}
		if(ids == null || ids.isEmpty()){

			sbSql = new StringBuilder("select * from (")
					.append(sbSql)
					.append(" ) s ")
					.append(" where id in (NULL) ");
		}
		return sbSql.toString();
	}
}
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值