SpringBoot、Spring-security、WebSocket整合,WebSocket带token加入Spring-security认证机制

背景

项目里边有个模块首页会显示一些汇总信息,有部分信息需要进行实时更新,因此想到使用WebSocket进行长连接,进入后便链接后台,然后后台根据需求(定时、数据有更新)给前台反馈数据,从而达到一次链接,一直通信的功能。避免了前端轮询调用带来的资源消耗。

WebSocket:

本人是第一次接触websocket,之前对类似的处理方式能想到的就是前端隔一段时间调用一次接口来获取数据。然后项目内的人就推荐让我用websocket去写,这样就不用前端一直调用接口了,直接初始化链接,然后一直保持通信,后台根据需求给前端返回数据。

pom依赖

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

websocket处理类

// 前端链接的路由(ws://localhost:8080/webSocket/projectHome)
@ServerEndpoint(value = "/webSocket/projectHome", subprotocols = "")
@Slf4j
@Component
public class WebSocketServer {
   @Autowired
   private static StatisticsDataService service;
   /**
    * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
    */
   private static final CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>();

   /**
    * 与某个客户端的连接会话,需要通过它来给客户端发送数据
    */
   private Session session;

   /**
    * 连接数
    */
   private static final AtomicInteger count = new AtomicInteger();

   /**
    * 用户id,一开始使用前端传入的方式,但是有安全隐患,后来换成后端从security中获取
    */
   private String sid = "";

   /**
    * 连接建立成功调用的方法
    */
   @OnOpen
   public void onOpen(Session session) {
   	// 获取用户信息,保存到websocket中
       Authentication authentication = (Authentication) session.getUserPrincipal();
       SecurityUtils.setAuthentication(authentication);
       String username = SecurityUtils.getUsername();
       this.session = session;
       //如果存在就先删除一个,防止重复推送消息
       for (WebSocketServer webSocket : webSocketSet) {
           if (webSocket.sid.equals(username)) {
               webSocketSet.remove(webSocket);
               count.getAndDecrement();
           }
       }
       count.getAndIncrement();
       webSocketSet.add(this);
       this.sid = username;
   }

   /**
    * 连接关闭调用的方法
    */
   @OnClose
   public void onClose() {
       webSocketSet.remove(this);
       count.getAndDecrement();
   }

   /**
    * 收到客户端消息后调用的方法
    *
    * @param message 客户端发送过来的消息
    */
   @OnMessage
   public void onMessage(String message, Session session) {
       Authentication authentication = (Authentication) session.getUserPrincipal();
       log.info("收到来自" + sid + "的信息:" + message);
       // 实时更新
       service.refresh(sid, authentication);
   }

   @OnError
   public void onError(Session session, Throwable error) {
       log.error("发生错误");
       error.printStackTrace();
   }

   /**
    * 实现服务器主动推送
    */
   private void sendMessage(String type, Object data) throws IOException {
       Map<String, Object> result = new HashMap<>();
       result.put("type", type);
       result.put("data", data);
       this.session.getAsyncRemote().sendText(ObjectMapperBuilder.toJSONString(result));
   }

   /**
    * 群发自定义消息
    */
   public static void sendInfo(String type, Object data, @PathParam("sid") String sid) {
       for (WebSocketServer item : webSocketSet) {
           try {
               //这里可以设定只推送给这个sid的,为null则全部推送
               if (sid == null) {
                   item.sendMessage(type, data);
               } else if (item.sid.equals(sid)) {
                   item.sendMessage(type, data);
               }
           } catch (IOException ignored) {
           }
       }
   }

   @Override
   public boolean equals(Object o) {
       if (this == o) {
           return true;
       }
       if (o == null || getClass() != o.getClass()) {
           return false;
       }
       WebSocketServer that = (WebSocketServer) o;
       return Objects.equals(session, that.session);
   }

   /**
    * 判断是否有链接
    *
    * @return
    */
   public static boolean isConn(String sid) {
       for (WebSocketServer item : webSocketSet) {
           if (item.sid.equals(sid)) {
               return true;
           }
       }
       return false;
   }

   @Override
   public int hashCode() {
       return Objects.hash(session);
   }

   @Autowired
   public void setRepository(StatisticsDataService service) {
       WebSocketServer.service = service;
   }

开发出现的问题

  1. 业务代码处理
  2. 注入的service类为null
  3. 上下文认证

问题处理

1.业务代码的调用:
一开始想到的是用定时任务隔一段时间跑一次,如下:

/**
 * 实时更新
 */
@Scheduled(fixedDelay = 5000L)
private void refresh(){
    // 业务处理***
}

但是业务需求是将数据与上一次发送的数据做对比,如果有更新才给客户端发送数据,
而且需要判断用户是否还在线,这样定时任务就不适用了,无法关联用户。
然后转换到在websocket中OnMessage内调用业务方法,如下:

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        Authentication authentication = (Authentication) session.getUserPrincipal();
        log.info("收到来自" + sid + "的信息:" + message);
        // 实时更新
        service.refresh(sid, authentication);
    }

2.这样调用引发了第二个问题,就是注入的service在调用方法时一直报空指针,原理应该是websocket不受spring管控,所以在websocket中是拿不到spring中注入的对象的。然后百度一波,发现在websocket中需要自己设定一个,如下:

	@Autowired
    private static StatisticsDataService service;

	@Autowired
    public void setRepository(StatisticsDataService service) {
        WebSocketServer.service = service;
    }

3.这样就解决了注入对象空指针异常的问题了。本以为可以愉快的跑起来了,但是还没咧开嘴第3个问题就出来了。因为是微服务,所以有的接口是需要通过feign接口来调用的,然后业务处理的话是使用另外一个新线程来跑的,这就造成上下文中没有了用户的信息,导致获取用户权限时失败。。。
解决方法:
1.前端在建立链接时将认证token以参数的方式传入后台
例:ws://localhost:8080/webSocket/projectHome?Authorization=abcdefghi***
2.后台Spring-Security进行拦截过滤,从url中获取token,根据token获取用户信息,然后将用户注入到上下文中。这块代码是别人写的,因为涉及到搭建框架方式,所以不大清楚处理方式,不过逻辑是这样的。
3.在业务处理线程中将用户信息再次注入。代码如下:

websocket内在收到客户端信息后进行业务处理调用
    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
    	// 获取用户信息
        Authentication authentication = (Authentication) session.getUserPrincipal();
        log.info("收到来自" + sid + "的信息:" + message);
        // 实时更新
        service.refresh(sid, authentication);
    }

service中
	/**
     * 实时更新
     */
    private void refresh(String userId, Authentication authentication) {
        ThreadPoolExecutorUtil.getPoll().execute(() -> {
            // 注入用户信息
            SecurityUtils.setAuthentication(authentication);

            // 获取元数据
            int num = 0;

            // 判断用户是否在线,不在线则不用处理
            while (WebSocketServer.isConn(userId)) {
                // 获取数据
                int newNum = 1;
                // 判断数据是否有更新
                if (num != newNum) {
                    num = newNum;
                    // 发送最新数据给前端
                    WebSocketServer.sendInfo("num", newNum, userId);
                }
                try {
                	// 没5秒执行一次
                    Thread.sleep(5000L);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });
    }

这样就能实现长连接实时通信了,而且上下文中也能获取到用户信息,进行用户鉴权了。

后来项目内大佬有写了一种定时任务来替换创建线程执行业务代码,代码如下:

	private void refresh(String userId, Authentication authentication) {
        handler.start(5000L, task -> {
            // 注入用户信息
            SecurityUtils.setAuthentication(authentication);

            // 获取元数据
            int num = 0;

            // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
            if (WebSocketServer.isConn(userId)) {
                // 获取数据
                int newNum = 1;
                // 判断数据是否有更新
                if (num != newNum) {
                    num = newNum;
                    // 发送最新数据给前端
                    WebSocketServer.sendInfo("num", newNum, userId);
                }
                // 设置返回值,判断是否需要继续执行
                return true;
            }
            return false;
        });
    }
handler内:
    public void start(long delay, Function<Timeout, Boolean> function) {
        timer.newTimeout(t -> {
        	// 获取返回值,判断是否执行
            Boolean result = function.apply(t);
            if (result) {
                timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS);
            }
        }, delay, TimeUnit.MILLISECONDS);
    }

到此WebSocket使用开发完毕,第一次用,有些地方说的不对的还请大佬指出,一起学习。

应评论区要求添加上SecurityUtil类

package com.budsoft.utils;

import cn.hutool.crypto.digest.MD5;
import cn.hutool.json.JSONObject;
import com.budsoft.exception.BadRequestException;
import com.budsoft.security.vo.JwtUser;
import org.apache.commons.codec.digest.DigestUtils;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;

import java.util.List;
import java.util.stream.Collectors;

import static com.budsoft.utils.XAdminConstant.ADMIN_CODE;
import static com.budsoft.utils.XAdminConstant.SYSTEM_CODE;

/**
 * 获取当前登录的用户
 */
public class SecurityUtils {

    public static Authentication getAuthentication() {
        return SecurityContextHolder.getContext().getAuthentication();
    }

    public static String getTokenKey() {
        Object token = getAuthentication().getCredentials();
        return MD5.create().digestHex(token.toString());
    }

    public static String getAuthenticationKey() {
        return DigestUtils.sha256Hex(getTokenKey());
    }

    public static UserDetails getUserDetails() {
        UserDetails userDetails;
        try {
            userDetails = (UserDetails) getAuthentication().getPrincipal();
        } catch (Exception e) {
            throw new BadRequestException(HttpStatus.UNAUTHORIZED, "登录状态过期");
        }
        return userDetails;
    }

    public static JwtUser getJwtUser() {
        JwtUser jwtUser;
        try {
            jwtUser = (JwtUser) getAuthentication().getPrincipal();
        } catch (Exception e) {
            throw new BadRequestException(HttpStatus.UNAUTHORIZED, "登录状态过期");
        }
        return jwtUser;
    }

    /**
     * 获取系统用户名称
     *
     * @return 系统用户名称
     */
    public static String getUsername() {
        Object obj = getUserDetails();
        return new JSONObject(obj).get("username", String.class);
    }

    public static List<String> getPermissions() {
        return getUserDetails().getAuthorities().stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList());
    }

    public static Boolean isAdministrator() {
        List<String> permissions = getPermissions();
        return permissions.contains(ADMIN_CODE) || permissions.contains(SYSTEM_CODE);
    }

    public static void setAuthentication(Authentication authentication) {
        SecurityContextHolder.getContext().setAuthentication(authentication);
    }
}
可以通过在Spring Security中配置一个Token认证过滤器来实现基于WebSocketToken认证。具体步骤如下: 1. 创建一个TokenAuthenticationFilter类,继承自OncePerRequestFilter并实现doFilterInternal方法。该类负责检查请求中是否包含有效的Token,并进行相应的认证处理。 ```java public class TokenAuthenticationFilter extends OncePerRequestFilter { private final TokenProvider tokenProvider; public TokenAuthenticationFilter(TokenProvider tokenProvider) { this.tokenProvider = tokenProvider; } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { String token = getTokenFromRequest(request); if (StringUtils.hasText(token) && tokenProvider.validateToken(token)) { Authentication authentication = tokenProvider.getAuthentication(token); SecurityContextHolder.getContext().setAuthentication(authentication); } filterChain.doFilter(request, response); } private String getTokenFromRequest(HttpServletRequest request) { String bearerToken = request.getHeader("Authorization"); if (StringUtils.hasText(bearerToken) && bearerToken.startsWith("Bearer ")) { return bearerToken.substring(7); } return null; } } ``` 2. 创建一个TokenProvider类,用于生成Token和验证Token的有效性,并根据Token获取用户信息。 ```java @Component public class TokenProvider { private static final String SECRET_KEY = "my-secret-key"; private static final long EXPIRATION_TIME = 86400000; // 1 day public String generateToken(Authentication authentication) { UserPrincipal principal = (UserPrincipal) authentication.getPrincipal(); Date expirationDate = new Date(System.currentTimeMillis() + EXPIRATION_TIME); return Jwts.builder() .setSubject(Long.toString(principal.getId())) .setIssuedAt(new Date()) .setExpiration(expirationDate) .signWith(SignatureAlgorithm.HS512, SECRET_KEY) .compact(); } public boolean validateToken(String token) { try { Jwts.parser().setSigningKey(SECRET_KEY).parseClaimsJws(token); return true; } catch (Exception e) { return false; } } public Authentication getAuthentication(String token) { Claims claims = Jwts.parser().setSigningKey(SECRET_KEY).parseClaimsJws(token).getBody(); Long userId = Long.parseLong(claims.getSubject()); UserPrincipal principal = new UserPrincipal(userId); return new UsernamePasswordAuthenticationToken(principal, "", principal.getAuthorities()); } } ``` 3. 在配置类中注册TokenAuthenticationFilter和TokenProvider,并将TokenAuthenticationFilter添加到Spring Security的过滤器链中。 ```java @Configuration @EnableWebSocketMessageBroker public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @Autowired private TokenProvider tokenProvider; @Override public void configureMessageBroker(MessageBrokerRegistry config) { config.enableSimpleBroker("/topic"); config.setApplicationDestinationPrefixes("/app"); } @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/ws").setAllowedOriginPatterns("*").withSockJS(); } @Override public void configureClientInboundChannel(ChannelRegistration registration) { registration.interceptors(new ChannelInterceptorAdapter() { @Override public Message<?> preSend(Message<?> message, MessageChannel channel) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); if (StompCommand.CONNECT.equals(accessor.getCommand())) { String token = accessor.getFirstNativeHeader("Authorization"); if (StringUtils.hasText(token) && token.startsWith("Bearer ")) { token = token.substring(7); TokenAuthenticationFilter filter = new TokenAuthenticationFilter(tokenProvider); SecurityContextHolder.getContext().setAuthentication(filter.getAuthentication(token)); } } return message; } }); } @Override public void configureClientOutboundChannel(ChannelRegistration registration) { } @Override public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) { } @Override public void addReturnValueHandlers(List<HandlerMethodReturnValueHandler> returnValueHandlers) { } @Override public boolean configureMessageConverters(List<MessageConverter> messageConverters) { return true; } @Override public void configureWebSocketTransport(WebSocketTransportRegistration registry) { } @Bean public TokenAuthenticationFilter tokenAuthenticationFilter() throws Exception { TokenAuthenticationFilter filter = new TokenAuthenticationFilter(tokenProvider); filter.setAuthenticationManager(authenticationManager()); return filter; } @Override protected void configure(HttpSecurity http) throws Exception { http.csrf().disable().authorizeRequests() .antMatchers("/api/auth/**").permitAll() .anyRequest().authenticated(); } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth.userDetailsService(userDetailsService()) .passwordEncoder(passwordEncoder()); } @Bean public PasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); } @Bean @Override public AuthenticationManager authenticationManagerBean() throws Exception { return super.authenticationManagerBean(); } @Override @Bean public UserDetailsService userDetailsService() { return new UserDetailsServiceImpl(); } } ``` 在上述代码中,我们通过重写configureClientInboundChannel方法,在连接到WebSocket时获取请求中的Token,并使用TokenAuthenticationFilter进行认证。注意,我们需要将TokenAuthenticationFilter添加到Spring Security的过滤器链中,以便它能够在WebSocket连接期间对请求进行拦截。 最后,我们需要在客户端的连接请求中添加Authorization头部,以便在服务端进行Token认证。例如: ```javascript stompClient.connect({}, function(frame) { console.log('Connected: ' + frame); stompClient.subscribe('/topic/greetings', function(greeting) { showGreeting(JSON.parse(greeting.body).content); }); }, function(error) { console.log('Error: ' + error); }, {"Authorization": "Bearer " + token}); ```
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值