shiro权限控制(二):分布式架构中shiro的实现

时间:2022-03-22 10:38:57

前言:
前段时间在搭建公司游戏框架安全验证的时候,就想到之前web最火的shiro框架,虽然后面实践发现在netty中不太适用,最后自己模仿shiro写了一个缩减版的,但是中间花费两天时间弄出来的shiro可不能白费,这里给大家出个简单的教程说明吧。

shiro的基本介绍这里就不再说了,可以自行翻阅博主之前写的shiro教程,这篇文章主要说明分布式架构下shiro的session共享问题。

一、原理描述

无论分布式、还是集群下,项目都需要获取登录用户的信息,而不可能做的就是让客户在每个系统或者每个模块中反复登录,也不存在让客户端存载用户信息给服务端,这是很常识的问题

而单机模式下,我们用shiro做了登录验证,他的主要方式就是在第一次登陆的时候,把我们设置的用户信息保存在cache(内存)中和自带的ehcahe(缓存管理器)中,然后给客户端一个cookie,在每次客户端访问时获取cookie值,从而得到用户信息。

好了,那么逻辑就清楚了,分布式架构下,要与多系统共享用户信息,其实就是共享shiro保存的cache。

要在多项目*享,内存是不可能的了,ehcache对分布式支持不太好,或者说根本不支持。那么剩下只能是我么熟悉的mysql,redis,mongdb啥的数据库了。这么一对比,不用我说大家也明白了,最适合的无疑是redis了,速度快,主从啥的。

二、流程描述

查看源码我们可以知道,cacheManager最终会被set到sessionDAO中,所以我们要自己写sessionDAO。有两个类去操作保存的,那么我们只需要重写,实现这两个类,然后在注册的时候声明即可。

1.shiroCache:cache类,可以自己写一个定时消除的MAP存放更好,文章结尾我会给出map的代码。而这里的代码我是放在redis的。

package com.result.shiro.distributed;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import org.apache.shiro.cache.Cache;
import org.apache.shiro.cache.CacheException;
import com.result.redis.RedisKey;
import com.result.redis.RedisUtil;
import com.result.tools.KyroUtil;
/**
* @author 作者 huangxinyu
* @version 创建时间:2018年1月8日 下午9:33:23
* cache共享
*/
@SuppressWarnings("unchecked")
public class ShiroCache<K, V> implements Cache<K, V> {
private static final String REDIS_SHIRO_CACHE = RedisKey.CACHEKEY;
private String cacheKey;
private long globExpire = 30;
@SuppressWarnings("rawtypes")
public ShiroCache(String name) {
this.cacheKey = REDIS_SHIRO_CACHE + name + ":";
}
@Override
public V get(K key) throws CacheException {
Object obj = RedisUtil.get(KyroUtil.serialization(getCacheKey(key)));
if(obj==null){
return null;
}
return (V) KyroUtil.deserialization((String)obj);
}
@Override
public V put(K key, V value) throws CacheException {
V old = get(key);
RedisUtil.setex(KyroUtil.serialization(getCacheKey(key)), 18000, KyroUtil.serialization(value));
return old;
}
@Override
public V remove(K key) throws CacheException {
V old = get(key);
RedisUtil.del(KyroUtil.serialization(getCacheKey(key)));
return old;
}
@Override
public void clear() throws CacheException {
for(String key : (Set<String>)keys()){
RedisUtil.del(key);
}
}
@Override
public int size() {
return keys().size();
}
@Override
public Set<K> keys() {
return (Set<K>) RedisUtil.keys(KyroUtil.serialization(getCacheKey("*")));
}
@Override
public Collection<V> values() {
Set<K> set = keys();
List<V> list = new ArrayList<>();
for (K s : set) {
list.add(get(s));
}
return list;
}
private K getCacheKey(Object k) {
return (K) (this.cacheKey + k);
}
}

2.session操作类:这里用来把用户信息存放在redis*享的。

package com.result.shiro.distributed;
/**
* @author 作者 huangxinyu
* @version 创建时间:2018年1月6日 上午10:12:42
* redis实现共享session
*/
import java.io.Serializable;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.eis.EnterpriseCacheSessionDAO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.result.redis.RedisKey;
import com.result.redis.RedisUtil;
import com.result.tools.KyroUtil;
import com.result.tools.SerializationUtil;
public class RedisSessionDao extends EnterpriseCacheSessionDAO {
private static Logger logger = LoggerFactory.getLogger(RedisSessionDao.class);
@Override
public void update(Session session) throws UnknownSessionException {
this.saveSession(session);
}
/**
* 删除session
*/
@Override
public void delete(Session session) {
if (session == null || session.getId() == null) {
logger.error("==========session或sessionI 不存在");
return;
}
RedisUtil.del(KyroUtil.serialization(RedisKey.SESSIONKEY + session.getId()));
}
/**
* 获取存活的sessions
*/
@Override
public Collection<Session> getActiveSessions() {
Set<Session> sessions = new HashSet<>();
Set<String> keys = RedisUtil.keys(KyroUtil.serialization(RedisKey.SESSIONKEY + "*"));
for(String key:keys){
sessions.add((Session)KyroUtil.deserialization((String)RedisUtil.get(key)));
}
return sessions;
}
/**
* 创建session
*/
@Override
protected Serializable doCreate(Session session) {
Serializable sessionId = this.generateSessionId(session);
this.assignSessionId(session, sessionId);
this.saveSession(session);
return sessionId;
}
/**
* 获取session
*/
@Override
protected Session doReadSession(Serializable sessionId) {
if(sessionId == null){
logger.error("==========session id 不存在");
return null;
}
Object obj = RedisUtil.get(KyroUtil.serialization(RedisKey.SESSIONKEY + sessionId));
if(obj==null){
return null;
}
Session s = (Session)KyroUtil.deserialization((String)obj);
return s;
}
/**
* 保存session并存储过期时间
* @param session
* @throws UnknownSessionException
*/
public static void saveSession(String sessionId,Object obj) throws UnknownSessionException{
if (obj == null) {
logger.error("要存入的session为空");
return;
}
//设置过期时间
int expireTime = 1800;
RedisUtil.setex(sessionId,expireTime,SerializationUtil.serializeToString(obj));
}
}
然后还有一个类也是必要的 package com.result.shiro.distributed;
import org.apache.shiro.cache.Cache;
import org.apache.shiro.cache.CacheException;
import org.apache.shiro.cache.CacheManager;
/**
* @author 作者 huangxinyu
* @version 创建时间:2018年1月8日 下午9:32:41
* 类说明
*/
public class RedisCacheManager implements CacheManager {
@Override
public <K, V> Cache<K, V> getCache(String name) throws CacheException {
return new ShiroCache<K, V>(name);
}
}

三:辅助类说明

用户信息的session存放在redis中肯定是需要序列化的,然而用json这种可读性太强的东西安全性显得极低,而且长度太大,浪费存储空间和IO。所以需要找其他的序列化工具。

常规的好用的序列化工具有kyro,protobuff,这些是性能极高而且序列化之后长度极小的序列化工具,其中protobuf支持跨语言。不过这些在之后的文章再和大家介绍去了,因为~!!session不支持这两种操作(因为上面两个类中操作的session实际是一个接口)。

那么序列化用的什么,emmmm~一个很原生的东西,测试效率也挺高的,和protobuf差不太多。下面贴出的代码实际就是上面类中kyroUtils中的方法,因为shiro分布式在项目中被废掉了,我也没去改名字了。大家自己看仔细点就可以了。

被注释掉的代码是kyro的序列化工具。

package com.result.tools;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author 作者 huangxinyu
* @version 创建时间:2018年1月6日 下午2:22:14
* Kryo工具类
*/
public class KyroUtil {
private static Logger logger = LoggerFactory.getLogger(KyroUtil.class);
//private static KryoPool pool;
//原本打算使用kyro序列化session,后来发现kyro对session序列化不支持,反序列后得不到value。 这种out序列化测试性能消耗时间更短,但是长度变大4倍意思,待优化
// static{
// KryoFactory factory = new KryoFactory() {
// public Kryo create() {
// Kryo kryo = new Kryo();
// kryo.setReferences(false);
// //把shiroSession的结构注册到Kryo注册器里面,提高序列化/反序列化效率
// kryo.register(Session.class, new JavaSerializer());
// kryo.register(String.class, new JavaSerializer());
// kryo.register(User.class, new JavaSerializer());
// kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
// return kryo;
// }
// };
// pool = new KryoPool.Builder(factory).build();
// logger.info("KryoPool初始化成功====================================");
// }
/**
* 对象编码
* @param value
* @return
*/
public static String serialization(Object value) {
// String str ="";
// try {
// Kryo kryo = pool.borrow();
// ByteArrayOutputStream baos = new ByteArrayOutputStream();
// Output output = new Output(baos);
// kryo.writeClassAndObject(output, value);
// output.flush();
// output.close();
// byte[] b = baos.toByteArray();
// baos.flush();
// baos.close();
// str = new String(b, "ISO8859-1");
// } catch (IOException e) {
// e.printStackTrace();
// }
// return str;
//
ByteArrayOutputStream bos = null;
ObjectOutputStream oos = null;
try {
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(bos);
oos.writeObject(value);
return new String(bos.toByteArray(), "ISO8859-1");
} catch (Exception e) {
throw new RuntimeException("serialize session error", e);
} finally {
try {
oos.close();
bos.close();
} catch (IOException e) {
e.printStackTrace();
}
}
// return new String(new Base64().encode(b));
}
/**
* 对象解码
* @param <T>
* @param <T>
* @param obj
* @param clazz
* @return
*/
public static Object deserialization(String obj) {
// try {
// Kryo kryo = pool.borrow();
// ByteArrayInputStream bais;
// bais = new ByteArrayInputStream(obj.getBytes("ISO8859-1"));
// //new Base64().decode(obj));
// Input input = new Input(bais);
// return kryo.readClassAndObject(input);
// } catch (UnsupportedEncodingException e) {
// // TODO Auto-generated catch block
// e.printStackTrace();
// }
// return null;
ByteArrayInputStream bis = null;
ObjectInputStream ois = null;
try {
bis = new ByteArrayInputStream(obj.getBytes("ISO8859-1"));
ois = new ObjectInputStream(bis);
return ois.readObject();
} catch (Exception e) {
throw new RuntimeException("deserialize session error", e);
} finally {
try {
ois.close();
bis.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}

四、注册

好了,该重写的都重写了,那么最后一步就是整合spring的时候我们要告诉spring,我们要用的是我们重写过的sessiondao了。

我这里用的是代码的方式,因为某些原因在写框架的时候不太好用xml去整合。

反正原理都差不多,大家看看就明白了:

package com.business.shiro;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.shiro.authc.credential.HashedCredentialsMatcher;
import org.apache.shiro.cache.CacheManager;
import org.apache.shiro.cache.ehcache.EhCacheManager;
import org.apache.shiro.codec.Base64;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.session.mgt.ExecutorServiceSessionValidationScheduler;
import org.apache.shiro.session.mgt.eis.EnterpriseCacheSessionDAO;
import org.apache.shiro.spring.LifecycleBeanPostProcessor;
import org.apache.shiro.spring.security.interceptor.AuthorizationAttributeSourceAdvisor;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.CookieRememberMeManager;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.apache.shiro.web.servlet.SimpleCookie;
import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;
import org.springframework.aop.framework.autoproxy.DefaultAdvisorAutoProxyCreator;
import org.springframework.beans.factory.config.MethodInvokingFactoryBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import com.result.shiro.distributed.RedisCacheManager;
import com.result.shiro.distributed.RedisSessionDao;
/**
* @author 作者 huangxinyu
* @version 创建时间:2018年1月8日 下午8:29:12
* 类说明
*/
@Configuration
public class ShiroConfiguration {
private static Map<String, String> filterChainDefinitionMap = new LinkedHashMap<String, String>();
@Bean(name = "cacheShiroManager")
public CacheManager getCacheManage() {
return new RedisCacheManager();
}
@Bean(name = "lifecycleBeanPostProcessor")
public LifecycleBeanPostProcessor getLifecycleBeanPostProcessor() {
return new LifecycleBeanPostProcessor();
}
@Bean(name = "sessionValidationScheduler")
public ExecutorServiceSessionValidationScheduler getExecutorServiceSessionValidationScheduler() {
ExecutorServiceSessionValidationScheduler scheduler = new ExecutorServiceSessionValidationScheduler();
scheduler.setInterval(900000);
return scheduler;
}
@Bean(name = "hashedCredentialsMatcher")
public HashedCredentialsMatcher getHashedCredentialsMatcher() {
HashedCredentialsMatcher credentialsMatcher = new HashedCredentialsMatcher();
credentialsMatcher.setHashAlgorithmName("MD5");
credentialsMatcher.setHashIterations(1);
credentialsMatcher.setStoredCredentialsHexEncoded(true);
return credentialsMatcher;
}
@Bean(name = "sessionIdCookie")
public SimpleCookie getSessionIdCookie() {
SimpleCookie cookie = new SimpleCookie("sid");
cookie.setHttpOnly(true);
cookie.setMaxAge(-1);
return cookie;
}
@Bean(name = "rememberMeCookie")
public SimpleCookie getRememberMeCookie() {
SimpleCookie simpleCookie = new SimpleCookie("rememberMe");
simpleCookie.setHttpOnly(true);
simpleCookie.setMaxAge(2592000);
return simpleCookie;
}
@Bean
public CookieRememberMeManager getRememberManager(){
CookieRememberMeManager meManager = new CookieRememberMeManager();
meManager.setCipherKey(Base64.decode("4AvVhmFLUs0KTA3Kprsdag=="));
meManager.setCookie(getRememberMeCookie());
return meManager;
}
@Bean(name = "sessionManager")
public DefaultWebSessionManager getSessionManage() {
DefaultWebSessionManager sessionManager = new DefaultWebSessionManager();
sessionManager.setGlobalSessionTimeout(1800000);
sessionManager.setSessionValidationScheduler(getExecutorServiceSessionValidationScheduler());
sessionManager.setSessionValidationSchedulerEnabled(true);
sessionManager.setDeleteInvalidSessions(true);
sessionManager.setSessionIdCookieEnabled(true);
sessionManager.setSessionIdCookie(getSessionIdCookie());
RedisSessionDao cacheSessionDAO = new RedisSessionDao();
cacheSessionDAO.setCacheManager(getCacheManage());
sessionManager.setSessionDAO(cacheSessionDAO);
// -----可以添加session 创建、删除的监听器
return sessionManager;
}
@Bean(name = "myRealm")
public AuthorizingRealm getShiroRealm() {
MyRealm realm = new MyRealm();
// realm.setName("shiro_auth_cache");
// realm.setAuthenticationCache(getCacheManage().getCache(realm.getName()));
// realm.setAuthenticationTokenClass(UserAuthenticationToken.class);
return realm;
}
@Bean(name = "securityManager")
public DefaultWebSecurityManager getSecurityManager() {
DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
securityManager.setCacheManager(getCacheManage());
securityManager.setSessionManager(getSessionManage());
securityManager.setRememberMeManager(getRememberManager());
securityManager.setRealm(getShiroRealm());
return securityManager;
}
@Bean
public MethodInvokingFactoryBean getMethodInvokingFactoryBean(){
MethodInvokingFactoryBean factoryBean = new MethodInvokingFactoryBean();
factoryBean.setStaticMethod("org.apache.shiro.SecurityUtils.setSecurityManager");
factoryBean.setArguments(new Object[]{getSecurityManager()});
return factoryBean;
}
@Bean
@DependsOn("lifecycleBeanPostProcessor")
public DefaultAdvisorAutoProxyCreator getAutoProxyCreator(){
DefaultAdvisorAutoProxyCreator creator = new DefaultAdvisorAutoProxyCreator();
creator.setProxyTargetClass(true);
return creator;
}
@Bean
public AuthorizationAttributeSourceAdvisor getAuthorizationAttributeSourceAdvisor(){
AuthorizationAttributeSourceAdvisor advisor = new AuthorizationAttributeSourceAdvisor();
advisor.setSecurityManager(getSecurityManager());
return advisor;
}
/**
* @return
*/
@Bean(name = "shiroFilter")
public ShiroFilterFactoryBean getShiroFilterFactoryBean(){
ShiroFilterFactoryBean factoryBean = new ShiroFilterFactoryBean();
factoryBean.setSecurityManager(getSecurityManager());
factoryBean.setLoginUrl("/toLogin");
factoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);
return factoryBean;
}
}

优化:伪定时消除map,最好配合quartz清楚,不然内存中MAP如果不访问就不消除,容易累计。

package com.result.security;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import com.result.NettyGoConstant;
/**
* @author 作者 huangxinyu
* @version 创建时间:2018年1月29日 上午10:31:50 类说明
*/
public class ExpiryMap<K, V> extends HashMap<K, V> {
private static final long serialVersionUID = 1L;
/**
* default expiry time 2m
*/
private long EXPIRY = NettyGoConstant.LOGINSESSIONTIMEOUT;
private HashMap<K, Long> expiryMap = new HashMap<>();
public ExpiryMap() {
super();
}
public ExpiryMap(long defaultExpiryTime) {
this(1 << 4, defaultExpiryTime);
}
public ExpiryMap(int initialCapacity, long defaultExpiryTime) {
super(initialCapacity);
this.EXPIRY = defaultExpiryTime;
}
public V put(K key, V value) {
expiryMap.put(key, System.currentTimeMillis() + EXPIRY);
return super.put(key, value);
}
public boolean containsKey(Object key) {
return !checkExpiry(key, true) && super.containsKey(key);
}
/**
* @param key
* @param value
* @param expiryTime
* 键值对有效期 毫秒
* @return
*/
public V put(K key, V value, long expiryTime) {
expiryMap.put(key, System.currentTimeMillis() + expiryTime);
return super.put(key, value);
}
public int size() {
return entrySet().size();
}
public boolean isEmpty() {
return entrySet().size() == 0;
}
public boolean containsValue(Object value) {
if (value == null)
return Boolean.FALSE;
Set<java.util.Map.Entry<K, V>> set = super.entrySet();
Iterator<java.util.Map.Entry<K, V>> iterator = set.iterator();
while (iterator.hasNext()) {
java.util.Map.Entry<K, V> entry = iterator.next();
if (value.equals(entry.getValue())) {
if (checkExpiry(entry.getKey(), false)) {
iterator.remove();
return Boolean.FALSE;
} else
return Boolean.TRUE;
}
}
return Boolean.FALSE;
}
public Collection<V> values() {
Collection<V> values = super.values();
if (values == null || values.size() < 1)
return values;
Iterator<V> iterator = values.iterator();
while (iterator.hasNext()) {
V next = iterator.next();
if (!containsValue(next))
iterator.remove();
}
return values;
}
public V get(Object key) {
if (key == null)
return null;
if (checkExpiry(key, true))
return null;
return super.get(key);
}
/**
*
* @Description: 是否过期
* @param key
* @return null:不存在或key为null -1:过期 存在且没过期返回value 因为过期的不是实时删除,所以稍微有点作用
*/
public Object isInvalid(Object key) {
if (key == null)
return null;
if (!expiryMap.containsKey(key)) {
return null;
}
long expiryTime = expiryMap.get(key);
boolean flag = System.currentTimeMillis() > expiryTime;
if (flag) {
super.remove(key);
expiryMap.remove(key);
return -1;
}
return super.get(key);
}
public void putAll(Map<? extends K, ? extends V> m) {
for (Map.Entry<? extends K, ? extends V> e : m.entrySet())
expiryMap.put(e.getKey(), System.currentTimeMillis() + EXPIRY);
super.putAll(m);
}
public Set<Map.Entry<K, V>> entrySet() {
Set<java.util.Map.Entry<K, V>> set = super.entrySet();
Iterator<java.util.Map.Entry<K, V>> iterator = set.iterator();
while (iterator.hasNext()) {
java.util.Map.Entry<K, V> entry = iterator.next();
if (checkExpiry(entry.getKey(), false))
iterator.remove();
}
return set;
}
/**
*
* @Description: 是否过期
* @author: qd-ankang
* @date: 2016-11-24 下午4:05:02
* @param expiryTime
* true 过期
* @param isRemoveSuper
* true super删除
* @return
*/
private boolean checkExpiry(Object key, boolean isRemoveSuper) {
if (!expiryMap.containsKey(key)) {
return Boolean.FALSE;
}
long expiryTime = expiryMap.get(key);
boolean flag = System.currentTimeMillis() > expiryTime;
if (flag) {
if (isRemoveSuper)
super.remove(key);
expiryMap.remove(key);
}
return flag;
}
/**
* 删除
* @param key
*/
public void del(Object key){
super.remove(key);
expiryMap.remove(key);
}
public static void main(String[] args) throws InterruptedException {
ExpiryMap<String, String> map = new ExpiryMap<>(10);
map.put("test", "ankang");
map.put("test1", "ankang");
map.put("test2", "ankang", 3000);
System.out.println("test1" + map.get("test"));
Thread.sleep(1000);
System.out.println("isInvalid:" + map.isInvalid("test"));
System.out.println("size:" + map.size());
System.out.println("size:" + ((HashMap<String, String>) map).size());
for (Map.Entry<String, String> m : map.entrySet()) {
System.out.println("isInvalid:" + map.isInvalid(m.getKey()));
map.containsKey(m.getKey());
System.out.println("key:" + m.getKey() + " value:" + m.getValue());
}
System.out.println("test1" + map.get("test"));
}
/**
* 是否超过过期的一半时间
* @param key
* @return
*/
public boolean isHalfExpiryTime(Object key ){
if (!expiryMap.containsKey(key)) {
return false;
}
long expiryTime = expiryMap.get(key);
boolean flag = System.currentTimeMillis()-(expiryTime-NettyGoConstant.LOGINSESSIONTIMEOUT)>=NettyGoConstant.LOGINSESSIONTIMEOUT/2;
return flag;
}
}