因工作需求,需要根据用户的数据权限,来查询并展示相应的数据,那么就需要动态拦截sql,在根据用户权限做相应的处理,因此需要一个通用拦截器,并以注解实现。该文只做查询拦截,如有其他需求,可根据工作做相应更改。
该注解是方法级,因此需要注解在dao层方法上,如有需要也可更改为类级
注解:
@Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) @Documented // 指名数据库查询方法需要和权限挂钩 public @interface Permission { }
定义拦截器实现接口重写其intercept方法
@Intercepts({ // @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}) // @Signature( type = Executor.class, method = "update",args = {MappedStatement.class, Object.class}), @Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}) }) @Component public class PermissionInterceptor implements Interceptor { @Override public Object intercept(Invocation invocation) throws Throwable { } }
拿到所有查询sql请求,并得到相应的statement
@Intercepts({ // @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}) // @Signature( type = Executor.class, method = "update",args = {MappedStatement.class, Object.class}), @Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}) }) @Component public class PermissionInterceptor implements Interceptor { @Override public Object intercept(Invocation invocation) throws Throwable { String processSql = ExecutorPluginUtils.getSqlByInvocation(invocation); // 执行自定义修改sql操作 // 获取sql String sql2Reset = processSql; Statement statement = CCJSqlParserUtil.parse(processSql); MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0]; } }
如果后端未用分页,则这步可以省略在项目启动类下完成该配置
//得到spring上下文 ConfigurableApplicationContext run = SpringApplication.run(Application.class, args); Interceptor permissionInterceptor = (Interceptor) run.getBean("permissionInterceptor"); //这种方式添加mybatis拦截器保证在pageHelper前执行 run.getBean(SqlSessionFactory.class).getConfiguration().addInterceptor(permissionInterceptor);
工具类
package com.ydy.common.utils; import com.ydy.common.annotation.Permission; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.mapping.SqlSource; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.reflection.DefaultReflectorFactory; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.factory.DefaultObjectFactory; import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory; import java.lang.reflect.Method; import java.lang.reflect.Type; import java.sql.SQLException; import java.util.Arrays; import java.util.Objects; public class ExecutorPluginUtils { /** * 获取sql语句 * @param invocation * @return */ public static String getSqlByInvocation(Invocation invocation) { final Object[] args = invocation.getArgs(); MappedStatement ms = (MappedStatement) args[0]; Object parameterObject = args[1]; BoundSql boundSql = ms.getBoundSql(parameterObject); return boundSql.getSql(); } /** * 包装sql后,重置到invocation中 * @param invocation * @param sql * @throws SQLException */ public static void resetSql2Invocation(Invocation invocation, String sql) throws SQLException { final Object[] args = invocation.getArgs(); MappedStatement statement = (MappedStatement) args[0]; Object parameterObject = args[1]; BoundSql boundSql = statement.getBoundSql(parameterObject); MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql)); MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory()); msObject.setValue("sqlSource.boundSql.sql", sql); args[0] = newStatement; } private static MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) { MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType()); builder.resource(ms.getResource()); builder.fetchSize(ms.getFetchSize()); builder.statementType(ms.getStatementType()); builder.keyGenerator(ms.getKeyGenerator()); if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) { StringBuilder keyProperties = new StringBuilder(); for (String keyProperty : ms.getKeyProperties()) { keyProperties.append(keyProperty).append(","); } keyProperties.delete(keyProperties.length() - 1, keyProperties.length()); builder.keyProperty(keyProperties.toString()); } builder.timeout(ms.getTimeout()); builder.parameterMap(ms.getParameterMap()); builder.resultMaps(ms.getResultMaps()); builder.resultSetType(ms.getResultSetType()); builder.cache(ms.getCache()); builder.flushCacheRequired(ms.isFlushCacheRequired()); builder.useCache(ms.isUseCache()); return builder.build(); } /** * 是否标记为区域字段 * @return */ public static boolean isAreaTag( MappedStatement mappedStatement) throws ClassNotFoundException { String id = mappedStatement.getId(); //获取类名 String className = id.substring(0, id.lastIndexOf(".")); Class clazz = Class.forName(className); //获取方法名 String methodName = id.substring(id.lastIndexOf(".") + 1); //这里是博主工作需求,防止pagehelper那里未生效 if(methodName.contains("_COUNT")){ methodName=methodName.replace("_COUNT",""); } String m=methodName; Class> classType = Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf("."))); //获取对应拦截方法名 String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1); //这里是博主工作需求,防止pagehelper那里未生效 if(mName.contains("_COUNT")){ mName=mName.replace("_COUNT",""); } boolean ignore = false; //获取该类(接口)的所有方法,如果你查询的方法就写在该类,就不需要下面的if判断 Method[] declaredMethods = classType.getDeclaredMethods(); Method declaredMethod = Arrays.stream(declaredMethods).filter(it -> it.getName().equals(m)).findFirst().orElse(null); //该判断是拿到该接口的超类的方法,博主的查询方法就在超类里,因此需要利用下面代码来获取对应方法 if (declaredMethod == null) { Type[] genericInterfaces = clazz.getGenericInterfaces(); declaredMethod = Arrays.stream(genericInterfaces).map(e -> { Method[] declaredMethods1 = ((Class) e).getDeclaredMethods(); return Arrays.stream(declaredMethods1).filter(it -> it.getName().equals(m)).findFirst().orElse(null); }).filter(Objects::nonNull).findFirst().orElse(null); } if(declaredMethod!=null){ //查询方法是否被permission标记注解 ignore = declaredMethod.isAnnotationPresent(Permission.class); } return ignore; } /** * 是否标记为区域字段 * @return */ public static boolean isAreaTagIngore( MappedStatement mappedStatement) throws ClassNotFoundException { String id = mappedStatement.getId(); String className = id.substring(0, id.lastIndexOf(".")); Class clazz = Class.forName(className); String methodName = id.substring(id.lastIndexOf(".") + 1); Class> classType = Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf("."))); //获取对应拦截方法名 String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1); boolean ignore = false; Method[] declaredMethods = classType.getDeclaredMethods(); Method declaredMethod = Arrays.stream(declaredMethods).filter(it -> it.getName().equals(methodName)).findFirst().orElse(null); if (declaredMethod == null) { Type[] genericInterfaces = clazz.getGenericInterfaces(); declaredMethod = Arrays.stream(genericInterfaces).map(e -> { Method[] declaredMethods1 = ((Class) e).getDeclaredMethods(); return Arrays.stream(declaredMethods1).filter(it -> it.getName().equals(methodName)).findFirst().orElse(null); }).filter(Objects::nonNull).findFirst().orElse(null); } ignore = declaredMethod.isAnnotationPresent(Permission.class); return ignore; } public static String getOperateType(Invocation invocation) { final Object[] args = invocation.getArgs(); MappedStatement ms = (MappedStatement) args[0]; SqlCommandType commondType = ms.getSqlCommandType(); if (commondType.compareTo(SqlCommandType.SELECT) == 0) { return "select"; } return null; } // 定义一个内部辅助类,作用是包装sq static class BoundSqlSqlSource implements SqlSource { private BoundSql boundSql; public BoundSqlSqlSource(BoundSql boundSql) { this.boundSql = boundSql; } @Override public BoundSql getBoundSql(Object parameterObject) { return boundSql; } } }
如果方法被permission注解进入if方法,查询各自数据权限,拼接sql,替换sql。如未进入则放行。
if (ExecutorPluginUtils.isAreaTag(mappedStatement)) { //获取该用户所具有的角色的数据权限dataScope //因数据敏感省略 //获取该用户的所在公司或部门下的所有人 //例如 StringBuffer orgBuffer = new StringBuffer(); // orgBuffer.append("("); //String collect = allUserByOrgs.stream().map(String::valueOf).collect(Collectors.joining(",")); //orgBuffer.append(collect).append(")"); //String orgsUser = orgBuffer.toString(); try { if (statement instanceof Select) { Select selectStatement = (Select) statement; //其中的PlainSelect 可以拿到sql语句的全部节点信息,具体各位可以看源码 PlainSelect plain = (PlainSelect) selectStatement.getSelectBody(); //获取所有外连接 Listjoins = plain.getJoins(); //获取到原始sql语句 String sql = processSql; StringBuffer whereSql = new StringBuffer(); switch (dataScope) { //这里dataScope 范围 1 所有数据权限 2 本人 3,部门及分部门(递归) 4.公司及分公司(递归) //所有数据权限作用在人上,因此sql用 in case 1: whereSql.append("1=1"); break; case 2: for (Join join : joins) { Table rightItem = (Table) join.getRightItem(); //匹配表名 if(rightItem.getName().equals("sec_user")){ //获取别名 if(rightItem.getAlias()!=null){ whereSql.append(rightItem.getAlias().getName()).append(".id = ").append(SecurityUtils.getLoginUser().getId()); }else { whereSql.append("id = ").append(deptsUser); } } } break; case 3: for (Join join : joins) { Table rightItem = (Table) join.getRightItem(); if(rightItem.getName().equals("sec_user")){ if(rightItem.getAlias()!=null){ whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(deptsUser); }else { whereSql.append("id in ").append(deptsUser); } } } break; case 4: for (Join join : joins) { Table rightItem = (Table) join.getRightItem(); if(rightItem.getName().equals("sec_user")){ if(rightItem.getAlias()!=null){ whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(orgsUser); }else { whereSql.append("id in ").append(deptsUser); } } } break; } //获取where节点 Expression where = plain.getWhere(); if (where == null) { if (whereSql.length() > 0) { Expression expression = CCJSqlParserUtil .parseCondExpression(whereSql.toString()); Expression whereExpression = (Expression) expression; plain.setWhere(whereExpression); } } else { if (whereSql.length() > 0) { //where条件之前存在,需要重新进行拼接 whereSql.append(" and ( " + where.toString() + " )"); } else { //新增片段不存在,使用之前的sql whereSql.append(where.toString()); } Expression expression = CCJSqlParserUtil .parseCondExpression(whereSql.toString()); plain.setWhere(expression); } sql2Reset = selectStatement.toString(); } } catch (Exception e) { e.printStackTrace(); } } // 替换sql ExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset); //放行 Object proceed = invocation.proceed(); return proceed;