← Back to Articles
6/6/2026Admin Post

websockets part2 spring implementation

WebSockets Demystified - Part 2: Spring Boot Implementation

Series: Index | Part 1 | Part 2 | Part 3 | Part 4 | Part 5 | Part 6


Table of Contents

  1. Maven Dependencies
  2. WebSocket Configuration
  3. STOMP Message Controller
  4. Sending Messages from Services
  5. Security Configuration
  6. Session Management and Connection Lifecycle
  7. Error Handling
  8. Interceptors and Custom Headers
  9. Heartbeat and Timeout Configuration
  10. Testing WebSocket Endpoints

1. Maven Dependencies

<!-- pom.xml -->
<properties>
    <java.version>17</java.version>
    <spring-boot.version>3.2.0</spring-boot.version>
</properties>
 
<dependencies>
 
    <!-- Core Spring Boot Web -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
 
    <!-- WebSocket support - includes STOMP and SockJS -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-websocket</artifactId>
    </dependency>
 
    <!-- Spring Security for WebSocket authentication -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-security</artifactId>
    </dependency>
 
    <!-- JWT for token-based authentication -->
    <dependency>
        <groupId>io.jsonwebtoken</groupId>
        <artifactId>jjwt-api</artifactId>
        <version>0.12.3</version>
    </dependency>
    <dependency>
        <groupId>io.jsonwebtoken</groupId>
        <artifactId>jjwt-impl</artifactId>
        <version>0.12.3</version>
        <scope>runtime</scope>
    </dependency>
    <dependency>
        <groupId>io.jsonwebtoken</groupId>
        <artifactId>jjwt-jackson</artifactId>
        <version>0.12.3</version>
        <scope>runtime</scope>
    </dependency>
 
    <!-- Redis for pub/sub in clustered deployments -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-redis</artifactId>
    </dependency>
 
    <!-- Spring Data JPA with MySQL -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-jpa</artifactId>
    </dependency>
    <dependency>
        <groupId>com.mysql</groupId>
        <artifactId>mysql-connector-j</artifactId>
        <scope>runtime</scope>
    </dependency>
 
    <!-- Actuator for health checks and metrics -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-actuator</artifactId>
    </dependency>
 
    <!-- Micrometer for metrics (sends to CloudWatch) -->
    <dependency>
        <groupId>io.micrometer</groupId>
        <artifactId>micrometer-registry-cloudwatch2</artifactId>
    </dependency>
 
    <!-- Validation -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-validation</artifactId>
    </dependency>
 
    <!-- Lombok -->
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <optional>true</optional>
    </dependency>
 
    <!-- Testing -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-test</artifactId>
        <scope>test</scope>
    </dependency>
    <dependency>
        <groupId>org.springframework.security</groupId>
        <artifactId>spring-security-test</artifactId>
        <scope>test</scope>
    </dependency>
 
</dependencies>

2. WebSocket Configuration

Application Properties

# src/main/resources/application.yml
 
server:
  port: 8080
  # Important: Tomcat WebSocket configuration
  tomcat:
    max-connections: 10000
    threads:
      max: 200
      min-spare: 20
 
spring:
  application:
    name: websocket-service
 
  datasource:
    url: jdbc:mysql://localhost:3306/wsdb?useSSL=true&serverTimezone=UTC&allowPublicKeyRetrieval=true
    username: ${DB_USERNAME}
    password: ${DB_PASSWORD}
    hikari:
      maximum-pool-size: 20
      minimum-idle: 5
      idle-timeout: 30000
      connection-timeout: 20000
 
  jpa:
    hibernate:
      ddl-auto: validate
    show-sql: false
    properties:
      hibernate:
        dialect: org.hibernate.dialect.MySQL8Dialect
        format_sql: true
 
  data:
    redis:
      host: ${REDIS_HOST:localhost}
      port: ${REDIS_PORT:6379}
      password: ${REDIS_PASSWORD:}
      ssl:
        enabled: ${REDIS_SSL:false}
      timeout: 2000ms
      lettuce:
        pool:
          max-active: 20
          max-idle: 10
          min-idle: 5
 
app:
  websocket:
    # Heartbeat intervals in milliseconds
    heartbeat-send-interval: 25000
    heartbeat-receive-interval: 25000
    # Message size limits
    message-size-limit: 65536   # 64 KB per message
    send-buffer-size-limit: 524288  # 512 KB per connection send buffer
    send-time-limit: 20000      # 20 seconds max to send a message before disconnect
    # Allowed origins - MUST be explicitly configured
    allowed-origins:
      - "https://app.example.com"
      - "https://www.example.com"
 
  jwt:
    secret: ${JWT_SECRET}
    expiration-ms: 86400000
 
management:
  endpoints:
    web:
      exposure:
        include: health,info,metrics,prometheus
  metrics:
    export:
      cloudwatch:
        namespace: WebSocketService
        step: 1m

Core WebSocket Configuration Class

// src/main/java/com/example/ws/config/WebSocketConfig.java
 
package com.example.ws.config;
 
import com.example.ws.interceptor.AuthHandshakeInterceptor;
import com.example.ws.interceptor.WebSocketChannelInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
 
import java.util.List;
 
@Configuration
@EnableWebSocketMessageBroker  // Enables STOMP messaging with a full-featured message broker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
 
    @Autowired
    private AuthHandshakeInterceptor authHandshakeInterceptor;
 
    @Autowired
    private WebSocketChannelInterceptor channelInterceptor;
 
    @Value("${app.websocket.allowed-origins}")
    private List<String> allowedOrigins;
 
    @Value("${app.websocket.message-size-limit}")
    private int messageSizeLimit;
 
    @Value("${app.websocket.send-buffer-size-limit}")
    private int sendBufferSizeLimit;
 
    @Value("${app.websocket.send-time-limit}")
    private int sendTimeLimit;
 
    @Value("${app.websocket.heartbeat-send-interval}")
    private long heartbeatSendInterval;
 
    @Value("${app.websocket.heartbeat-receive-interval}")
    private long heartbeatReceiveInterval;
 
    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
 
        registry.addEndpoint("/ws")
            // STOMP endpoint path - clients connect to wss://host/ws
            .setHandshakeHandler(new CustomHandshakeHandler())
            // Add our JWT authentication interceptor
            .addInterceptors(authHandshakeInterceptor)
            // Allowed origins - NEVER use "*" in production
            .setAllowedOrigins(allowedOrigins.toArray(new String[0]))
            // SockJS fallback for environments that don't support WebSocket
            .withSockJS()
                // SockJS heartbeat
                .setHeartbeatTime(25000)
                // Disconnect delay before considering a client gone
                .setDisconnectDelay(5000)
                // Timeout for HTTP-based transports
                .setHttpMessageCacheSize(1000)
                .setWebSocketEnabled(true);
    }
 
    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
 
        // Application destination prefix - messages sent here go to @MessageMapping methods
        registry.setApplicationDestinationPrefixes("/app");
 
        // Topic prefix - server broadcasts to these (all subscribers receive)
        // Queue prefix - server sends to specific users
        registry.enableSimpleBroker("/topic", "/queue")
            .setHeartbeatValue(new long[]{heartbeatSendInterval, heartbeatReceiveInterval})
            .setTaskScheduler(heartbeatTaskScheduler());
 
        // User destination prefix - for private messages to a specific user
        // /user/queue/notifications -> sends to a specific user's notification queue
        registry.setUserDestinationPrefix("/user");
    }
 
    @Override
    public void configureWebSocketTransport(WebSocketTransportRegistration registration) {
        registration
            // Maximum size of a single incoming message (64 KB default)
            .setMessageSizeLimit(messageSizeLimit)
            // Maximum size of data that can be buffered per session when sending
            .setSendBufferSizeLimit(sendBufferSizeLimit)
            // Maximum time to block when sending to a client before it is considered dead
            .setSendTimeLimit(sendTimeLimit);
    }
 
    @Override
    public void configureClientInboundChannel(ChannelRegistration registration) {
        // Add our security/validation interceptor to inbound messages
        registration.interceptors(channelInterceptor);
        // Thread pool for processing inbound messages
        registration.taskExecutor()
            .corePoolSize(4)
            .maxPoolSize(20)
            .queueCapacity(100);
    }
 
    @Override
    public void configureClientOutboundChannel(ChannelRegistration registration) {
        // Thread pool for sending outbound messages
        registration.taskExecutor()
            .corePoolSize(4)
            .maxPoolSize(20)
            .queueCapacity(100);
    }
 
    // Task scheduler for heartbeat - required when configuring heartbeat on the broker
    @Bean
    public org.springframework.scheduling.TaskScheduler heartbeatTaskScheduler() {
        org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler scheduler =
            new org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler();
        scheduler.setPoolSize(2);
        scheduler.setThreadNamePrefix("ws-heartbeat-");
        scheduler.initialize();
        return scheduler;
    }
}

Custom Handshake Handler

// src/main/java/com/example/ws/config/CustomHandshakeHandler.java
 
package com.example.ws.config;
 
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
 
import java.security.Principal;
import java.util.Map;
 
/**
 * Customizes the Principal assigned to each WebSocket session.
 * Spring uses the Principal's name to route /user/** destinations.
 * We set it from the JWT token verified during the handshake.
 */
public class CustomHandshakeHandler extends DefaultHandshakeHandler {
 
    @Override
    protected Principal determineUser(
            ServerHttpRequest request,
            WebSocketHandler wsHandler,
            Map<String, Object> attributes) {
 
        // The JWT interceptor (AuthHandshakeInterceptor) has already
        // extracted and validated the token, and stored the userId in attributes.
        String userId = (String) attributes.get("userId");
 
        if (userId == null) {
            // This should not happen if AuthHandshakeInterceptor is correct.
            // Return null or throw - returning null means this session has no principal.
            return null;
        }
 
        // Return a simple Principal wrapping the userId.
        // Spring uses this name to route messages to /user/{userId}/queue/...
        return () -> userId;
    }
}

3. STOMP Message Controller

Understanding STOMP Destinations

STOMP Destination Types in Spring:

/app/chat.send         -> Goes to @MessageMapping("/chat.send")
/topic/room.123        -> Broadcast to all subscribers of this topic
/queue/private         -> Point-to-point queue
/user/queue/notify     -> User-specific destination (routes to /queue/notify for that user)

Chat Message Controller

// src/main/java/com/example/ws/controller/ChatController.java
 
package com.example.ws.controller;
 
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.dto.ChatMessageResponse;
import com.example.ws.dto.TypingIndicator;
import com.example.ws.service.ChatService;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.handler.annotation.DestinationVariable;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.stereotype.Controller;
 
import java.security.Principal;
 
@Controller
@RequiredArgsConstructor
@Slf4j
public class ChatController {
 
    private final ChatService chatService;
 
    /**
     * Handles a message sent to /app/chat.send/{roomId}
     *
     * The @MessageMapping path is relative to the app prefix configured in WebSocketConfig.
     * Full destination: /app/chat.send/{roomId}
     *
     * The @Valid annotation triggers Bean Validation on the payload.
     *
     * The Principal is automatically injected and contains the userId
     * from our CustomHandshakeHandler.
     */
    @MessageMapping("/chat.send/{roomId}")
    public void sendMessage(
            @DestinationVariable String roomId,
            @Valid @Payload ChatMessageRequest request,
            Principal principal,
            SimpMessageHeaderAccessor headerAccessor) {
 
        String sessionId = headerAccessor.getSessionId();
        String userId = principal.getName();
 
        log.debug("Message from user={} in room={}, session={}", userId, roomId, sessionId);
 
        // Validate user has access to this room (business logic check)
        chatService.validateRoomAccess(userId, roomId);
 
        // Process and broadcast the message
        chatService.processAndBroadcast(userId, roomId, request);
    }
 
    /**
     * Handles typing indicator events.
     * Sent to /app/chat.typing/{roomId}
     * Broadcasts to /topic/room.{roomId}.typing
     */
    @MessageMapping("/chat.typing/{roomId}")
    public void handleTyping(
            @DestinationVariable String roomId,
            @Payload TypingIndicator indicator,
            Principal principal) {
 
        String userId = principal.getName();
        chatService.broadcastTypingIndicator(userId, roomId, indicator.isTyping());
    }
 
    /**
     * @SubscribeMapping fires immediately when a client subscribes.
     * The return value is sent directly back to the subscribing client.
     * Full destination client subscribes to: /app/chat.history/{roomId}
     *
     * This is useful for sending initial state when a client subscribes.
     */
    @SubscribeMapping("/chat.history/{roomId}")
    public ChatHistoryResponse onSubscribeToRoom(
            @DestinationVariable String roomId,
            Principal principal) {
 
        String userId = principal.getName();
        chatService.validateRoomAccess(userId, roomId);
 
        // Return last 50 messages - delivered directly to the subscribing client
        return chatService.getRecentMessages(roomId, 50);
    }
 
    /**
     * Handles errors thrown by any @MessageMapping method in this controller.
     * The @SendToUser annotation sends the error only to the user who triggered it.
     */
    @MessageExceptionHandler
    @SendToUser("/queue/errors")
    public ErrorResponse handleException(Throwable exception, Principal principal) {
        log.error("WebSocket message error for user={}: {}", principal.getName(), exception.getMessage());
        return new ErrorResponse(exception.getMessage());
    }
}

DTOs with Validation

// src/main/java/com/example/ws/dto/ChatMessageRequest.java
 
package com.example.ws.dto;
 
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;
import lombok.Data;
 
@Data
public class ChatMessageRequest {
 
    @NotBlank(message = "Message content must not be blank")
    @Size(max = 4000, message = "Message must not exceed 4000 characters")
    private String content;
 
    private String messageType = "TEXT"; // TEXT, IMAGE, FILE
}
// src/main/java/com/example/ws/dto/ChatMessageResponse.java
 
package com.example.ws.dto;
 
import lombok.Builder;
import lombok.Data;
 
import java.time.Instant;
 
@Data
@Builder
public class ChatMessageResponse {
 
    private String messageId;
    private String roomId;
    private String senderId;
    private String senderName;
    private String content;
    private String messageType;
    private Instant timestamp;
    private String status; // SENT, DELIVERED, READ
}
// src/main/java/com/example/ws/dto/TypingIndicator.java
 
package com.example.ws.dto;
 
import lombok.Data;
 
@Data
public class TypingIndicator {
    private boolean typing;
}
// src/main/java/com/example/ws/dto/ErrorResponse.java
 
package com.example.ws.dto;
 
import lombok.AllArgsConstructor;
import lombok.Data;
import java.time.Instant;
 
@Data
@AllArgsConstructor
public class ErrorResponse {
    private String message;
    private Instant timestamp = Instant.now();
 
    public ErrorResponse(String message) {
        this.message = message;
    }
}

4. Sending Messages from Services

SimpMessagingTemplate - The Core Tool

SimpMessagingTemplate is Spring's primary way to push messages to WebSocket clients from anywhere in your application - services, scheduled tasks, event listeners.

// src/main/java/com/example/ws/service/ChatService.java
 
package com.example.ws.service;
 
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.dto.ChatMessageResponse;
import com.example.ws.entity.ChatMessage;
import com.example.ws.entity.ChatRoom;
import com.example.ws.exception.AccessDeniedException;
import com.example.ws.exception.RoomNotFoundException;
import com.example.ws.repository.ChatMessageRepository;
import com.example.ws.repository.ChatRoomRepository;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
 
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
 
@Service
@Slf4j
public class ChatService {
 
    private final SimpMessagingTemplate messagingTemplate;
    private final ChatMessageRepository messageRepository;
    private final ChatRoomRepository roomRepository;
    private final Counter messagesSentCounter;
 
    public ChatService(
            SimpMessagingTemplate messagingTemplate,
            ChatMessageRepository messageRepository,
            ChatRoomRepository roomRepository,
            MeterRegistry meterRegistry) {
 
        this.messagingTemplate = messagingTemplate;
        this.messageRepository = messageRepository;
        this.roomRepository = roomRepository;
 
        // Register metrics - these go to CloudWatch via Micrometer
        this.messagesSentCounter = Counter.builder("websocket.messages.sent")
            .description("Total WebSocket messages sent")
            .register(meterRegistry);
    }
 
    /**
     * Validates that a user has access to a room.
     * Throws AccessDeniedException if they do not.
     */
    public void validateRoomAccess(String userId, String roomId) {
        if (!roomRepository.isMember(userId, roomId)) {
            throw new AccessDeniedException("User " + userId + " is not a member of room " + roomId);
        }
    }
 
    /**
     * Saves the message to MySQL and broadcasts it to all room subscribers.
     */
    @Transactional
    public void processAndBroadcast(String userId, String roomId, ChatMessageRequest request) {
 
        // 1. Build entity and persist to MySQL
        ChatMessage message = ChatMessage.builder()
            .messageId(UUID.randomUUID().toString())
            .roomId(roomId)
            .senderId(userId)
            .content(request.getContent())
            .messageType(request.getMessageType())
            .status("SENT")
            .createdAt(Instant.now())
            .build();
 
        ChatMessage saved = messageRepository.save(message);
 
        // 2. Build the response DTO
        ChatMessageResponse response = buildResponse(saved, userId);
 
        // 3. Broadcast to all subscribers of this room topic
        // All clients subscribed to /topic/room.{roomId} will receive this message
        messagingTemplate.convertAndSend("/topic/room." + roomId, response);
 
        // 4. Update metrics
        messagesSentCounter.increment();
 
        log.debug("Broadcast message {} to room {}", saved.getMessageId(), roomId);
    }
 
    /**
     * Sends a message to a SPECIFIC user only.
     * Uses the user destination prefix - only the user with userId receives it.
     * Client subscribes to: /user/queue/private
     */
    public void sendPrivateMessage(String targetUserId, ChatMessageResponse response) {
        // convertAndSendToUser prepends /user/{userId} automatically
        messagingTemplate.convertAndSendToUser(
            targetUserId,
            "/queue/private",
            response
        );
    }
 
    /**
     * Broadcasts a typing indicator to the room.
     * These are ephemeral - not stored in DB.
     */
    public void broadcastTypingIndicator(String userId, String roomId, boolean isTyping) {
        TypingEvent event = new TypingEvent(userId, roomId, isTyping);
        messagingTemplate.convertAndSend("/topic/room." + roomId + ".typing", event);
    }
 
    /**
     * Retrieves recent messages for a room.
     * Called when a user first subscribes.
     */
    @Transactional(readOnly = true)
    public ChatHistoryResponse getRecentMessages(String roomId, int limit) {
        List<ChatMessage> messages = messageRepository
            .findTopByRoomIdOrderByCreatedAtDesc(roomId, limit);
 
        List<ChatMessageResponse> responses = messages.stream()
            .map(m -> buildResponse(m, m.getSenderId()))
            .collect(Collectors.toList());
 
        return new ChatHistoryResponse(roomId, responses);
    }
 
    private ChatMessageResponse buildResponse(ChatMessage message, String userId) {
        return ChatMessageResponse.builder()
            .messageId(message.getMessageId())
            .roomId(message.getRoomId())
            .senderId(message.getSenderId())
            .content(message.getContent())
            .messageType(message.getMessageType())
            .timestamp(message.getCreatedAt())
            .status(message.getStatus())
            .build();
    }
}

Sending Messages from a Scheduled Task

// src/main/java/com/example/ws/scheduler/StockPriceScheduler.java
 
package com.example.ws.scheduler;
 
import com.example.ws.dto.PriceUpdate;
import com.example.ws.service.StockPriceService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
 
import java.util.List;
 
@Component
@RequiredArgsConstructor
@Slf4j
public class StockPriceScheduler {
 
    private final SimpMessagingTemplate messagingTemplate;
    private final StockPriceService stockPriceService;
 
    /**
     * Pushes live stock prices to all subscribers every 1 second.
     * Clients subscribe to /topic/prices to receive these updates.
     */
    @Scheduled(fixedDelay = 1000)
    public void broadcastPrices() {
        try {
            List<PriceUpdate> prices = stockPriceService.getLatestPrices();
            messagingTemplate.convertAndSend("/topic/prices", prices);
        } catch (Exception e) {
            log.error("Failed to broadcast price update: {}", e.getMessage());
            // Do not let the exception propagate - it would stop the scheduler
        }
    }
 
    /**
     * Sends a personalized portfolio update to each user.
     * Each user only receives their own portfolio data.
     */
    @Scheduled(fixedDelay = 5000)
    public void broadcastPortfolioUpdates() {
        List<String> activeUserIds = stockPriceService.getActivePortfolioViewers();
        for (String userId : activeUserIds) {
            try {
                PortfolioUpdate update = stockPriceService.getPortfolioUpdate(userId);
                // Each user gets their own update
                messagingTemplate.convertAndSendToUser(userId, "/queue/portfolio", update);
            } catch (Exception e) {
                log.warn("Failed to send portfolio update to user {}: {}", userId, e.getMessage());
            }
        }
    }
}

5. Security Configuration

JWT Authentication Handshake Interceptor

// src/main/java/com/example/ws/interceptor/AuthHandshakeInterceptor.java
 
package com.example.ws.interceptor;
 
import com.example.ws.security.JwtTokenProvider;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
 
import java.util.Map;
 
/**
 * This interceptor runs BEFORE the WebSocket connection is established.
 * It validates the JWT token during the HTTP upgrade request.
 *
 * WHY HERE: The handshake is the only point where standard HTTP headers
 * and query params are available. Once the protocol switches to WebSocket,
 * there are no more HTTP headers.
 *
 * TOKEN PASSING OPTIONS:
 * 1. Query parameter: wss://host/ws?token=eyJ...  (simplest, visible in logs)
 * 2. Cookie: Set-Cookie during login (more secure, works with SockJS)
 * 3. Custom header during handshake (NOT supported by browser WebSocket API,
 *    only by programmatic clients like Java, Python etc.)
 *
 * For browser clients, option 1 or 2 is the only choice.
 * We demonstrate option 1 (query param) here.
 */
@Component
@RequiredArgsConstructor
@Slf4j
public class AuthHandshakeInterceptor implements HandshakeInterceptor {
 
    private final JwtTokenProvider jwtTokenProvider;
 
    @Override
    public boolean beforeHandshake(
            ServerHttpRequest request,
            ServerHttpResponse response,
            WebSocketHandler wsHandler,
            Map<String, Object> attributes) throws Exception {
 
        // Extract token from query string: /ws?token=eyJhbGciOiJIUzI1NiJ9...
        String token = extractToken(request);
 
        if (token == null) {
            log.warn("WebSocket connection attempt without token from {}", request.getRemoteAddress());
            response.setStatusCode(HttpStatus.UNAUTHORIZED);
            return false; // Reject the handshake
        }
 
        try {
            // Validate the token and extract claims
            String userId = jwtTokenProvider.validateAndGetUserId(token);
 
            // Store userId in attributes - accessible in CustomHandshakeHandler
            // and in @MessageMapping methods via SimpMessageHeaderAccessor
            attributes.put("userId", userId);
            attributes.put("token", token);
 
            log.debug("WebSocket handshake authorized for user={}", userId);
            return true; // Allow the handshake
 
        } catch (Exception e) {
            log.warn("Invalid JWT token during WebSocket handshake: {}", e.getMessage());
            response.setStatusCode(HttpStatus.UNAUTHORIZED);
            return false;
        }
    }
 
    @Override
    public void afterHandshake(
            ServerHttpRequest request,
            ServerHttpResponse response,
            WebSocketHandler wsHandler,
            Exception exception) {
        // Log or handle post-handshake events if needed
        if (exception != null) {
            log.error("Error after WebSocket handshake: {}", exception.getMessage());
        }
    }
 
    private String extractToken(ServerHttpRequest request) {
        String query = request.getURI().getQuery();
        if (query == null) return null;
 
        for (String param : query.split("&")) {
            if (param.startsWith("token=")) {
                return param.substring(6);
            }
        }
        return null;
    }
}

Channel Interceptor for Message-Level Security

// src/main/java/com/example/ws/interceptor/WebSocketChannelInterceptor.java
 
package com.example.ws.interceptor;
 
import com.example.ws.security.JwtTokenProvider;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component;
 
import java.security.Principal;
import java.util.Collections;
import java.util.Map;
 
/**
 * This interceptor validates STOMP frames AFTER the WebSocket connection is established.
 *
 * Primary responsibility:
 * 1. On CONNECT: Set the Spring Security Principal so that
 *    @AuthenticationPrincipal and SecurityContext work in @MessageMapping methods.
 * 2. On SEND/SUBSCRIBE: Validate that the user is still authenticated
 *    (token could expire after a long connection).
 *
 * This is a defense-in-depth measure on top of the handshake interceptor.
 */
@Component
@RequiredArgsConstructor
@Slf4j
public class WebSocketChannelInterceptor implements ChannelInterceptor {
 
    private final JwtTokenProvider jwtTokenProvider;
 
    @Override
    public Message<?> preSend(Message<?> message, MessageChannel channel) {
 
        StompHeaderAccessor accessor =
            MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
 
        if (accessor == null) {
            return message;
        }
 
        StompCommand command = accessor.getCommand();
 
        if (StompCommand.CONNECT.equals(command)) {
            // STOMP CONNECT frame - user sends this after WebSocket handshake
            // At this point, our handshake interceptor has already validated the token
            // and stored userId in the session attributes.
 
            Map<String, Object> sessionAttributes = accessor.getSessionAttributes();
            if (sessionAttributes != null) {
                String userId = (String) sessionAttributes.get("userId");
                if (userId != null) {
                    // Set the Principal on the accessor so Spring can route /user/** correctly
                    UsernamePasswordAuthenticationToken auth =
                        new UsernamePasswordAuthenticationToken(
                            userId, null,
                            Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"))
                        );
                    accessor.setUser(auth);
                    log.debug("STOMP CONNECT for user={}, session={}", userId, accessor.getSessionId());
                }
            }
 
        } else if (StompCommand.SEND.equals(command)) {
            // Validate the user still has a valid session
            Principal user = accessor.getUser();
            if (user == null) {
                log.warn("Unauthenticated SEND attempt on session={}", accessor.getSessionId());
                // Returning null would drop the message silently
                // Throwing would send an ERROR frame back
                throw new SecurityException("Unauthenticated WebSocket session");
            }
 
        } else if (StompCommand.SUBSCRIBE.equals(command)) {
            // Validate subscription access
            Principal user = accessor.getUser();
            String destination = accessor.getDestination();
 
            if (user <mark class="obsidian-highlight"> null) {
                throw new SecurityException("Unauthenticated subscription attempt");
            }
 
            // Validate the user is allowed to subscribe to this destination
            validateSubscriptionAccess(user.getName(), destination);
        }
 
        return message;
    }
 
    private void validateSubscriptionAccess(String userId, String destination) {
        if (destination </mark> null) return;
 
        // Prevent users from subscribing to other users' private queues
        // Pattern: /user/{userId}/queue/... should only be accessed by that user
        // Note: Spring routes /user/queue/... to /user/{sessionId}/queue/...
        // but direct subscriptions need validation
 
        if (destination.startsWith("/topic/admin") && !isAdmin(userId)) {
            log.warn("Unauthorized subscription to admin topic by user={}", userId);
            throw new SecurityException("Access denied to destination: " + destination);
        }
    }
 
    private boolean isAdmin(String userId) {
        // Check user roles from database or JWT claims
        // Simplified for brevity
        return false;
    }
}

Spring Security Configuration

// src/main/java/com/example/ws/config/SecurityConfig.java
 
package com.example.ws.config;
 
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.web.SecurityFilterChain;
 
@Configuration
@EnableWebSecurity
public class SecurityConfig {
 
    @Bean
    public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
        http
            .csrf(csrf -> csrf
                // CSRF protection for WebSocket is handled differently.
                // The WebSocket handshake is authenticated by our interceptor.
                // Standard CSRF tokens don't work well with WebSocket.
                .ignoringRequestMatchers("/ws/**")
            )
            .authorizeHttpRequests(auth -> auth
                .requestMatchers("/ws/**").permitAll() // Auth handled by WebSocket interceptor
                .requestMatchers("/actuator/health").permitAll()
                .anyRequest().authenticated()
            )
            .sessionManagement(session ->
                session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)
            );
 
        return http.build();
    }
}

6. Session Management and Connection Lifecycle

WebSocket Session Event Listener

// src/main/java/com/example/ws/listener/WebSocketEventListener.java
 
package com.example.ws.listener;
 
import com.example.ws.service.PresenceService;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
import org.springframework.web.socket.messaging.SessionUnsubscribeEvent;
 
import java.util.concurrent.atomic.AtomicInteger;
 
@Component
@Slf4j
public class WebSocketEventListener {
 
    private final PresenceService presenceService;
    private final SimpMessagingTemplate messagingTemplate;
    private final AtomicInteger activeConnections = new AtomicInteger(0);
 
    public WebSocketEventListener(
            PresenceService presenceService,
            SimpMessagingTemplate messagingTemplate,
            MeterRegistry meterRegistry) {
 
        this.presenceService = presenceService;
        this.messagingTemplate = messagingTemplate;
 
        // Register active connection gauge in CloudWatch
        Gauge.builder("websocket.connections.active", activeConnections, AtomicInteger::get)
            .description("Current number of active WebSocket connections")
            .register(meterRegistry);
    }
 
    /**
     * Fired when a STOMP CONNECTED frame is sent to the client.
     * This means the full STOMP session is established.
     */
    @EventListener
    public void handleSessionConnected(SessionConnectedEvent event) {
        StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
        String sessionId = sha.getSessionId();
        String userId = getUserId(sha);
 
        activeConnections.incrementAndGet();
 
        log.info("WebSocket CONNECTED - userId={}, sessionId={}, total={}",
            userId, sessionId, activeConnections.get());
 
        if (userId != null) {
            presenceService.markOnline(userId, sessionId);
            broadcastPresenceChange(userId, "ONLINE");
        }
    }
 
    /**
     * Fired when a WebSocket session disconnects - regardless of reason.
     * This covers: browser close, network drop, explicit disconnect, server shutdown.
     */
    @EventListener
    public void handleSessionDisconnect(SessionDisconnectEvent event) {
        StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
        String sessionId = sha.getSessionId();
        String userId = getUserId(sha);
        String closeStatus = event.getCloseStatus().toString();
 
        activeConnections.decrementAndGet();
 
        log.info("WebSocket DISCONNECTED - userId={}, sessionId={}, status={}, total={}",
            userId, sessionId, closeStatus, activeConnections.get());
 
        if (userId != null) {
            // Check if user has other active sessions before marking offline
            presenceService.markOffline(userId, sessionId);
 
            if (!presenceService.isOnline(userId)) {
                broadcastPresenceChange(userId, "OFFLINE");
            }
        }
    }
 
    @EventListener
    public void handleSessionSubscribe(SessionSubscribeEvent event) {
        StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
        log.debug("SUBSCRIBE - userId={}, destination={}, session={}",
            getUserId(sha), sha.getDestination(), sha.getSessionId());
    }
 
    @EventListener
    public void handleSessionUnsubscribe(SessionUnsubscribeEvent event) {
        StompHeaderAccessor sha = StompHeaderAccessor.wrap(event.getMessage());
        log.debug("UNSUBSCRIBE - userId={}, subscriptionId={}, session={}",
            getUserId(sha), sha.getSubscriptionId(), sha.getSessionId());
    }
 
    private String getUserId(StompHeaderAccessor sha) {
        if (sha.getUser() != null) {
            return sha.getUser().getName();
        }
        return null;
    }
 
    private void broadcastPresenceChange(String userId, String status) {
        PresenceUpdate update = new PresenceUpdate(userId, status);
        messagingTemplate.convertAndSend("/topic/presence", update);
    }
}

7. Error Handling

Global WebSocket Error Handler

// src/main/java/com/example/ws/exception/WebSocketErrorHandler.java
 
package com.example.ws.exception;
 
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
 
/**
 * Handles unrecoverable WebSocket errors.
 * These are errors at the WebSocket transport level, not STOMP level.
 */
@Slf4j
public class WebSocketErrorHandler {
 
    public void handleTransportError(WebSocketSession session, Throwable exception) {
        log.error("WebSocket transport error on session={}: {}",
            session.getId(), exception.getMessage());
    }
}

STOMP Exception Handler

// src/main/java/com/example/ws/exception/GlobalWebSocketExceptionHandler.java
 
package com.example.ws.exception;
 
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.web.bind.annotation.ControllerAdvice;
 
/**
 * Global exception handler for WebSocket @MessageMapping methods.
 * Similar to @ControllerAdvice for REST controllers.
 */
@ControllerAdvice
@RequiredArgsConstructor
@Slf4j
public class GlobalWebSocketExceptionHandler {
 
    @MessageExceptionHandler(AccessDeniedException.class)
    @SendToUser("/queue/errors")
    public ErrorResponse handleAccessDenied(AccessDeniedException ex) {
        log.warn("Access denied in WebSocket: {}", ex.getMessage());
        return new ErrorResponse("ACCESS_DENIED", ex.getMessage());
    }
 
    @MessageExceptionHandler(ValidationException.class)
    @SendToUser("/queue/errors")
    public ErrorResponse handleValidation(ValidationException ex) {
        return new ErrorResponse("VALIDATION_ERROR", ex.getMessage());
    }
 
    @MessageExceptionHandler(Exception.class)
    @SendToUser("/queue/errors")
    public ErrorResponse handleGenericException(Exception ex) {
        log.error("Unexpected WebSocket error: {}", ex.getMessage(), ex);
        return new ErrorResponse("INTERNAL_ERROR", "An unexpected error occurred");
    }
}

8. Interceptors and Custom Headers

Extracting Custom STOMP Headers

// In a @MessageMapping method, extract custom STOMP headers:
 
@MessageMapping("/order.update")
public void handleOrderUpdate(
        @Payload OrderUpdateRequest request,
        @Header("correlationId") String correlationId,
        SimpMessageHeaderAccessor headerAccessor,
        Principal principal) {
 
    String sessionId = headerAccessor.getSessionId();
    String userId = principal.getName();
 
    // Use correlationId for distributed tracing
    MDC.put("correlationId", correlationId);
    MDC.put("userId", userId);
 
    orderService.processUpdate(userId, request, correlationId);
}

9. Heartbeat and Timeout Configuration

// Heartbeat config in WebSocketConfig (from Section 2)
// Send interval: server sends a heartbeat every 25 seconds
// Receive interval: server expects a heartbeat from client every 25 seconds
 
.setHeartbeatValue(new long[]{25000, 25000})
 
/*
How heartbeat works:
- Both sides negotiate heartbeat during STOMP CONNECT/CONNECTED exchange.
- Client sends: heart-beat:10000,10000  (can send every 10s, wants to receive every 10s)
- Server sends: heart-beat:25000,25000  (can send every 25s, wants to receive every 25s)
- Actual intervals are MAX(client-desired-send, server-desired-receive)
   Server sends every: max(25000, 10000) = 25000 ms
   Client sends every: max(10000, 25000) = 25000 ms
 
If no heartbeat is received within the negotiated interval * 1.5 (tolerance factor),
the connection is considered dead and closed.
*/

10. Testing WebSocket Endpoints

Integration Test with WebSocket Test Client

// src/test/java/com/example/ws/integration/WebSocketIntegrationTest.java
 
package com.example.ws.integration;
 
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.dto.ChatMessageResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.simp.stomp.*;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.socket.sockjs.client.SockJsClient;
import org.springframework.web.socket.sockjs.client.WebSocketTransport;
 
import java.lang.reflect.Type;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
 
import static org.assertj.core.api.Assertions.assertThat;
 
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class WebSocketIntegrationTest {
 
    @LocalServerPort
    private int port;
 
    private WebSocketStompClient stompClient;
    private String wsUrl;
 
    @BeforeEach
    void setUp() {
        wsUrl = "ws://localhost:" + port + "/ws";
 
        stompClient = new WebSocketStompClient(
            new SockJsClient(List.of(new WebSocketTransport(new StandardWebSocketClient())))
        );
        stompClient.setMessageConverter(new MappingJackson2MessageConverter());
    }
 
    @Test
    void testSendAndReceiveChatMessage() throws Exception {
        BlockingQueue<ChatMessageResponse> receivedMessages = new LinkedBlockingQueue<>();
 
        // Connect with a test JWT token
        StompHeaders connectHeaders = new StompHeaders();
        connectHeaders.set("token", generateTestJwt("user1"));
 
        StompSession session = stompClient.connectAsync(
            wsUrl + "?token=" + generateTestJwt("user1"),
            new WebSocketHttpHeaders(),
            connectHeaders,
            new StompSessionHandlerAdapter() {}
        ).get(5, TimeUnit.SECONDS);
 
        // Subscribe to the room topic
        session.subscribe("/topic/room.room1", new StompFrameHandler() {
            @Override
            public Type getPayloadType(StompHeaders headers) {
                return ChatMessageResponse.class;
            }
 
            @Override
            public void handleFrame(StompHeaders headers, Object payload) {
                receivedMessages.add((ChatMessageResponse) payload);
            }
        });
 
        // Send a message
        ChatMessageRequest request = new ChatMessageRequest();
        request.setContent("Hello, Integration Test!");
        request.setMessageType("TEXT");
 
        session.send("/app/chat.send/room1", request);
 
        // Wait for the broadcast to arrive back
        ChatMessageResponse received = receivedMessages.poll(5, TimeUnit.SECONDS);
 
        assertThat(received).isNotNull();
        assertThat(received.getContent()).isEqualTo("Hello, Integration Test!");
        assertThat(received.getSenderId()).isEqualTo("user1");
 
        session.disconnect();
    }
 
    private String generateTestJwt(String userId) {
        // Generate a test JWT token - use your JwtTokenProvider or a test utility
        return "test-token-for-" + userId; // Simplified; use real JWT in actual tests
    }
}

Unit Test for Message Controller

// src/test/java/com/example/ws/controller/ChatControllerTest.java
 
package com.example.ws.controller;
 
import com.example.ws.dto.ChatMessageRequest;
import com.example.ws.service.ChatService;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
 
import java.security.Principal;
 
import static org.mockito.Mockito.*;
 
@ExtendWith(MockitoExtension.class)
class ChatControllerTest {
 
    @Mock
    private ChatService chatService;
 
    @InjectMocks
    private ChatController chatController;
 
    @Test
    void testSendMessage_callsServiceWithCorrectParams() {
        // Arrange
        String roomId = "room1";
        String userId = "user1";
        ChatMessageRequest request = new ChatMessageRequest();
        request.setContent("Test message");
 
        Principal principal = () -> userId;
        SimpMessageHeaderAccessor headerAccessor = mock(SimpMessageHeaderAccessor.class);
        when(headerAccessor.getSessionId()).thenReturn("session1");
 
        doNothing().when(chatService).validateRoomAccess(userId, roomId);
        doNothing().when(chatService).processAndBroadcast(userId, roomId, request);
 
        // Act
        chatController.sendMessage(roomId, request, principal, headerAccessor);
 
        // Assert
        verify(chatService).validateRoomAccess(userId, roomId);
        verify(chatService).processAndBroadcast(userId, roomId, request);
    }
}

Next: Part 3 - Advanced Patterns