踩坑笔记 Spring websocket并发发送消息异常
作者:mmseoamin日期:2023-12-25

文章目录

    • 示例代码
      • WebSocketConfig配置代码
      • 握手拦截器代码
      • 业务处理器代码
      • 问题复现
      • 原因分析
      • 解决方案
        • 方案一 加锁同步发送
        • 方案二 使用ConcurrentWebSocketSessionDecorator
        • 方案三 自研事件驱动队列(借鉴 Tomcat)
        • 总结

          今天刚刚经历了一个坑,非常新鲜,我立刻决定记录下来。首先,让我们先看一下我们项目中使用的 Spring WebSocket 示例代码。

          示例代码

          在我们的项目中,我们使用了 Spring WebSocket 来实现服务器与客户端之间的实时通信。下面是一个简化的示例代码:

          WebSocketConfig配置代码

          @Configuration
          @EnableWebSocket // 启动Websocket
          public class WebSocketConfig implements WebSocketConfigurer {
              @Override
              public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
                  registry.addHandler(myHandler(), "/myHandler/**")
                      // 添加拦截器,可以获取连接的param和 header 用作认证鉴权
                      .addInterceptors(new LakerSessionHandshakeInterceptor())
                      // 设置运行跨域
                      .setAllowedOrigins("*");
              }
                 
              @Bean
              public ServletServerContainerFactoryBean createWebSocketContainer() {
                  ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
                  // 设置默认会话空闲超时 以毫秒为单位 非正值意味着无限超时,默认值 0 ,默认没10s检查一次空闲就关闭
                  container.setMaxSessionIdleTimeout(10 * 1000L);
                  // 设置异步发送消息的默认超时时间 以毫秒为单位 非正值意味着无限超时 ,默认值-1,还没看到作用
                  container.setAsyncSendTimeout(10 * 1000L);
                  // 设置文本消息的默认最大缓冲区大小 以字符为单位,默认值 8 * 1024
                  container.setMaxTextMessageBufferSize(8 * 1024);
                  // 设置二进制消息的默认最大缓冲区大小 以字节为单位,默认值 8 * 1024
                  container.setMaxBinaryMessageBufferSize(8 * 1024);
                  return container;
              }
              @Bean
              public WebSocketHandler myHandler() {
                  return new MyHandler();
              }
          }
          

          握手拦截器代码

          public class LakerSessionHandshakeInterceptor extends HttpSessionHandshakeInterceptor {
              // 拦截器返回false,则不会进行websocket协议的转换
              // 可以获取请求参数做认证鉴权
              @Override
              public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception {
                  HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
                  // 将所有查询参数复制到 WebSocketSession 属性映射
                  Enumeration parameterNames = servletRequest.getParameterNames();
                  while (parameterNames.hasMoreElements()) {
                      String parameterName = parameterNames.nextElement();
                      // 放入的值可以从WebSocketHandler里面的WebSocketSession拿出来
                      attributes.put(parameterName, servletRequest.getParameter(parameterName));
                  }
                  if (servletRequest.getHeader(HttpHeaders.AUTHORIZATION) != null) {
                      setErrorResponse(response, HttpStatus.UNAUTHORIZED);
                      return false;
                  }
                  return true;
              }
              @Override
              public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
              }
              private void setErrorResponse(ServerHttpResponse response, HttpStatus status, String errorMessage) {
                  response.setStatusCode(status);
                  if (errorMessage != null) {
                      try {
                          objectMapper.writeValue(response.getBody(), new ErrorDetail(status.value(), errorMessage));
                      } catch (IOException ioe) {
                      }
                  }
              }
          }
          

          业务处理器代码

          public class MyHandler extends AbstractWebSocketHandler {
               private final Map webSocketSessionMap = new ConcurrentHashMap<>();
           
              //成功连接时
              @Override
              public void afterConnectionEstablished(WebSocketSession session) throws Exception {
                  super.afterConnectionEstablished(session);
                  // 设置最大报文大小,如果报文过大则会报错的,可以将限制调大。
                  // 会覆盖config中的配置。
                  session.setBinaryMessageSizeLimit(8 * 1024);
                  session.setTextMessageSizeLimit(8 * 1024);
                  webSocketSessionMap.put(session.getId(), session);
              }
              //处理textmessage
              @Override
              protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
                  super.handleTextMessage(session, message);
                  // 有消息就广播下
                  for (Map.Entry entry : webSocketSessionMap.entrySet()) {
                      String s = entry.getKey();
                      WebSocketSession webSocketSession = entry.getValue();
                      if (webSocketSession.isOpen()) {
                      	webSocketSession.sendMessage(new TextMessage(s + ":" + message.getPayload()));
                      	log.info("send to {} msg:{}", s, message.getPayload());
                      }
                  }
              }
              
              
              //报错时
              @Override
              public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
                  super.handleTransportError(session, exception);
                    log.warn("handleTransportError:: sessionId: {} {}", session.getId(), exception.getMessage(), exception);
                  if (session.isOpen()) {
                      try {
                          session.close();
                      } catch (Exception ex) {
                          // ignore
                      }
                  }
              }
              
             //连接关闭时
              @Override
              public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
                  super.afterConnectionClosed(session, status);
                  WebSocketSession session =  webSocketSessionMap.remove(session.getId());
                  if (session.isOpen()) {
                      try {
                          session.close();
                      } catch (Exception ex) {
                          // ignore
                      }
                  }
              }
              
              //处理binarymessage
              @Override
              protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
                  super.handleBinaryMessage(session, message);
              }
              //处理pong
              @Override
              protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
                  super.handlePongMessage(session, message);
              }
              //是否支持报文拆包
              // 决定是否接受半包,因为websocket协议比较底层,好像Tcp协议一样,如果发送大消息可能会拆成多个小报文。如果不希望处理不完整的报文,希望底层帮忙聚合成完整消息将此方法返回false,这样底层会等待完整报文到达聚合后才回调。
              @Override
              public boolean supportsPartialMessages() {
                  return super.supportsPartialMessages();
              }
          }
          

          问题复现

          在我们的测试环境中,我们发现当多个线程同时尝试通过 WebSocket 会话发送消息时,会抛出异常。

          现在我们用JMeter模拟100个用户发消息。

          当执行一会儿后会发现服务端出现如下异常日志:

          java.lang.IllegalStateException: 远程 endpoint 处于 [TEXT_PARTIAL_WRITING] 状态,是被调用方法的无效状态
          	at org.apache.tomcat.websocket.WsRemoteEndpointImplBase$StateMachine.checkState(WsRemoteEndpointImplBase.java:1274)
          	at org.apache.tomcat.websocket.WsRemoteEndpointImplBase$StateMachine.textPartialStart(WsRemoteEndpointImplBase.java:1231)
          	at org.apache.tomcat.websocket.WsRemoteEndpointImplBase.sendPartialString(WsRemoteEndpointImplBase.java:226)
          	at org.apache.tomcat.websocket.WsRemoteEndpointBasic.sendText(WsRemoteEndpointBasic.java:49)
          	at org.springframework.web.socket.adapter.standard.StandardWebSocketSession.sendTextMessage(StandardWebSocketSession.java:215)
          	at org.springframework.web.socket.adapter.AbstractWebSocketSession.sendMessage(AbstractWebSocketSession.java:108)
          

          原因分析

          经过分析,发现异常的根本原因是在并发发送消息时,WebSocket 会话的状态发生了异常。

          具体来说,当一个线程正在发送文本消息时,另一个线程也尝试发送消息,就会导致状态不一致,从而触发异常。

          即WebSocketSession.sendMessage其不是线程安全的,内部有个状态机来管理防止并发导致问题以fail-fast方式快速告诉使用者。

          那么问题代码就在这里了

          WebSocketSession是不支持并发发送消息的,我们使用者要保证其线程安全,这是我一开始没预期到的。

          解决方案

          方案一 加锁同步发送

          为了解决并发发送消息导致的异常,我们可以在发送消息的代码块上加锁,确保同一时刻只有一个线程能够执行发送操作。

          下面是使用加锁机制的示例代码:

          改起来非常简单,只需要对webSocketSession实例加个锁即可。

          缺点

          当并发度较高时,越后面排队等待锁的人被block的越久。

          大致模型图

          方案二 使用ConcurrentWebSocketSessionDecorator

          另一种解决并发发送消息的方法是使用 ConcurrentWebSocketSessionDecorator。这是 Spring WebSocket 提供的一个装饰器类,用于增强底层的 WebSocketSession 的线程安全性。它通过并发安全的方式包装原始的 WebSocketSession 对象,确保在多线程环境下安全地访问和修改会话属性,以及进行消息发送操作。

          ConcurrentWebSocketSessionDecorator 的工作原理是利用线程安全的数据结构和同步机制,确保对会话执行的操作的原子性和一致性。当需要发送消息时,装饰器会获取锁或使用并发数据结构来协调多个线程之间的访问。这样可以防止对会话状态的并发修改,避免潜在的竞态条件。

          下面是使用 ConcurrentWebSocketSessionDecorator 的示例代码:

              @Override
              public void afterConnectionEstablished(WebSocketSession session) {
                  log.info("{} Connection established!", session.getId());
                  // webSocketSessionMap.put(session.getId(), session);
                  // 把线程安全的session代理装饰类放到容器里
                  webSocketSessionMap.put(session.getId(), new ConcurrentWebSocketSessionDecorator(session, 10 * 1000, 64000));
              }
          

          ConcurrentWebSocketSessionDecorator原理

          ConcurrentWebSocketSessionDecorator整体代码大概200行还是比较容易看懂的。

          // 包装一个WebSocketSession以保证一次只有一个线程可以发送消息。
          // 如果发送速度慢,后续尝试从其他线程发送更多消息将无法获取刷新锁,消息将被缓冲。届时,将检查指定的缓冲区大小限制和发送时间限制,如果超过限制,会话将关闭。
          public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator {
          	private static final Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
          	private final int sendTimeLimit;
          	private final int bufferSizeLimit;
          	private final OverflowStrategy overflowStrategy;
          	@Nullable
          	private Consumer> preSendCallback;
          	private final Queue> buffer = new LinkedBlockingQueue<>();
          	private final AtomicInteger bufferSize = new AtomicInteger();
          	private volatile long sendStartTime;
          	private volatile boolean limitExceeded;
          	private volatile boolean closeInProgress;
          	private final Lock flushLock = new ReentrantLock();
          	private final Lock closeLock = new ReentrantLock();
          	public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) {
          		this(delegate, sendTimeLimit, bufferSizeLimit, OverflowStrategy.TERMINATE);
          	}
              // delegate 需要代理的session
              // sendTimeLimit 表示发送**单个消息**的最大时间
              // bufferSizeLimit 表示发送消息的队列最大字节数,不是消息的数量而是消息的总大小
              // overflowStrategy 表示超过限制时的策略有2个
                           // - 断开连接(默认选项)
                           // - 丢弃最老的数据,直到大小满足
          	public ConcurrentWebSocketSessionDecorator(
          			WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit, OverflowStrategy overflowStrategy) {
          		super(delegate);
          		this.sendTimeLimit = sendTimeLimit;
          		this.bufferSizeLimit = bufferSizeLimit;
          		this.overflowStrategy = overflowStrategy;
          	}
              // 返回自当前发送开始以来的时间(毫秒),如果当前没有发送正在进行则返回 0。
              // 即花费的耗时
          	public long getTimeSinceSendStarted() {
          		long start = this.sendStartTime;
          		return (start > 0 ? (System.currentTimeMillis() - start) : 0);
          	}
          	// 设置在将消息添加到发送缓冲区后调用的回调
          	public void setMessageCallback(Consumer> callback) {
          		this.preSendCallback = callback;
          	}
          	@Override
          	public void sendMessage(WebSocketMessage message) throws IOException {
                  //检查超限了就不发了
          		if (shouldNotSend()) {
          			return;
          		}
          		// 消息放到buffer队列
          		this.buffer.add(message);
                  // 增加bufferSize用于后面判断是不是超限了
          		this.bufferSize.addAndGet(message.getPayloadLength());
          		// 发送缓冲区后调用的回调
          		if (this.preSendCallback != null) {
          			this.preSendCallback.accept(message);
          		}
          		do {
                      // 尝试获取锁,发送消息,只有一个线程负责发送所有消息
          			if (!tryFlushMessageBuffer()) {
          				// 没获取锁的线程,对当前buffer和时间检查,
                          // 检查不过就抛异常 然后框架自己会抓取异常关闭当前的连接
          				checkSessionLimits();
          				break;
          			}
          		}
          		while (!this.buffer.isEmpty() && !shouldNotSend());
          	}
          	// 超限了 不能发送了
          	private boolean shouldNotSend() {
          		return (this.limitExceeded || this.closeInProgress);
          	}
          	// 尝试获取锁 并发送所有缓存的消息
          	private boolean tryFlushMessageBuffer() throws IOException {
          		if (this.flushLock.tryLock()) {
          			try {
                           // 循环发送消息
          				while (true) {
                              // 一次拉一个消息
          					WebSocketMessage message = this.buffer.poll();
                              // 没消息了 或者 超限了
          					if (message == null || shouldNotSend()) {
                                  // 退出 完活
          						break;
          					}
                              // 释放bufferSize
          					this.bufferSize.addAndGet(-message.getPayloadLength());
                              // 用于判断单个消息是否发送超时的
          					this.sendStartTime = System.currentTimeMillis();
                              // 发送消息
          					getDelegate().sendMessage(message);
                              // 重置开始时间
          					this.sendStartTime = 0;
          				}
          			}
          			finally {
          				this.sendStartTime = 0;
          				this.flushLock.unlock();
          			}
          			return true;
          		}
          		return false;
          	}
             //检查是否超时,是否超过buffer限制
          	private void checkSessionLimits() {
                  // 应该发送 且 获取到关闭锁
          		if (!shouldNotSend() && this.closeLock.tryLock()) {
          			try {
                           //检测是否发送超时
          				if (getTimeSinceSendStarted() > getSendTimeLimit()) {
          					String format = "Send time %d (ms) for session '%s' exceeded the allowed limit %d";
          					String reason = String.format(format, getTimeSinceSendStarted(), getId(), getSendTimeLimit());
          					limitExceeded(reason);
          				}
                           //检测buffer大小,根据策略要么抛异常关闭连接,要么丢弃数据
          				else if (getBufferSize() > getBufferSizeLimit()) {
          					switch (this.overflowStrategy) {
                                  // 关闭连接,抛出异常框架就会关闭连接    
          						case TERMINATE:
          							String format = "Buffer size %d bytes for session '%s' exceeds the allowed limit %d";
          							String reason = String.format(format, getBufferSize(), getId(), getBufferSizeLimit());
          							limitExceeded(reason);
          							break;
                                  // 丢弃老数据    
          						case DROP:
          							int i = 0;
          							while (getBufferSize() > getBufferSizeLimit()) {
          								WebSocketMessage message = this.buffer.poll();
          								if (message == null) {
          									break;
          								}
          								this.bufferSize.addAndGet(-message.getPayloadLength());
          								i++;
          							}
          							if (logger.isDebugEnabled()) {
          								logger.debug("Dropped " + i + " messages, buffer size: " + getBufferSize());
          							}
          							break;
          						default:
          							// Should never happen..
          							throw new IllegalStateException("Unexpected OverflowStrategy: " + this.overflowStrategy);
          					}
          				}
          			}
          			finally {
          				this.closeLock.unlock();
          			}
          		}
          	}
          	private void limitExceeded(String reason) {
          		this.limitExceeded = true;
          		throw new SessionLimitExceededException(reason, CloseStatus.SESSION_NOT_RELIABLE);
          	}
          	@Override
          	public void close(CloseStatus status) throws IOException {
          		this.closeLock.lock();
          		try {
          			if (this.closeInProgress) {
          				return;
          			}
          			if (!CloseStatus.SESSION_NOT_RELIABLE.equals(status)) {
          				try {
          					checkSessionLimits();
          				}
          				catch (SessionLimitExceededException ex) {
          					// Ignore
          				}
          				if (this.limitExceeded) {
          					if (logger.isDebugEnabled()) {
          						logger.debug("Changing close status " + status + " to SESSION_NOT_RELIABLE.");
          					}
          					status = CloseStatus.SESSION_NOT_RELIABLE;
          				}
          			}
          			this.closeInProgress = true;
          			super.close(status);
          		}
          		finally {
          			this.closeLock.unlock();
          		}
          	}
          	public enum OverflowStrategy {
          		TERMINATE,
          		DROP
          	}
          }
          

          大致模型图

          方案三 自研事件驱动队列(借鉴 Tomcat)

          除了使用加锁和 ConcurrentWebSocketSessionDecorator,我们还可以借鉴 Tomcat 的事件驱动队列机制来解决并发发送消息的问题。具体的实现需要一些复杂的逻辑和代码,涉及到消息队列、线程池和事件处理机制,因此在这里我不展开讨论。如果你对这个方案感兴趣,可以参考 Tomcat 的源代码,了解更多关于事件驱动队列的实现细节。

          这个参考之前的tomcat设计即可。

          img

          总结

          在本篇博文中,我们讨论了在使用 Spring WebSocket 进行并发发送消息时可能遇到的异常情况。我们深入分析了异常的原因,并提供了三种解决方案:加锁同步发送、使用 ConcurrentWebSocketSessionDecorator 和自研事件驱动队列(借鉴 Tomcat)。每种方案都有其适用的场景和注意事项,你可以根据自己的需求选择合适的方法来解决并发发送消息的异常问题。

          希望本文对你有所帮助,让你在使用 Spring WebSocket 时能够避免类似的坑。如果你对本文有任何疑问或意见,欢迎在评论区留言,我们将尽力为你解答。谢谢阅读!

          参考链接:

          • Spring Framework Documentation: WebSocket
          • Tomcat Documentation: Event-driven Architectures