前言 在使用 Shiro 的过程中,遇到一个痛点,就是对 restful 支持不太好,也查了很多资料,各种各样的方法都有,要不就是功能不完整,要不就是解释不清楚,还有一些对原有功能的侵入性太强,经过一番探索,算是最简的配置下完成了需要的功能,这里给大家分享下。大家如果又更好的方案,也可以在评论区留言,互相探讨下。
虽然深入到了源码进行分析,但过程并不复杂,希望大家可以跟着我的思路捋顺了耐心看下去,而不是看见源码贴就抵触。
分析 首先先回顾下 Shiro 的过滤器链,一般我们都有如下配置:
1 2 3 4 /login.html = anon /login = anon /users = perms[user:list] /** = authc
不太熟悉的朋友可以了解下这篇文章:Shiro 过滤器 。
其中 /users
请求对应到 perms
过滤器,对应的类: org.apache.shiro.web.filter.authz.PermissionsAuthorizationFilter
,其中的 onAccessDenied
方法是在没有权限时被调用的, 源码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 protected boolean onAccessDenied (ServletRequest request, ServletResponse response) throws IOException { Subject subject = getSubject(request, response); if (subject.getPrincipal() == null ) { saveRequestAndRedirectToLogin(request, response); } else { String unauthorizedUrl = getUnauthorizedUrl(); if (StringUtils.hasText(unauthorizedUrl)) { WebUtils.issueRedirect(request, response, unauthorizedUrl); } else { WebUtils.toHttp(response).sendError(HttpServletResponse.SC_UNAUTHORIZED); } } return false ; }
我们可以在这里可以判断当前请求是否时 AJAX 请求,如果是,则不跳转到 logoUrl 或 UnauthorizedUrl 页面,而是返回 JSON 数据。
还有一个方法是 pathsMatch,是将当前请求的 url 与所有配置的 perms 过滤器链进行匹配,是则进行权限检查,不是则接着与下一个过滤器链进行匹配,源码如下:
1 2 3 4 5 protected boolean pathsMatch (String path, ServletRequest request) { String requestURI = getPathWithinApplication(request); log.trace("Attempting to match pattern '{}' with current requestURI '{}'..." , path, requestURI); return pathsMatch(path, requestURI); }
方法 了解完这两个方法,我来说说如何利用这两个方法来实现功能。
我们可以从配置的过滤器链来入手,原先的配置如:
1 /users = perms[user:list]
我们可以改为 /user==GET
,/user==POST
方式。==
用来分隔, 后面的部分指 HTTP Method
。
使用这种方式还要注意一个方法,即:org.apache.shiro.web.filter.mgt.PathMatchingFilterChainResolver
中的 getChain
方法,用来获取当前请求的 URL 应该使用的过滤器,源码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 public FilterChain getChain (ServletRequest request, ServletResponse response, FilterChain originalChain) { FilterChainManager filterChainManager = getFilterChainManager(); if (!filterChainManager.hasChains()) { return null ; } String requestURI = getPathWithinApplication(request); for (String pathPattern : filterChainManager.getChainNames()) { if (pathMatches(pathPattern, requestURI)) { if (log.isTraceEnabled()) { log.trace("Matched path pattern [" + pathPattern + "] for requestURI [" + requestURI + "]. " + "Utilizing corresponding filter chain..." ); } return filterChainManager.proxy(originalChain, pathPattern); } } return null ; }
这里大家需要注意,第四步的判断,我们已经将过滤器链,也就是这里的 pathPattern
改为了 /xxx==GET
这种方式,而请求的 URL 却仅包含 /xxx
,那么这里的 pathMatches
方法是肯定无法匹配成功,所以我们需要在第四步判断的时候,只判断前面的 URL
部分。
整个过程如下:
在过滤器链上对 restful 请求配置需要的 HTTP Method
,如:/user==DELETE
。
修改 PathMatchingFilterChainResolver
的 getChain
方法,当前请求的 URL 与过滤器链匹配时,过滤器只取 URL 部分进行判断。
修改过滤器的 pathsMatch
方法,判断当前请求的 URL 与请求方式是否与过滤器链中配置的一致。
修改过滤器的 onAccessDenied
方法,当访问被拒绝时,根据普通请求和 AJAX
请求分别返回 HTML
和 JSON
数据。
下面我们逐步来实现:
实现 过滤器链添加 http method 在我的项目中是从数据库获取的过滤器链,所以有如下代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 public Map<String, String> getUrlPermsMap () { Map<String, String> filterChainDefinitionMap = new LinkedHashMap<>(); filterChainDefinitionMap.put("/favicon.ico" , "anon" ); filterChainDefinitionMap.put("/css/**" , "anon" ); filterChainDefinitionMap.put("/fonts/**" , "anon" ); filterChainDefinitionMap.put("/images/**" , "anon" ); filterChainDefinitionMap.put("/js/**" , "anon" ); filterChainDefinitionMap.put("/lib/**" , "anon" ); filterChainDefinitionMap.put("/login" , "anon" ); List<Menu> menus = selectAll(); for (Menu menu : menus) { String url = menu.getUrl(); if (menu.getMethod() != null && !"" .equals(menu.getMethod())) { url += ("==" + menu.getMethod()); } String perms = "perms[" + menu.getPerms() + "]" ; filterChainDefinitionMap.put(url, perms); } filterChainDefinitionMap.put("/**" , "authc" ); return filterChainDefinitionMap; }
如: /xxx==GET = perms[user:list]
这里的 getUrl
,getMethod
和 getPerms
分别对应 /xxx
,GET
和 user:list
。
不过需要注意的是,如果在 XML 里配置,会被 Shiro 解析成 /xxx
和 =GET = perms[user:list]
,解决办法是使用其他符号代替 ==
。
修改 PathMatchingFilterChainResolver 的 getChain 方法 由于 Shiro 没有提供相应的接口,且我们不能直接修改源码,所以我们需要新建一个类继承 PathMatchingFilterChainResolver
并重写 getChain
方法,然后替换掉 PathMatchingFilterChainResolver
即可。
首先继承并重写方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 package im.zhaojun.shiro;import org.apache.shiro.web.filter.mgt.FilterChainManager;import org.apache.shiro.web.filter.mgt.PathMatchingFilterChainResolver;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import javax.servlet.FilterChain;import javax.servlet.ServletRequest;import javax.servlet.ServletResponse;public class RestPathMatchingFilterChainResolver extends PathMatchingFilterChainResolver { private static final Logger log = LoggerFactory.getLogger(RestPathMatchingFilterChainResolver.class); @Override public FilterChain getChain (ServletRequest request, ServletResponse response, FilterChain originalChain) { FilterChainManager filterChainManager = getFilterChainManager(); if (!filterChainManager.hasChains()) { return null ; } String requestURI = getPathWithinApplication(request); for (String pathPattern : filterChainManager.getChainNames()) { String[] pathPatternArray = pathPattern.split("==" ); if (pathMatches(pathPatternArray[0 ], requestURI)) { if (log.isTraceEnabled()) { log.trace("Matched path pattern [" + pathPattern + "] for requestURI [" + requestURI + "]. " + "Utilizing corresponding filter chain..." ); } return filterChainManager.proxy(originalChain, pathPattern); } } return null ; } }
然后替换掉 PathMatchingFilterChainResolver
,它是在 ShiroFilterFactoryBean
的 createInstance
方法里初始化的。
所以同样的套路,继承 ShiroFilterFactoryBean
并重写 createInstance
方法,将 new PathMatchingFilterChainResolver();
改为 new RestPathMatchingFilterChainResolver();
即可。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 package im.zhaojun.shiro;import org.apache.shiro.mgt.SecurityManager;import org.apache.shiro.spring.web.ShiroFilterFactoryBean;import org.apache.shiro.web.filter.mgt.FilterChainManager;import org.apache.shiro.web.filter.mgt.FilterChainResolver;import org.apache.shiro.web.filter.mgt.PathMatchingFilterChainResolver;import org.apache.shiro.web.mgt.WebSecurityManager;import org.apache.shiro.web.servlet.AbstractShiroFilter;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.beans.factory.BeanInitializationException;public class RestShiroFilterFactoryBean extends ShiroFilterFactoryBean { private static final Logger log = LoggerFactory.getLogger(RestShiroFilterFactoryBean.class); @Override protected AbstractShiroFilter createInstance () { log.debug("Creating Shiro Filter instance." ); SecurityManager securityManager = getSecurityManager(); if (securityManager == null ) { String msg = "SecurityManager property must be set." ; throw new BeanInitializationException(msg); } if (!(securityManager instanceof WebSecurityManager)) { String msg = "The security manager does not implement the WebSecurityManager interface." ; throw new BeanInitializationException(msg); } FilterChainManager manager = createFilterChainManager(); PathMatchingFilterChainResolver chainResolver = new RestPathMatchingFilterChainResolver(); chainResolver.setFilterChainManager(manager); return new SpringShiroFilter((WebSecurityManager) securityManager, chainResolver); } private static final class SpringShiroFilter extends AbstractShiroFilter { protected SpringShiroFilter (WebSecurityManager webSecurityManager, FilterChainResolver resolver) { super (); if (webSecurityManager == null ) { throw new IllegalArgumentException("WebSecurityManager property cannot be null." ); } setSecurityManager(webSecurityManager); if (resolver != null ) { setFilterChainResolver(resolver); } } } }
最后记得将 ShiroFilterFactoryBean
改为 RestShiroFilterFactoryBean
。
XML 方式:
1 2 3 <bean id ="shiroFilter" class ="im.zhaojun.shiro.RestShiroFilterFactoryBean" > </bean >
Bean 方式:
1 2 3 4 5 6 @Bean public ShiroFilterFactoryBean shirFilter (SecurityManager securityManager) { ShiroFilterFactoryBean shiroFilterFactoryBean = new RestShiroFilterFactoryBean(); return shiroFilterFactoryBean; }
修改过滤器的 pathsMatch 方法 同样新建一个类继承原有的 PermissionsAuthorizationFilter
并重写 pathsMatch
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 package im.zhaojun.shiro.filter;import org.apache.shiro.subject.Subject;import org.apache.shiro.util.StringUtils;import org.apache.shiro.web.filter.authz.PermissionsAuthorizationFilter;import org.apache.shiro.web.util.WebUtils;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import javax.servlet.ServletRequest;import javax.servlet.ServletResponse;import javax.servlet.http.HttpServletResponse;import java.io.IOException;import java.util.HashMap;import java.util.Map;public class RestAuthorizationFilter extends PermissionsAuthorizationFilter { private static final Logger log = LoggerFactory .getLogger(RestAuthorizationFilter.class); @Override protected boolean pathsMatch (String path, ServletRequest request) { String requestURI = this .getPathWithinApplication(request); String[] strings = path.split("==" ); if (strings.length <= 1 ) { return this .pathsMatch(strings[0 ], requestURI); } else { String httpMethod = WebUtils.toHttp(request).getMethod().toUpperCase(); return httpMethod.equals(strings[1 ].toUpperCase()) && this .pathsMatch(strings[0 ], requestURI); } } }
修改过滤器的 onAccessDenied 方法 同样是上一步的类,重写 onAccessDenied
方法即可:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 @Override protected boolean onAccessDenied (ServletRequest request, ServletResponse response) throws IOException { Subject subject = getSubject(request, response); if (subject.getPrincipal() == null ) { if (im.zhaojun.util.WebUtils.isAjaxRequest(WebUtils.toHttp(request))) { if (log.isDebugEnabled()) { log.debug("用户: [{}] 请求 restful url : {}, 未登录被拦截." , subject.getPrincipal(), this .getPathWithinApplication(request)); } Map<String, Object> map = new HashMap<>(); map.put("code" , -1 ); im.zhaojun.util.WebUtils.writeJson(map, response); } else { saveRequestAndRedirectToLogin(request, response); } } else { if (im.zhaojun.util.WebUtils.isAjaxRequest(WebUtils.toHttp(request))) { if (log.isDebugEnabled()) { log.debug("用户: [{}] 请求 restful url : {}, 无权限被拦截." , subject.getPrincipal(), this .getPathWithinApplication(request)); } Map<String, Object> map = new HashMap<>(); map.put("code" , -2 ); map.put("msg" , "没有权限啊!" ); im.zhaojun.util.WebUtils.writeJson(map, response); } else { String unauthorizedUrl = getUnauthorizedUrl(); if (StringUtils.hasText(unauthorizedUrl)) { WebUtils.issueRedirect(request, response, unauthorizedUrl); } else { WebUtils.toHttp(response).sendError(HttpServletResponse.SC_UNAUTHORIZED); } } } return false ; }
重写完 pathsMatch
和 onAccessDenied
方法后,将这个类替换原有的 perms
过滤器的类:
XML 方式:
1 2 3 4 5 6 7 8 9 10 <bean id ="shiroFilter" class ="im.zhaojun.shiro.RestShiroFilterFactoryBean" > <property name ="filters" > <map > <entry key ="perms" value-ref ="restAuthorizationFilter" /> </map > </property > </bean > <bean id ="restAuthorizationFilter" class ="im.zhaojun.shiro.filter.RestAuthorizationFilter" />
Bean 方式:
1 2 3 4 5 6 7 8 9 @Bean public ShiroFilterFactoryBean shirFilter (SecurityManager securityManager) { ShiroFilterFactoryBean shiroFilterFactoryBean = new RestShiroFilterFactoryBean(); Map<String, Filter> filters = shiroFilterFactoryBean.getFilters(); filters.put("perms" , new RestAuthorizationFilter()); return shiroFilterFactoryBean; }
这里只改了 perms
过滤器,对于其他过滤器也是同样的道理,重写过滤器的 pathsMatch
和 onAccessDenied
方法,并覆盖原有过滤器即可。
附 上面用到的工具类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 package im.zhaojun.util;import cn.hutool.json.JSONUtil;import com.alibaba.druid.support.json.JSONUtils;import org.springframework.boot.jackson.JsonComponent;import javax.servlet.ServletResponse;import javax.servlet.http.HttpServletRequest;import java.io.IOException;import java.io.PrintWriter;import java.util.Map;public class WebUtils { public static boolean isAjaxRequest (HttpServletRequest request) { String requestedWith = request.getHeader("x-requested-with" ); return requestedWith != null && "XMLHttpRequest" .equalsIgnoreCase(requestedWith); } public static void writeJson (Object obj, ServletResponse response) { PrintWriter out = null ; try { response.setCharacterEncoding("UTF-8" ); response.setContentType("application/json; charset=utf-8" ); out = response.getWriter(); out.write(JSONUtil.toJsonStr(obj)); } catch (IOException e) { e.printStackTrace(); } finally { if (out != null ) { out.close(); } } } }
结语 基本的过程就是这些,这是我在学习 Shiro 的过程中的一些见解,希望可以帮助到大家。具体应用的项目地址为:https://github.com/zhaojun1998/Shiro-Action ,功能在不断完善中,代码可能有些粗糙,还请见谅。