WebSockets Demystified - Part 2: Spring Boot Implementation
Series: Index | Part 1 | Part 2 | Part 3 | Part 4 | Part 5 | Part 6
Table of Contents
- Maven Dependencies
- WebSocket Configuration
- STOMP Message Controller
- Sending Messages from Services
- Security Configuration
- Session Management and Connection Lifecycle
- Error Handling
- Interceptors and Custom Headers
- Heartbeat and Timeout Configuration
- Testing WebSocket Endpoints
1. Maven Dependencies
<!-- pom.xml -->
<properties>
<java.version>17</java.version>
<spring-boot.version>3.2.0</spring-boot.version>
</properties>
<dependencies>
<!-- Core Spring Boot Web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- WebSocket support - includes STOMP and SockJS -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- Spring Security for WebSocket authentication -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-security</artifactId>
</dependency>
<!-- JWT for token-based authentication -->
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
<version>0.12.3</version>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<version>0.12.3</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<version>0.12.3</version>
<scope>runtime</scope>
</dependency>
<!-- Redis for pub/sub in clustered deployments -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- Spring Data JPA with MySQL -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<scope>runtime</scope>
</dependency>
<!-- Actuator for health checks and metrics -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<!-- Micrometer for metrics (sends to CloudWatch) -->
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-cloudwatch2</artifactId>
</dependency>
<!-- Validation -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-validation</artifactId>
</dependency>
<!-- Lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!-- Testing -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>2. WebSocket Configuration
Application Properties
# src/main/resources/application.yml
server:
port: 8080
# Important: Tomcat WebSocket configuration
tomcat:
max-connections: 10000
threads:
max: 200
min-spare: 20
spring:
application:
name: websocket-service
datasource:
url: jdbc:mysql://localhost:3306/wsdb?useSSL=true&serverTimezone=UTC&allowPublicKeyRetrieval=true
username: ${DB_USERNAME}
password: ${DB_PASSWORD}
hikari:
maximum-pool-size: 20
minimum-idle: 5
idle-timeout: 30000
connection-timeout: 20000
jpa:
hibernate:
ddl-auto: validate
show-sql: false
properties:
hibernate:
dialect: org.hibernate.dialect.MySQL8Dialect
format_sql: true
data:
redis:
host: ${REDIS_HOST:localhost}
port: ${REDIS_PORT:6379}
password: ${REDIS_PASSWORD:}
ssl:
enabled: ${REDIS_SSL:false}
timeout: 2000ms
lettuce:
pool:
max-active: 20
max-idle: 10
min-idle: 5
app:
websocket:
# Heartbeat intervals in milliseconds
heartbeat-send-interval: 25000
heartbeat-receive-interval: 25000
# Message size limits
message-size-limit: 65536 # 64 KB per message
send-buffer-size-limit: 524288 # 512 KB per connection send buffer
send-time-limit: 20000 # 20 seconds max to send a message before disconnect
# Allowed origins - MUST be explicitly configured
allowed-origins:
- "https://app.example.com"
- "https://www.example.com"
jwt:
secret: ${JWT_SECRET}
expiration-ms: 86400000
management:
endpoints:
web:
exposure:
include: health,info,metrics,prometheus
metrics:
export:
cloudwatch:
namespace: WebSocketService
step: 1mCore WebSocket Configuration Class
// src/main/java/com/example/ws/config/WebSocketConfig.java
package com.example.ws.config;
import com.example.ws.interceptor.AuthHandshakeInterceptor;
import com.example.ws.interceptor.WebSocketChannelInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
import java.util.List;
@Configuration
@EnableWebSocketMessageBroker // Enables STOMP messaging with a full-featured message broker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Autowired
private AuthHandshakeInterceptor authHandshakeInterceptor;
@Autowired
private WebSocketChannelInterceptor channelInterceptor;
@Value("${app.websocket.allowed-origins}")
private List<String> allowedOrigins;
@Value("${app.websocket.message-size-limit}")
private int messageSizeLimit;
@Value("${app.websocket.send-buffer-size-limit}")
private int sendBufferSizeLimit;
@Value("${app.websocket.send-time-limit}")
private int sendTimeLimit;
@Value("${app.websocket.heartbeat-send-interval}")
private long heartbeatSendInterval;
@Value("${app.websocket.heartbeat-receive-interval}")
private long heartbeatReceiveInterval;
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws")
// STOMP endpoint path - clients connect to wss://host/ws
.setHandshakeHandler(new CustomHandshakeHandler())
// Add our JWT authentication interceptor
.addInterceptors(authHandshakeInterceptor)
// Allowed origins - NEVER use "*" in production
.setAllowedOrigins(allowedOrigins.toArray(new String[0]))
// SockJS fallback for environments that don't support WebSocket
.withSockJS()
// SockJS heartbeat
.setHeartbeatTime(25000)
// Disconnect delay before considering a client gone
.setDisconnectDelay(5000)
// Timeout for HTTP-based transports
.setHttpMessageCacheSize(1000)
.setWebSocketEnabled(true);
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
// Application destination prefix - messages sent here go to @MessageMapping methods
registry.setApplicationDestinationPrefixes("/app");
// Topic prefix - server broadcasts to these (all subscribers receive)
// Queue prefix - server sends to specific users
registry.enableSimpleBroker("/topic", "/queue")
.setHeartbeatValue(new long[]{heartbeatSendInterval, heartbeatReceiveInterval})
.setTaskScheduler(heartbeatTaskScheduler());
// User destination prefix - for private messages to a specific user
// /user/queue/notifications -> sends to a specific user's notification queue
registry.setUserDestinationPrefix("/user");
}
@Override
public void configureWebSocketTransport(WebSocketTransportRegistration registration) {
registration
// Maximum size of a single incoming message (64 KB default)
.setMessageSizeLimit(messageSizeLimit)
// Maximum size of data that can be buffered per session when sending
.setSendBufferSizeLimit(sendBufferSizeLimit)
// Maximum time to block when sending to a client before it is considered dead
.setSendTimeLimit(sendTimeLimit);
}
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
// Add our security/validation interceptor to inbound messages
registration.interceptors(channelInterceptor);
// Thread pool for processing inbound messages
registration.taskExecutor()
.corePoolSize(4)
.maxPoolSize(20)
.queueCapacity(100);
}
@Override
public void configureClientOutboundChannel(ChannelRegistration registration) {
// Thread pool for sending outbound messages
registration.taskExecutor()
.corePoolSize(4)
.maxPoolSize(20)
.queueCapacity(100);
}
// Task scheduler for heartbeat - required when configuring heartbeat on the broker
@Bean
public org.springframework.scheduling.TaskScheduler heartbeatTaskScheduler() {
org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler scheduler =
new org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler();
scheduler.setPoolSize(2);
scheduler.setThreadNamePrefix("ws-heartbeat-");
scheduler.initialize();
return scheduler;
}
}Custom Handshake Handler
// src/main/java/com/example/ws/config/CustomHandshakeHandler.java
package com.example.ws.config;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import java.security.Principal;
import java.util.Map;
/**
* Customizes the Principal assigned to each WebSocket session.
* Spring uses the Principal's name to route /user/** destinations.
* We set it from the JWT token verified during the handshake.
*/
public class CustomHandshakeHandler extends DefaultHandshakeHandler {
@Override
protected Principal determineUser(
ServerHttpRequest request,
WebSocketHandler wsHandler,
Map<String, Object> attributes) {
// The JWT interceptor (AuthHandshakeInterceptor) has already
// extracted and validated the token, and stored the userId in attributes.
String userId = (String) attributes.get("userId");
if (userId == null) {
// This should not happen if AuthHandshakeInterceptor is correct.
// Return null or throw - returning null means this session has no principal.
return null;
}
// Return a simple Principal wrapping the userId.
// Spring uses this name to route messages to /user/{userId}/queue/...
return () -> userId;
}
}3. STOMP Message Controller
Understanding STOMP Destinations
STOMP Destination Types in Spring:
/app/chat.send -> Goes to @MessageMapping("/chat.send")
/topic/room.123 -> Broadcast to all subscribers of this topic
/queue/private -> Point-to-point queue
/user/queue/notify -> User-specific destination (routes to /queue/notify for that user)
Chat Message Controller
// src/main/java/com/example/ws/controller/ChatController.java
package com.example.ws.controller;
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.dto.ChatMessageResponse;
import com.example.ws.dto.TypingIndicator;
import com.example.ws.service.ChatService;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.handler.annotation.DestinationVariable;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.stereotype.Controller;
import java.security.Principal;
@Controller
@RequiredArgsConstructor
@Slf4j
public class ChatController {
private final ChatService chatService;
/**
* Handles a message sent to /app/chat.send/{roomId}
*
* The @MessageMapping path is relative to the app prefix configured in WebSocketConfig.
* Full destination: /app/chat.send/{roomId}
*
* The @Valid annotation triggers Bean Validation on the payload.
*
* The Principal is automatically injected and contains the userId
* from our CustomHandshakeHandler.
*/
@MessageMapping("/chat.send/{roomId}")
public void sendMessage(
@DestinationVariable String roomId,
@Valid @Payload ChatMessageRequest request,
Principal principal,
SimpMessageHeaderAccessor headerAccessor) {
String sessionId = headerAccessor.getSessionId();
String userId = principal.getName();
log.debug("Message from user={} in room={}, session={}", userId, roomId, sessionId);
// Validate user has access to this room (business logic check)
chatService.validateRoomAccess(userId, roomId);
// Process and broadcast the message
chatService.processAndBroadcast(userId, roomId, request);
}
/**
* Handles typing indicator events.
* Sent to /app/chat.typing/{roomId}
* Broadcasts to /topic/room.{roomId}.typing
*/
@MessageMapping("/chat.typing/{roomId}")
public void handleTyping(
@DestinationVariable String roomId,
@Payload TypingIndicator indicator,
Principal principal) {
String userId = principal.getName();
chatService.broadcastTypingIndicator(userId, roomId, indicator.isTyping());
}
/**
* @SubscribeMapping fires immediately when a client subscribes.
* The return value is sent directly back to the subscribing client.
* Full destination client subscribes to: /app/chat.history/{roomId}
*
* This is useful for sending initial state when a client subscribes.
*/
@SubscribeMapping("/chat.history/{roomId}")
public ChatHistoryResponse onSubscribeToRoom(
@DestinationVariable String roomId,
Principal principal) {
String userId = principal.getName();
chatService.validateRoomAccess(userId, roomId);
// Return last 50 messages - delivered directly to the subscribing client
return chatService.getRecentMessages(roomId, 50);
}
/**
* Handles errors thrown by any @MessageMapping method in this controller.
* The @SendToUser annotation sends the error only to the user who triggered it.
*/
@MessageExceptionHandler
@SendToUser("/queue/errors")
public ErrorResponse handleException(Throwable exception, Principal principal) {
log.error("WebSocket message error for user={}: {}", principal.getName(), exception.getMessage());
return new ErrorResponse(exception.getMessage());
}
}DTOs with Validation
// src/main/java/com/example/ws/dto/ChatMessageRequest.java
package com.example.ws.dto;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;
import lombok.Data;
@Data
public class ChatMessageRequest {
@NotBlank(message = "Message content must not be blank")
@Size(max = 4000, message = "Message must not exceed 4000 characters")
private String content;
private String messageType = "TEXT"; // TEXT, IMAGE, FILE
}// src/main/java/com/example/ws/dto/ChatMessageResponse.java
package com.example.ws.dto;
import lombok.Builder;
import lombok.Data;
import java.time.Instant;
@Data
@Builder
public class ChatMessageResponse {
private String messageId;
private String roomId;
private String senderId;
private String senderName;
private String content;
private String messageType;
private Instant timestamp;
private String status; // SENT, DELIVERED, READ
}// src/main/java/com/example/ws/dto/TypingIndicator.java
package com.example.ws.dto;
import lombok.Data;
@Data
public class TypingIndicator {
private boolean typing;
}// src/main/java/com/example/ws/dto/ErrorResponse.java
package com.example.ws.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.time.Instant;
@Data
@AllArgsConstructor
public class ErrorResponse {
private String message;
private Instant timestamp = Instant.now();
public ErrorResponse(String message) {
this.message = message;
}
}4. Sending Messages from Services
SimpMessagingTemplate - The Core Tool
SimpMessagingTemplate is Spring's primary way to push messages to WebSocket clients from anywhere in your application - services, scheduled tasks, event listeners.
// src/main/java/com/example/ws/service/ChatService.java
package com.example.ws.service;
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.dto.ChatMessageResponse;
import com.example.ws.entity.ChatMessage;
import com.example.ws.entity.ChatRoom;
import com.example.ws.exception.AccessDeniedException;
import com.example.ws.exception.RoomNotFoundException;
import com.example.ws.repository.ChatMessageRepository;
import com.example.ws.repository.ChatRoomRepository;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
@Service
@Slf4j
public class ChatService {
private final SimpMessagingTemplate messagingTemplate;
private final ChatMessageRepository messageRepository;
private final ChatRoomRepository roomRepository;
private final Counter messagesSentCounter;
public ChatService(
SimpMessagingTemplate messagingTemplate,
ChatMessageRepository messageRepository,
ChatRoomRepository roomRepository,
MeterRegistry meterRegistry) {
this.messagingTemplate = messagingTemplate;
this.messageRepository = messageRepository;
this.roomRepository = roomRepository;
// Register metrics - these go to CloudWatch via Micrometer
this.messagesSentCounter = Counter.builder("websocket.messages.sent")
.description("Total WebSocket messages sent")
.register(meterRegistry);
}
/**
* Validates that a user has access to a room.
* Throws AccessDeniedException if they do not.
*/
public void validateRoomAccess(String userId, String roomId) {
if (!roomRepository.isMember(userId, roomId)) {
throw new AccessDeniedException("User " + userId + " is not a member of room " + roomId);
}
}
/**
* Saves the message to MySQL and broadcasts it to all room subscribers.
*/
@Transactional
public void processAndBroadcast(String userId, String roomId, ChatMessageRequest request) {
// 1. Build entity and persist to MySQL
ChatMessage message = ChatMessage.builder()
.messageId(UUID.randomUUID().toString())
.roomId(roomId)
.senderId(userId)
.content(request.getContent())
.messageType(request.getMessageType())
.status("SENT")
.createdAt(Instant.now())
.build();
ChatMessage saved = messageRepository.save(message);
// 2. Build the response DTO
ChatMessageResponse response = buildResponse(saved, userId);
// 3. Broadcast to all subscribers of this room topic
// All clients subscribed to /topic/room.{roomId} will receive this message
messagingTemplate.convertAndSend("/topic/room." + roomId, response);
// 4. Update metrics
messagesSentCounter.increment();
log.debug("Broadcast message {} to room {}", saved.getMessageId(), roomId);
}
/**
* Sends a message to a SPECIFIC user only.
* Uses the user destination prefix - only the user with userId receives it.
* Client subscribes to: /user/queue/private
*/
public void sendPrivateMessage(String targetUserId, ChatMessageResponse response) {
// convertAndSendToUser prepends /user/{userId} automatically
messagingTemplate.convertAndSendToUser(
targetUserId,
"/queue/private",
response
);
}
/**
* Broadcasts a typing indicator to the room.
* These are ephemeral - not stored in DB.
*/
public void broadcastTypingIndicator(String userId, String roomId, boolean isTyping) {
TypingEvent event = new TypingEvent(userId, roomId, isTyping);
messagingTemplate.convertAndSend("/topic/room." + roomId + ".typing", event);
}
/**
* Retrieves recent messages for a room.
* Called when a user first subscribes.
*/
@Transactional(readOnly = true)
public ChatHistoryResponse getRecentMessages(String roomId, int limit) {
List<ChatMessage> messages = messageRepository
.findTopByRoomIdOrderByCreatedAtDesc(roomId, limit);
List<ChatMessageResponse> responses = messages.stream()
.map(m -> buildResponse(m, m.getSenderId()))
.collect(Collectors.toList());
return new ChatHistoryResponse(roomId, responses);
}
private ChatMessageResponse buildResponse(ChatMessage message, String userId) {
return ChatMessageResponse.builder()
.messageId(message.getMessageId())
.roomId(message.getRoomId())
.senderId(message.getSenderId())
.content(message.getContent())
.messageType(message.getMessageType())
.timestamp(message.getCreatedAt())
.status(message.getStatus())
.build();
}
}Sending Messages from a Scheduled Task
// src/main/java/com/example/ws/scheduler/StockPriceScheduler.java
package com.example.ws.scheduler;
import com.example.ws.dto.PriceUpdate;
import com.example.ws.service.StockPriceService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
@RequiredArgsConstructor
@Slf4j
public class StockPriceScheduler {
private final SimpMessagingTemplate messagingTemplate;
private final StockPriceService stockPriceService;
/**
* Pushes live stock prices to all subscribers every 1 second.
* Clients subscribe to /topic/prices to receive these updates.
*/
@Scheduled(fixedDelay = 1000)
public void broadcastPrices() {
try {
List<PriceUpdate> prices = stockPriceService.getLatestPrices();
messagingTemplate.convertAndSend("/topic/prices", prices);
} catch (Exception e) {
log.error("Failed to broadcast price update: {}", e.getMessage());
// Do not let the exception propagate - it would stop the scheduler
}
}
/**
* Sends a personalized portfolio update to each user.
* Each user only receives their own portfolio data.
*/
@Scheduled(fixedDelay = 5000)
public void broadcastPortfolioUpdates() {
List<String> activeUserIds = stockPriceService.getActivePortfolioViewers();
for (String userId : activeUserIds) {
try {
PortfolioUpdate update = stockPriceService.getPortfolioUpdate(userId);
// Each user gets their own update
messagingTemplate.convertAndSendToUser(userId, "/queue/portfolio", update);
} catch (Exception e) {
log.warn("Failed to send portfolio update to user {}: {}", userId, e.getMessage());
}
}
}
}5. Security Configuration
JWT Authentication Handshake Interceptor
// src/main/java/com/example/ws/interceptor/AuthHandshakeInterceptor.java
package com.example.ws.interceptor;
import com.example.ws.security.JwtTokenProvider;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import java.util.Map;
/**
* This interceptor runs BEFORE the WebSocket connection is established.
* It validates the JWT token during the HTTP upgrade request.
*
* WHY HERE: The handshake is the only point where standard HTTP headers
* and query params are available. Once the protocol switches to WebSocket,
* there are no more HTTP headers.
*
* TOKEN PASSING OPTIONS:
* 1. Query parameter: wss://host/ws?token=eyJ... (simplest, visible in logs)
* 2. Cookie: Set-Cookie during login (more secure, works with SockJS)
* 3. Custom header during handshake (NOT supported by browser WebSocket API,
* only by programmatic clients like Java, Python etc.)
*
* For browser clients, option 1 or 2 is the only choice.
* We demonstrate option 1 (query param) here.
*/
@Component
@RequiredArgsConstructor
@Slf4j
public class AuthHandshakeInterceptor implements HandshakeInterceptor {
private final JwtTokenProvider jwtTokenProvider;
@Override
public boolean beforeHandshake(
ServerHttpRequest request,
ServerHttpResponse response,
WebSocketHandler wsHandler,
Map<String, Object> attributes) throws Exception {
// Extract token from query string: /ws?token=eyJhbGciOiJIUzI1NiJ9...
String token = extractToken(request);
if (token == null) {
log.warn("WebSocket connection attempt without token from {}", request.getRemoteAddress());
response.setStatusCode(HttpStatus.UNAUTHORIZED);
return false; // Reject the handshake
}
try {
// Validate the token and extract claims
String userId = jwtTokenProvider.validateAndGetUserId(token);
// Store userId in attributes - accessible in CustomHandshakeHandler
// and in @MessageMapping methods via SimpMessageHeaderAccessor
attributes.put("userId", userId);
attributes.put("token", token);
log.debug("WebSocket handshake authorized for user={}", userId);
return true; // Allow the handshake
} catch (Exception e) {
log.warn("Invalid JWT token during WebSocket handshake: {}", e.getMessage());
response.setStatusCode(HttpStatus.UNAUTHORIZED);
return false;
}
}
@Override
public void afterHandshake(
ServerHttpRequest request,
ServerHttpResponse response,
WebSocketHandler wsHandler,
Exception exception) {
// Log or handle post-handshake events if needed
if (exception != null) {
log.error("Error after WebSocket handshake: {}", exception.getMessage());
}
}
private String extractToken(ServerHttpRequest request) {
String query = request.getURI().getQuery();
if (query == null) return null;
for (String param : query.split("&")) {
if (param.startsWith("token=")) {
return param.substring(6);
}
}
return null;
}
}Channel Interceptor for Message-Level Security
// src/main/java/com/example/ws/interceptor/WebSocketChannelInterceptor.java
package com.example.ws.interceptor;
import com.example.ws.security.JwtTokenProvider;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component;
import java.security.Principal;
import java.util.Collections;
import java.util.Map;
/**
* This interceptor validates STOMP frames AFTER the WebSocket connection is established.
*
* Primary responsibility:
* 1. On CONNECT: Set the Spring Security Principal so that
* @AuthenticationPrincipal and SecurityContext work in @MessageMapping methods.
* 2. On SEND/SUBSCRIBE: Validate that the user is still authenticated
* (token could expire after a long connection).
*
* This is a defense-in-depth measure on top of the handshake interceptor.
*/
@Component
@RequiredArgsConstructor
@Slf4j
public class WebSocketChannelInterceptor implements ChannelInterceptor {
private final JwtTokenProvider jwtTokenProvider;
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (accessor == null) {
return message;
}
StompCommand command = accessor.getCommand();
if (StompCommand.CONNECT.equals(command)) {
// STOMP CONNECT frame - user sends this after WebSocket handshake
// At this point, our handshake interceptor has already validated the token
// and stored userId in the session attributes.
Map<String, Object> sessionAttributes = accessor.getSessionAttributes();
if (sessionAttributes != null) {
String userId = (String) sessionAttributes.get("userId");
if (userId != null) {
// Set the Principal on the accessor so Spring can route /user/** correctly
UsernamePasswordAuthenticationToken auth =
new UsernamePasswordAuthenticationToken(
userId, null,
Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"))
);
accessor.setUser(auth);
log.debug("STOMP CONNECT for user={}, session={}", userId, accessor.getSessionId());
}
}
} else if (StompCommand.SEND.equals(command)) {
// Validate the user still has a valid session
Principal user = accessor.getUser();
if (user == null) {
log.warn("Unauthenticated SEND attempt on session={}", accessor.getSessionId());
// Returning null would drop the message silently
// Throwing would send an ERROR frame back
throw new SecurityException("Unauthenticated WebSocket session");
}
} else if (StompCommand.SUBSCRIBE.equals(command)) {
// Validate subscription access
Principal user = accessor.getUser();
String destination = accessor.getDestination();
if (user <mark class="obsidian-highlight"> null) {
throw new SecurityException("Unauthenticated subscription attempt");
}
// Validate the user is allowed to subscribe to this destination
validateSubscriptionAccess(user.getName(), destination);
}
return message;
}
private void validateSubscriptionAccess(String userId, String destination) {
if (destination </mark> null) return;
// Prevent users from subscribing to other users' private queues
// Pattern: /user/{userId}/queue/... should only be accessed by that user
// Note: Spring routes /user/queue/... to /user/{sessionId}/queue/...
// but direct subscriptions need validation
if (destination.startsWith("/topic/admin") && !isAdmin(userId)) {
log.warn("Unauthorized subscription to admin topic by user={}", userId);
throw new SecurityException("Access denied to destination: " + destination);
}
}
private boolean isAdmin(String userId) {
// Check user roles from database or JWT claims
// Simplified for brevity
return false;
}
}Spring Security Configuration
// src/main/java/com/example/ws/config/SecurityConfig.java
package com.example.ws.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.web.SecurityFilterChain;
@Configuration
@EnableWebSecurity
public class SecurityConfig {
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http
.csrf(csrf -> csrf
// CSRF protection for WebSocket is handled differently.
// The WebSocket handshake is authenticated by our interceptor.
// Standard CSRF tokens don't work well with WebSocket.
.ignoringRequestMatchers("/ws/**")
)
.authorizeHttpRequests(auth -> auth
.requestMatchers("/ws/**").permitAll() // Auth handled by WebSocket interceptor
.requestMatchers("/actuator/health").permitAll()
.anyRequest().authenticated()
)
.sessionManagement(session ->
session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)
);
return http.build();
}
}6. Session Management and Connection Lifecycle
WebSocket Session Event Listener
// src/main/java/com/example/ws/listener/WebSocketEventListener.java
package com.example.ws.listener;
import com.example.ws.service.PresenceService;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
import org.springframework.web.socket.messaging.SessionUnsubscribeEvent;
import java.util.concurrent.atomic.AtomicInteger;
@Component
@Slf4j
public class WebSocketEventListener {
private final PresenceService presenceService;
private final SimpMessagingTemplate messagingTemplate;
private final AtomicInteger activeConnections = new AtomicInteger(0);
public WebSocketEventListener(
PresenceService presenceService,
SimpMessagingTemplate messagingTemplate,
MeterRegistry meterRegistry) {
this.presenceService = presenceService;
this.messagingTemplate = messagingTemplate;
// Register active connection gauge in CloudWatch
Gauge.builder("websocket.connections.active", activeConnections, AtomicInteger::get)
.description("Current number of active WebSocket connections")
.register(meterRegistry);
}
/**
* Fired when a STOMP CONNECTED frame is sent to the client.
* This means the full STOMP session is established.
*/
@EventListener
public void handleSessionConnected(SessionConnectedEvent event) {
StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
String sessionId = sha.getSessionId();
String userId = getUserId(sha);
activeConnections.incrementAndGet();
log.info("WebSocket CONNECTED - userId={}, sessionId={}, total={}",
userId, sessionId, activeConnections.get());
if (userId != null) {
presenceService.markOnline(userId, sessionId);
broadcastPresenceChange(userId, "ONLINE");
}
}
/**
* Fired when a WebSocket session disconnects - regardless of reason.
* This covers: browser close, network drop, explicit disconnect, server shutdown.
*/
@EventListener
public void handleSessionDisconnect(SessionDisconnectEvent event) {
StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
String sessionId = sha.getSessionId();
String userId = getUserId(sha);
String closeStatus = event.getCloseStatus().toString();
activeConnections.decrementAndGet();
log.info("WebSocket DISCONNECTED - userId={}, sessionId={}, status={}, total={}",
userId, sessionId, closeStatus, activeConnections.get());
if (userId != null) {
// Check if user has other active sessions before marking offline
presenceService.markOffline(userId, sessionId);
if (!presenceService.isOnline(userId)) {
broadcastPresenceChange(userId, "OFFLINE");
}
}
}
@EventListener
public void handleSessionSubscribe(SessionSubscribeEvent event) {
StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
log.debug("SUBSCRIBE - userId={}, destination={}, session={}",
getUserId(sha), sha.getDestination(), sha.getSessionId());
}
@EventListener
public void handleSessionUnsubscribe(SessionUnsubscribeEvent event) {
StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
log.debug("UNSUBSCRIBE - userId={}, subscriptionId={}, session={}",
getUserId(sha), sha.getSubscriptionId(), sha.getSessionId());
}
private String getUserId(StompHeaderAccessor sha) {
if (sha.getUser() != null) {
return sha.getUser().getName();
}
return null;
}
private void broadcastPresenceChange(String userId, String status) {
PresenceUpdate update = new PresenceUpdate(userId, status);
messagingTemplate.convertAndSend("/topic/presence", update);
}
}7. Error Handling
Global WebSocket Error Handler
// src/main/java/com/example/ws/exception/WebSocketErrorHandler.java
package com.example.ws.exception;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
/**
* Handles unrecoverable WebSocket errors.
* These are errors at the WebSocket transport level, not STOMP level.
*/
@Slf4j
public class WebSocketErrorHandler {
public void handleTransportError(WebSocketSession session, Throwable exception) {
log.error("WebSocket transport error on session={}: {}",
session.getId(), exception.getMessage());
}
}STOMP Exception Handler
// src/main/java/com/example/ws/exception/GlobalWebSocketExceptionHandler.java
package com.example.ws.exception;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.web.bind.annotation.ControllerAdvice;
/**
* Global exception handler for WebSocket @MessageMapping methods.
* Similar to @ControllerAdvice for REST controllers.
*/
@ControllerAdvice
@RequiredArgsConstructor
@Slf4j
public class GlobalWebSocketExceptionHandler {
@MessageExceptionHandler(AccessDeniedException.class)
@SendToUser("/queue/errors")
public ErrorResponse handleAccessDenied(AccessDeniedException ex) {
log.warn("Access denied in WebSocket: {}", ex.getMessage());
return new ErrorResponse("ACCESS_DENIED", ex.getMessage());
}
@MessageExceptionHandler(ValidationException.class)
@SendToUser("/queue/errors")
public ErrorResponse handleValidation(ValidationException ex) {
return new ErrorResponse("VALIDATION_ERROR", ex.getMessage());
}
@MessageExceptionHandler(Exception.class)
@SendToUser("/queue/errors")
public ErrorResponse handleGenericException(Exception ex) {
log.error("Unexpected WebSocket error: {}", ex.getMessage(), ex);
return new ErrorResponse("INTERNAL_ERROR", "An unexpected error occurred");
}
}8. Interceptors and Custom Headers
Extracting Custom STOMP Headers
// In a @MessageMapping method, extract custom STOMP headers:
@MessageMapping("/order.update")
public void handleOrderUpdate(
@Payload OrderUpdateRequest request,
@Header("correlationId") String correlationId,
SimpMessageHeaderAccessor headerAccessor,
Principal principal) {
String sessionId = headerAccessor.getSessionId();
String userId = principal.getName();
// Use correlationId for distributed tracing
MDC.put("correlationId", correlationId);
MDC.put("userId", userId);
orderService.processUpdate(userId, request, correlationId);
}9. Heartbeat and Timeout Configuration
// Heartbeat config in WebSocketConfig (from Section 2)
// Send interval: server sends a heartbeat every 25 seconds
// Receive interval: server expects a heartbeat from client every 25 seconds
.setHeartbeatValue(new long[]{25000, 25000})
/*
How heartbeat works:
- Both sides negotiate heartbeat during STOMP CONNECT/CONNECTED exchange.
- Client sends: heart-beat:10000,10000 (can send every 10s, wants to receive every 10s)
- Server sends: heart-beat:25000,25000 (can send every 25s, wants to receive every 25s)
- Actual intervals are MAX(client-desired-send, server-desired-receive)
Server sends every: max(25000, 10000) = 25000 ms
Client sends every: max(10000, 25000) = 25000 ms
If no heartbeat is received within the negotiated interval * 1.5 (tolerance factor),
the connection is considered dead and closed.
*/10. Testing WebSocket Endpoints
Integration Test with WebSocket Test Client
// src/test/java/com/example/ws/integration/WebSocketIntegrationTest.java
package com.example.ws.integration;
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.dto.ChatMessageResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.simp.stomp.*;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.socket.sockjs.client.SockJsClient;
import org.springframework.web.socket.sockjs.client.WebSocketTransport;
import java.lang.reflect.Type;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class WebSocketIntegrationTest {
@LocalServerPort
private int port;
private WebSocketStompClient stompClient;
private String wsUrl;
@BeforeEach
void setUp() {
wsUrl = "ws://localhost:" + port + "/ws";
stompClient = new WebSocketStompClient(
new SockJsClient(List.of(new WebSocketTransport(new StandardWebSocketClient())))
);
stompClient.setMessageConverter(new MappingJackson2MessageConverter());
}
@Test
void testSendAndReceiveChatMessage() throws Exception {
BlockingQueue<ChatMessageResponse> receivedMessages = new LinkedBlockingQueue<>();
// Connect with a test JWT token
StompHeaders connectHeaders = new StompHeaders();
connectHeaders.set("token", generateTestJwt("user1"));
StompSession session = stompClient.connectAsync(
wsUrl + "?token=" + generateTestJwt("user1"),
new WebSocketHttpHeaders(),
connectHeaders,
new StompSessionHandlerAdapter() {}
).get(5, TimeUnit.SECONDS);
// Subscribe to the room topic
session.subscribe("/topic/room.room1", new StompFrameHandler() {
@Override
public Type getPayloadType(StompHeaders headers) {
return ChatMessageResponse.class;
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
receivedMessages.add((ChatMessageResponse) payload);
}
});
// Send a message
ChatMessageRequest request = new ChatMessageRequest();
request.setContent("Hello, Integration Test!");
request.setMessageType("TEXT");
session.send("/app/chat.send/room1", request);
// Wait for the broadcast to arrive back
ChatMessageResponse received = receivedMessages.poll(5, TimeUnit.SECONDS);
assertThat(received).isNotNull();
assertThat(received.getContent()).isEqualTo("Hello, Integration Test!");
assertThat(received.getSenderId()).isEqualTo("user1");
session.disconnect();
}
private String generateTestJwt(String userId) {
// Generate a test JWT token - use your JwtTokenProvider or a test utility
return "test-token-for-" + userId; // Simplified; use real JWT in actual tests
}
}Unit Test for Message Controller
// src/test/java/com/example/ws/controller/ChatControllerTest.java
package com.example.ws.controller;
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.service.ChatService;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import java.security.Principal;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class ChatControllerTest {
@Mock
private ChatService chatService;
@InjectMocks
private ChatController chatController;
@Test
void testSendMessage_callsServiceWithCorrectParams() {
// Arrange
String roomId = "room1";
String userId = "user1";
ChatMessageRequest request = new ChatMessageRequest();
request.setContent("Test message");
Principal principal = () -> userId;
SimpMessageHeaderAccessor headerAccessor = mock(SimpMessageHeaderAccessor.class);
when(headerAccessor.getSessionId()).thenReturn("session1");
doNothing().when(chatService).validateRoomAccess(userId, roomId);
doNothing().when(chatService).processAndBroadcast(userId, roomId, request);
// Act
chatController.sendMessage(roomId, request, principal, headerAccessor);
// Assert
verify(chatService).validateRoomAccess(userId, roomId);
verify(chatService).processAndBroadcast(userId, roomId, request);
}
}