前言:
前段時間在搭建公司游戲框架安全驗證的時候,就想到之前web最火的shiro框架,雖然后面實踐發現在netty中不太適用,最后自己模仿shiro寫了一個縮減版的,但是中間花費兩天時間弄出來的shiro可不能白費,這里給大家出個簡單的教程說明吧。
shiro的基本介紹這里就不再說了,可以自行翻閱博主之前寫的shiro教程,這篇文章主要說明分布式架構下shiro的session共享問題。
一、原理描述
無論分布式、還是集群下,項目都需要獲取登錄用戶的信息,而不可能做的就是讓客戶在每個系統或者每個模塊中反復登錄,也不存在讓客戶端存載用戶信息給服務端,這是很常識的問題
而單機模式下,我們用shiro做了登錄驗證,他的主要方式就是在第一次登陸的時候,把我們設置的用戶信息保存在cache(內存)中和自帶的ehcahe(緩存管理器)中,然后給客戶端一個cookie,在每次客戶端訪問時獲取cookie值,從而得到用戶信息。
好了,那么邏輯就清楚了,分布式架構下,要與多系統共享用戶信息,其實就是共享shiro保存的cache。
要在多項目中共享,內存是不可能的了,ehcache對分布式支持不太好,或者說根本不支持。那么剩下只能是我么熟悉的mysql,redis,mongdb啥的數據庫了。這么一對比,不用我說大家也明白了,最適合的無疑是redis了,速度快,主從啥的。
二、流程描述
查看源碼我們可以知道,cacheManager最終會被set到sessionDAO中,所以我們要自己寫sessionDAO。有兩個類去操作保存的,那么我們只需要重寫,實現這兩個類,然后在注冊的時候聲明即可。
1.shiroCache:cache類,可以自己寫一個定時消除的MAP存放更好,文章結尾我會給出map的代碼。而這里的代碼我是放在redis的。
package com.result.shiro.distributed; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Set; import org.apache.shiro.cache.Cache; import org.apache.shiro.cache.CacheException; import com.result.redis.RedisKey; import com.result.redis.RedisUtil; import com.result.tools.KyroUtil; /** * @author 作者 huangxinyu * @version 創建時間:2018年1月8日 下午9:33:23 * cache共享 */ @SuppressWarnings("unchecked") public class ShiroCache<K, V> implements Cache<K, V> { private static final String REDIS_SHIRO_CACHE = RedisKey.CACHEKEY; private String cacheKey; private long globExpire = 30; @SuppressWarnings("rawtypes") public ShiroCache(String name) { this.cacheKey = REDIS_SHIRO_CACHE + name + ":"; } @Override public V get(K key) throws CacheException { Object obj = RedisUtil.get(KyroUtil.serialization(getCacheKey(key))); if(obj==null){ return null; } return (V) KyroUtil.deserialization((String)obj); } @Override public V put(K key, V value) throws CacheException { V old = get(key); RedisUtil.setex(KyroUtil.serialization(getCacheKey(key)), 18000, KyroUtil.serialization(value)); return old; } @Override public V remove(K key) throws CacheException { V old = get(key); RedisUtil.del(KyroUtil.serialization(getCacheKey(key))); return old; } @Override public void clear() throws CacheException { for(String key : (Set<String>)keys()){ RedisUtil.del(key); } } @Override public int size() { return keys().size(); } @Override public Set<K> keys() { return (Set<K>) RedisUtil.keys(KyroUtil.serialization(getCacheKey("*"))); } @Override public Collection<V> values() { Set<K> set = keys(); List<V> list = new ArrayList<>(); for (K s : set) { list.add(get(s)); } return list; } private K getCacheKey(Object k) { return (K) (this.cacheKey + k); } }
2.session操作類:這里用來把用戶信息存放在redis中共享的。
package com.result.shiro.distributed; /** * @author 作者 huangxinyu * @version 創建時間:2018年1月6日 上午10:12:42 * redis實現共享session */ import java.io.Serializable; import java.util.Collection; import java.util.HashSet; import java.util.Set; import org.apache.shiro.session.Session; import org.apache.shiro.session.UnknownSessionException; import org.apache.shiro.session.mgt.eis.EnterpriseCacheSessionDAO; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.result.redis.RedisKey; import com.result.redis.RedisUtil; import com.result.tools.KyroUtil; import com.result.tools.SerializationUtil; public class RedisSessionDao extends EnterpriseCacheSessionDAO { private static Logger logger = LoggerFactory.getLogger(RedisSessionDao.class); @Override public void update(Session session) throws UnknownSessionException { this.saveSession(session); } /** * 刪除session */ @Override public void delete(Session session) { if (session == null || session.getId() == null) { logger.error("==========session或sessionI 不存在"); return; } RedisUtil.del(KyroUtil.serialization(RedisKey.SESSIONKEY + session.getId())); } /** * 獲取存活的sessions */ @Override public Collection<Session> getActiveSessions() { Set<Session> sessions = new HashSet<>(); Set<String> keys = RedisUtil.keys(KyroUtil.serialization(RedisKey.SESSIONKEY + "*")); for(String key:keys){ sessions.add((Session)KyroUtil.deserialization((String)RedisUtil.get(key))); } return sessions; } /** * 創建session */ @Override protected Serializable doCreate(Session session) { Serializable sessionId = this.generateSessionId(session); this.assignSessionId(session, sessionId); this.saveSession(session); return sessionId; } /** * 獲取session */ @Override protected Session doReadSession(Serializable sessionId) { if(sessionId == null){ logger.error("==========session id 不存在"); return null; } Object obj = RedisUtil.get(KyroUtil.serialization(RedisKey.SESSIONKEY + sessionId)); if(obj==null){ return null; } Session s = (Session)KyroUtil.deserialization((String)obj); return s; } /** * 保存session並存儲過期時間 * @param session * @throws UnknownSessionException */ public static void saveSession(String sessionId,Object obj) throws UnknownSessionException{ if (obj == null) { logger.error("要存入的session為空"); return; } //設置過期時間 int expireTime = 1800; RedisUtil.setex(sessionId,expireTime,SerializationUtil.serializeToString(obj)); } } 然后還有一個類也是必要的 package com.result.shiro.distributed; import org.apache.shiro.cache.Cache; import org.apache.shiro.cache.CacheException; import org.apache.shiro.cache.CacheManager; /** * @author 作者 huangxinyu * @version 創建時間:2018年1月8日 下午9:32:41 * 類說明 */ public class RedisCacheManager implements CacheManager { @Override public <K, V> Cache<K, V> getCache(String name) throws CacheException { return new ShiroCache<K, V>(name); } }
三:輔助類說明
用戶信息的session存放在redis中肯定是需要序列化的,然而用json這種可讀性太強的東西安全性顯得極低,而且長度太大,浪費存儲空間和IO。所以需要找其他的序列化工具。
常規的好用的序列化工具有kyro,protobuff,這些是性能極高而且序列化之后長度極小的序列化工具,其中protobuf支持跨語言。不過這些在之后的文章再和大家介紹去了,因為~!!session不支持這兩種操作(因為上面兩個類中操作的session實際是一個接口)。
那么序列化用的什么,emmmm~一個很原生的東西,測試效率也挺高的,和protobuf差不太多。下面貼出的代碼實際就是上面類中kyroUtils中的方法,因為shiro分布式在項目中被廢掉了,我也沒去改名字了。大家自己看仔細點就可以了。
被注釋掉的代碼是kyro的序列化工具。
package com.result.tools; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author 作者 huangxinyu * @version 創建時間:2018年1月6日 下午2:22:14 * Kryo工具類 */ public class KyroUtil { private static Logger logger = LoggerFactory.getLogger(KyroUtil.class); //private static KryoPool pool; //原本打算使用kyro序列化session,后來發現kyro對session序列化不支持,反序列后得不到value。 這種out序列化測試性能消耗時間更短,但是長度變大4倍意思,待優化 // static{ // KryoFactory factory = new KryoFactory() { // public Kryo create() { // Kryo kryo = new Kryo(); // kryo.setReferences(false); // //把shiroSession的結構注冊到Kryo注冊器里面,提高序列化/反序列化效率 // kryo.register(Session.class, new JavaSerializer()); // kryo.register(String.class, new JavaSerializer()); // kryo.register(User.class, new JavaSerializer()); // kryo.setInstantiatorStrategy(new StdInstantiatorStrategy()); // return kryo; // } // }; // pool = new KryoPool.Builder(factory).build(); // logger.info("KryoPool初始化成功===================================="); // } /** * 對象編碼 * @param value * @return */ public static String serialization(Object value) { // String str =""; // try { // Kryo kryo = pool.borrow(); // ByteArrayOutputStream baos = new ByteArrayOutputStream(); // Output output = new Output(baos); // kryo.writeClassAndObject(output, value); // output.flush(); // output.close(); // byte[] b = baos.toByteArray(); // baos.flush(); // baos.close(); // str = new String(b, "ISO8859-1"); // } catch (IOException e) { // e.printStackTrace(); // } // return str; // ByteArrayOutputStream bos = null; ObjectOutputStream oos = null; try { bos = new ByteArrayOutputStream(); oos = new ObjectOutputStream(bos); oos.writeObject(value); return new String(bos.toByteArray(), "ISO8859-1"); } catch (Exception e) { throw new RuntimeException("serialize session error", e); } finally { try { oos.close(); bos.close(); } catch (IOException e) { e.printStackTrace(); } } // return new String(new Base64().encode(b)); } /** * 對象解碼 * @param <T> * @param <T> * @param obj * @param clazz * @return */ public static Object deserialization(String obj) { // try { // Kryo kryo = pool.borrow(); // ByteArrayInputStream bais; // bais = new ByteArrayInputStream(obj.getBytes("ISO8859-1")); // //new Base64().decode(obj)); // Input input = new Input(bais); // return kryo.readClassAndObject(input); // } catch (UnsupportedEncodingException e) { // // TODO Auto-generated catch block // e.printStackTrace(); // } // return null; ByteArrayInputStream bis = null; ObjectInputStream ois = null; try { bis = new ByteArrayInputStream(obj.getBytes("ISO8859-1")); ois = new ObjectInputStream(bis); return ois.readObject(); } catch (Exception e) { throw new RuntimeException("deserialize session error", e); } finally { try { ois.close(); bis.close(); } catch (IOException e) { e.printStackTrace(); } } } }
四、注冊
好了,該重寫的都重寫了,那么最后一步就是整合spring的時候我們要告訴spring,我們要用的是我們重寫過的sessiondao了。
我這里用的是代碼的方式,因為某些原因在寫框架的時候不太好用xml去整合。
反正原理都差不多,大家看看就明白了:
package com.business.shiro; import java.util.LinkedHashMap; import java.util.Map; import org.apache.shiro.authc.credential.HashedCredentialsMatcher; import org.apache.shiro.cache.CacheManager; import org.apache.shiro.cache.ehcache.EhCacheManager; import org.apache.shiro.codec.Base64; import org.apache.shiro.realm.AuthorizingRealm; import org.apache.shiro.session.mgt.ExecutorServiceSessionValidationScheduler; import org.apache.shiro.session.mgt.eis.EnterpriseCacheSessionDAO; import org.apache.shiro.spring.LifecycleBeanPostProcessor; import org.apache.shiro.spring.security.interceptor.AuthorizationAttributeSourceAdvisor; import org.apache.shiro.spring.web.ShiroFilterFactoryBean; import org.apache.shiro.web.mgt.CookieRememberMeManager; import org.apache.shiro.web.mgt.DefaultWebSecurityManager; import org.apache.shiro.web.servlet.SimpleCookie; import org.apache.shiro.web.session.mgt.DefaultWebSessionManager; import org.springframework.aop.framework.autoproxy.DefaultAdvisorAutoProxyCreator; import org.springframework.beans.factory.config.MethodInvokingFactoryBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.DependsOn; import com.result.shiro.distributed.RedisCacheManager; import com.result.shiro.distributed.RedisSessionDao; /** * @author 作者 huangxinyu * @version 創建時間:2018年1月8日 下午8:29:12 * 類說明 */ @Configuration public class ShiroConfiguration { private static Map<String, String> filterChainDefinitionMap = new LinkedHashMap<String, String>(); @Bean(name = "cacheShiroManager") public CacheManager getCacheManage() { return new RedisCacheManager(); } @Bean(name = "lifecycleBeanPostProcessor") public LifecycleBeanPostProcessor getLifecycleBeanPostProcessor() { return new LifecycleBeanPostProcessor(); } @Bean(name = "sessionValidationScheduler") public ExecutorServiceSessionValidationScheduler getExecutorServiceSessionValidationScheduler() { ExecutorServiceSessionValidationScheduler scheduler = new ExecutorServiceSessionValidationScheduler(); scheduler.setInterval(900000); return scheduler; } @Bean(name = "hashedCredentialsMatcher") public HashedCredentialsMatcher getHashedCredentialsMatcher() { HashedCredentialsMatcher credentialsMatcher = new HashedCredentialsMatcher(); credentialsMatcher.setHashAlgorithmName("MD5"); credentialsMatcher.setHashIterations(1); credentialsMatcher.setStoredCredentialsHexEncoded(true); return credentialsMatcher; } @Bean(name = "sessionIdCookie") public SimpleCookie getSessionIdCookie() { SimpleCookie cookie = new SimpleCookie("sid"); cookie.setHttpOnly(true); cookie.setMaxAge(-1); return cookie; } @Bean(name = "rememberMeCookie") public SimpleCookie getRememberMeCookie() { SimpleCookie simpleCookie = new SimpleCookie("rememberMe"); simpleCookie.setHttpOnly(true); simpleCookie.setMaxAge(2592000); return simpleCookie; } @Bean public CookieRememberMeManager getRememberManager(){ CookieRememberMeManager meManager = new CookieRememberMeManager(); meManager.setCipherKey(Base64.decode("4AvVhmFLUs0KTA3Kprsdag==")); meManager.setCookie(getRememberMeCookie()); return meManager; } @Bean(name = "sessionManager") public DefaultWebSessionManager getSessionManage() { DefaultWebSessionManager sessionManager = new DefaultWebSessionManager(); sessionManager.setGlobalSessionTimeout(1800000); sessionManager.setSessionValidationScheduler(getExecutorServiceSessionValidationScheduler()); sessionManager.setSessionValidationSchedulerEnabled(true); sessionManager.setDeleteInvalidSessions(true); sessionManager.setSessionIdCookieEnabled(true); sessionManager.setSessionIdCookie(getSessionIdCookie()); RedisSessionDao cacheSessionDAO = new RedisSessionDao(); cacheSessionDAO.setCacheManager(getCacheManage()); sessionManager.setSessionDAO(cacheSessionDAO); // -----可以添加session 創建、刪除的監聽器 return sessionManager; } @Bean(name = "myRealm") public AuthorizingRealm getShiroRealm() { MyRealm realm = new MyRealm(); // realm.setName("shiro_auth_cache"); // realm.setAuthenticationCache(getCacheManage().getCache(realm.getName())); // realm.setAuthenticationTokenClass(UserAuthenticationToken.class); return realm; } @Bean(name = "securityManager") public DefaultWebSecurityManager getSecurityManager() { DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager(); securityManager.setCacheManager(getCacheManage()); securityManager.setSessionManager(getSessionManage()); securityManager.setRememberMeManager(getRememberManager()); securityManager.setRealm(getShiroRealm()); return securityManager; } @Bean public MethodInvokingFactoryBean getMethodInvokingFactoryBean(){ MethodInvokingFactoryBean factoryBean = new MethodInvokingFactoryBean(); factoryBean.setStaticMethod("org.apache.shiro.SecurityUtils.setSecurityManager"); factoryBean.setArguments(new Object[]{getSecurityManager()}); return factoryBean; } @Bean @DependsOn("lifecycleBeanPostProcessor") public DefaultAdvisorAutoProxyCreator getAutoProxyCreator(){ DefaultAdvisorAutoProxyCreator creator = new DefaultAdvisorAutoProxyCreator(); creator.setProxyTargetClass(true); return creator; } @Bean public AuthorizationAttributeSourceAdvisor getAuthorizationAttributeSourceAdvisor(){ AuthorizationAttributeSourceAdvisor advisor = new AuthorizationAttributeSourceAdvisor(); advisor.setSecurityManager(getSecurityManager()); return advisor; } /** * @return */ @Bean(name = "shiroFilter") public ShiroFilterFactoryBean getShiroFilterFactoryBean(){ ShiroFilterFactoryBean factoryBean = new ShiroFilterFactoryBean(); factoryBean.setSecurityManager(getSecurityManager()); factoryBean.setLoginUrl("/toLogin"); factoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap); return factoryBean; } }
優化:偽定時消除map,最好配合quartz清楚,不然內存中MAP如果不訪問就不消除,容易累計。
package com.result.security; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Set; import com.result.NettyGoConstant; /** * @author 作者 huangxinyu * @version 創建時間:2018年1月29日 上午10:31:50 類說明 */ public class ExpiryMap<K, V> extends HashMap<K, V> { private static final long serialVersionUID = 1L; /** * default expiry time 2m */ private long EXPIRY = NettyGoConstant.LOGINSESSIONTIMEOUT; private HashMap<K, Long> expiryMap = new HashMap<>(); public ExpiryMap() { super(); } public ExpiryMap(long defaultExpiryTime) { this(1 << 4, defaultExpiryTime); } public ExpiryMap(int initialCapacity, long defaultExpiryTime) { super(initialCapacity); this.EXPIRY = defaultExpiryTime; } public V put(K key, V value) { expiryMap.put(key, System.currentTimeMillis() + EXPIRY); return super.put(key, value); } public boolean containsKey(Object key) { return !checkExpiry(key, true) && super.containsKey(key); } /** * @param key * @param value * @param expiryTime * 鍵值對有效期 毫秒 * @return */ public V put(K key, V value, long expiryTime) { expiryMap.put(key, System.currentTimeMillis() + expiryTime); return super.put(key, value); } public int size() { return entrySet().size(); } public boolean isEmpty() { return entrySet().size() == 0; } public boolean containsValue(Object value) { if (value == null) return Boolean.FALSE; Set<java.util.Map.Entry<K, V>> set = super.entrySet(); Iterator<java.util.Map.Entry<K, V>> iterator = set.iterator(); while (iterator.hasNext()) { java.util.Map.Entry<K, V> entry = iterator.next(); if (value.equals(entry.getValue())) { if (checkExpiry(entry.getKey(), false)) { iterator.remove(); return Boolean.FALSE; } else return Boolean.TRUE; } } return Boolean.FALSE; } public Collection<V> values() { Collection<V> values = super.values(); if (values == null || values.size() < 1) return values; Iterator<V> iterator = values.iterator(); while (iterator.hasNext()) { V next = iterator.next(); if (!containsValue(next)) iterator.remove(); } return values; } public V get(Object key) { if (key == null) return null; if (checkExpiry(key, true)) return null; return super.get(key); } /** * * @Description: 是否過期 * @param key * @return null:不存在或key為null -1:過期 存在且沒過期返回value 因為過期的不是實時刪除,所以稍微有點作用 */ public Object isInvalid(Object key) { if (key == null) return null; if (!expiryMap.containsKey(key)) { return null; } long expiryTime = expiryMap.get(key); boolean flag = System.currentTimeMillis() > expiryTime; if (flag) { super.remove(key); expiryMap.remove(key); return -1; } return super.get(key); } public void putAll(Map<? extends K, ? extends V> m) { for (Map.Entry<? extends K, ? extends V> e : m.entrySet()) expiryMap.put(e.getKey(), System.currentTimeMillis() + EXPIRY); super.putAll(m); } public Set<Map.Entry<K, V>> entrySet() { Set<java.util.Map.Entry<K, V>> set = super.entrySet(); Iterator<java.util.Map.Entry<K, V>> iterator = set.iterator(); while (iterator.hasNext()) { java.util.Map.Entry<K, V> entry = iterator.next(); if (checkExpiry(entry.getKey(), false)) iterator.remove(); } return set; } /** * * @Description: 是否過期 * @author: qd-ankang * @date: 2016-11-24 下午4:05:02 * @param expiryTime * true 過期 * @param isRemoveSuper * true super刪除 * @return */ private boolean checkExpiry(Object key, boolean isRemoveSuper) { if (!expiryMap.containsKey(key)) { return Boolean.FALSE; } long expiryTime = expiryMap.get(key); boolean flag = System.currentTimeMillis() > expiryTime; if (flag) { if (isRemoveSuper) super.remove(key); expiryMap.remove(key); } return flag; } /** * 刪除 * @param key */ public void del(Object key){ super.remove(key); expiryMap.remove(key); } public static void main(String[] args) throws InterruptedException { ExpiryMap<String, String> map = new ExpiryMap<>(10); map.put("test", "ankang"); map.put("test1", "ankang"); map.put("test2", "ankang", 3000); System.out.println("test1" + map.get("test")); Thread.sleep(1000); System.out.println("isInvalid:" + map.isInvalid("test")); System.out.println("size:" + map.size()); System.out.println("size:" + ((HashMap<String, String>) map).size()); for (Map.Entry<String, String> m : map.entrySet()) { System.out.println("isInvalid:" + map.isInvalid(m.getKey())); map.containsKey(m.getKey()); System.out.println("key:" + m.getKey() + " value:" + m.getValue()); } System.out.println("test1" + map.get("test")); } /** * 是否超過過期的一半時間 * @param key * @return */ public boolean isHalfExpiryTime(Object key ){ if (!expiryMap.containsKey(key)) { return false; } long expiryTime = expiryMap.get(key); boolean flag = System.currentTimeMillis()-(expiryTime-NettyGoConstant.LOGINSESSIONTIMEOUT)>=NettyGoConstant.LOGINSESSIONTIMEOUT/2; return flag; } }