要解决的问题
同一域名下有时会存在多个系统,多个系统也可能有多个节点。为了各系统之间共享登录状态,避免二次登陆,需要将session统一存储。
方案
本文使用redis统一存储session,重写Request和Session的方法,实现自定义存取session。
具体实现
1.继承HttpServletRequestWrapper这个装饰类,实现自定义request,重写getSession方法
public class MyRequest extends HttpServletRequestWrapper {
/**
* 序列化转换工具类
*/
private static ProtostuffSerializer serializer = new ProtostuffSerializer();
/**
* sessionID
*/
private String sessionID;
public MyRequest(HttpServletRequest request) {
super(request);
}
public MyRequest(HttpServletRequest request, String sessionID) {
super(request);
this.sessionID = sessionID;
}
@Override
public HttpSession getSession() {
try {
//sessionId 不为空,从redis中获取session
if (StringUtils.isNotBlank(sessionID)) {
String key = MySession.getSessionKey(sessionID);
RedisService redisService = WebApplicationContextUtil.getBean(RedisService.class);
return serializer.deserialize(redisService.getSession(key));
}
} catch (Exception e) {
}
//sessionId 为空,新建session返回,MySession 的构造方法会保存session到redis
MySession mySession = new MySession(super.getSession());
this.sessionID = mySession.getId();
return mySession;
}
}
2.自定义Mysession,实现Httpsession
public class MySession implements HttpSession, Serializable {
public static final long serialVersionUID = 1789323123543L;
private static ProtostuffSerializer protostuffSerializer = new ProtostuffSerializer();
private String id;
private long creationTime;
private long lastAccessedTime;
private int maxInactiveInterval = 3600;
private Map<String, Object> attributes = new HashMap<>();
public MySession() {
}
public MySession(HttpSession httpSession) {
super();
this.id = httpSession.getId();
this.maxInactiveInterval = httpSession.getMaxInactiveInterval();
updateSessionCache();
}
@Override
public long getCreationTime() {
return creationTime;
}
@Override
public String getId() {
return id;
}
@Override
public long getLastAccessedTime() {
return lastAccessedTime;
}
@Override
public ServletContext getServletContext() {
return WebApplicationContextUtil.getWebApplicationContext().getServletContext();
}
@Override
public void setMaxInactiveInterval(int maxInactiveInterval) {
this.maxInactiveInterval = maxInactiveInterval;
}
@Override
public int getMaxInactiveInterval() {
return maxInactiveInterval;
}
@Override
public HttpSessionContext getSessionContext() {
return null;
}
@Override
public Object getAttribute(String name) {
return attributes.get(name);
}
@Override
public Object getValue(String name) {
return attributes.get(name);
}
@Override
public Enumeration getAttributeNames() {
return null;
}
@Override
public String[] getValueNames() {
return new String[0];
}
@Override
public void setAttribute(String name, Object key) {
attributes.put(name, key);
updateSessionCache();
}
@Override
public void putValue(String name, Object key) {
attributes.put(name, key);
updateSessionCache();
}
@Override
public void removeAttribute(String name) {
attributes.remove(name);
updateSessionCache();
}
@Override
public void removeValue(String name) {
attributes.remove(name);
updateSessionCache();
}
@Override
public void invalidate() {
}
@Override
public boolean isNew() {
return false;
}
//更新redis中的session
public void updateSessionCache() {
String key = getSessionKey(id);
byte[] sessionInfo = protostuffSerializer.serialize(this);
try {
RedisService redisService = WebApplicationContextUtil.getBean(RedisService.class);
if (redisService != null) {
redisService.setSession(key, sessionInfo);
}
} catch (Exception e) {
}
}
static String getSessionKey(String sessionId) {
return "s|" + sessionId.replace(" ", "");
}
}
3.添加过滤器,使用自定义的MyRequest
public class MyFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
WebApplicationContextUtil.setWebApplicationContext(filterConfig.getServletContext());
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
//将原httpServletRequest 替换为MyRequest
String sessionId = getSessionId(httpServletRequest);
MyRequest myRequest = new MyRequest(httpServletRequest, sessionId);
filterChain.doFilter(myRequest, servletResponse);
}
@Override
public void destroy() {
}
//从cookie中拿SessionId,没有则认为没有登录过
private String getSessionId(HttpServletRequest httpServletRequest) {
Cookie[] cookies = httpServletRequest.getCookies();
if (!ArrayUtils.isEmpty(cookies)) {
for (Cookie cookie : cookies) {
if ("JSESSIONID".equals(cookie.getName())) {
return StringUtils.isNoneBlank(cookie.getValue()) ? cookie.getValue() : "";
}
}
}
return null;
}
}
4.配置过滤器,保证order的数值最小,最先执行
@Configuration
public class FilterConfiguration {
@Bean
public FilterRegistrationBean<MyFilter> createMyFilter() {
FilterRegistrationBean<MyFilter> filterRegistrationBean = new FilterRegistrationBean<>();
filterRegistrationBean.setFilter(new MyFilter());
filterRegistrationBean.setName("myFilter");
filterRegistrationBean.setOrder(-1);
return filterRegistrationBean;
}
}
MyRequest ,MySession,MyFilter 放到core包中,各服务引入后只需配置过滤器即可