From a82637f09f06945c45cd7efef0cef063ade7ce47 Mon Sep 17 00:00:00 2001 From: linxiuqiang <15060002560@163.com> Date: Sun, 14 Jun 2026 16:51:52 +0800 Subject: [PATCH] Evict streamable HTTP sessions after failed keep-alive pings --- ...vletStreamableServerTransportProvider.java | 48 ++++++ .../util/KeepAliveScheduler.java | 67 +++++++- ...treamableServerTransportProviderTests.java | 153 ++++++++++++++++++ .../util/KeepAliveSchedulerTests.java | 117 ++++++++++++++ 4 files changed, 381 insertions(+), 4 deletions(-) create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index e6af4fd0f..3873fd474 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -24,6 +24,7 @@ import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.spec.McpStreamableServerSession; import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; @@ -87,6 +88,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + private static final int KEEP_ALIVE_FAILURE_THRESHOLD = 3; + /** * The endpoint URI where clients should send their JSON-RPC messages. Defaults to * "/mcp". @@ -107,6 +110,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap keepAliveFailureCounts = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; /** @@ -158,6 +163,8 @@ private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, S .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) .initialDelay(keepAliveInterval) .interval(keepAliveInterval) + .onSuccess(this::resetKeepAliveFailures) + .onFailure(this::handleKeepAliveFailure) .build(); this.keepAliveScheduler.start(); @@ -231,8 +238,10 @@ public Mono closeGracefully() { }); this.sessions.clear(); + this.keepAliveFailureCounts.clear(); }).then().doOnSuccess(v -> { sessions.clear(); + keepAliveFailureCounts.clear(); logger.debug("Graceful shutdown completed"); if (this.keepAliveScheduler != null) { this.keepAliveScheduler.shutdown(); @@ -445,6 +454,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); this.sessions.put(init.session().getId(), init.session()); + this.keepAliveFailureCounts.remove(init.session().getId()); try { McpSchema.InitializeResult initResult = init.initResult().block(); @@ -614,6 +624,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response try { session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); this.sessions.remove(sessionId); + this.keepAliveFailureCounts.remove(sessionId); response.setStatus(HttpServletResponse.SC_OK); } catch (Exception e) { @@ -640,6 +651,42 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m return; } + void resetKeepAliveFailures(McpSession session) { + if (session instanceof McpStreamableServerSession streamableSession) { + String sessionId = streamableSession.getId(); + if (this.sessions.get(sessionId) == streamableSession) { + this.keepAliveFailureCounts.remove(sessionId); + } + } + } + + void handleKeepAliveFailure(McpSession session, Throwable error) { + if (!(session instanceof McpStreamableServerSession streamableSession)) { + return; + } + + String sessionId = streamableSession.getId(); + if (this.sessions.get(sessionId) != streamableSession) { + return; + } + + int failures = this.keepAliveFailureCounts.merge(sessionId, 1, Integer::sum); + if (failures < KEEP_ALIVE_FAILURE_THRESHOLD) { + logger.debug("Keep-alive ping failed for session {} ({}/{} consecutive failures): {}", sessionId, failures, + KEEP_ALIVE_FAILURE_THRESHOLD, error.getMessage()); + return; + } + + if (this.sessions.remove(sessionId, streamableSession)) { + this.keepAliveFailureCounts.remove(sessionId); + streamableSession.close(); + logger.info("Evicted session {} after {} failed keep-alive attempts", sessionId, failures); + } + else { + this.keepAliveFailureCounts.remove(sessionId); + } + } + /** * Sends an SSE event to a client with a specific ID. * @param writer The writer to send the event through @@ -748,6 +795,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId catch (Exception e) { logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + HttpServletStreamableServerTransportProvider.this.keepAliveFailureCounts.remove(this.sessionId); this.asyncContext.complete(); } finally { diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java index 6d53ed516..26e6ab019 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -6,6 +6,8 @@ import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import org.slf4j.Logger; @@ -57,6 +59,10 @@ public class KeepAliveScheduler { /** Supplier for reactive McpSession instances */ private final Supplier> mcpSessions; + private final Consumer onSuccess; + + private final BiConsumer onFailure; + /** * Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a * supplier for McpSession instances. @@ -66,11 +72,14 @@ public class KeepAliveScheduler { * @param mcpSessions Supplier for McpSession instances */ KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval, - Supplier> mcpSessions) { + Supplier> mcpSessions, Consumer onSuccess, + BiConsumer onFailure) { this.scheduler = scheduler; this.initialDelay = initialDelay; this.interval = interval; this.mcpSessions = mcpSessions; + this.onSuccess = onSuccess; + this.onFailure = onFailure; } /** @@ -92,8 +101,12 @@ public Disposable start() { .doOnNext(tick -> { this.mcpSessions.get() .flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF) - .doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session, - e.getMessage())) + .doOnSuccess(result -> this.notifySuccess(session)) + .doOnError(e -> { + logger.warn("Failed to send keep-alive ping to session {}: {}", session, + e.getMessage()); + this.notifyFailure(session, e); + }) .onErrorComplete()) .subscribe(); }) @@ -131,6 +144,24 @@ public boolean isRunning() { return this.isRunning.get(); } + private void notifySuccess(McpSession session) { + try { + this.onSuccess.accept(session); + } + catch (Exception e) { + logger.warn("Keep-alive success callback failed for session {}: {}", session, e.getMessage()); + } + } + + private void notifyFailure(McpSession session, Throwable error) { + try { + this.onFailure.accept(session, error); + } + catch (Exception e) { + logger.warn("Keep-alive failure callback failed for session {}: {}", session, e.getMessage()); + } + } + /** * Shuts down the scheduler and releases resources. */ @@ -154,6 +185,12 @@ public static class Builder { private Supplier> mcpSessions; + private Consumer onSuccess = session -> { + }; + + private BiConsumer onFailure = (session, error) -> { + }; + /** * Creates a new Builder instance with a supplier for McpSession instances. * @param mcpSessions The supplier for McpSession instances @@ -204,12 +241,34 @@ public Builder interval(Duration interval) { return this; } + /** + * Sets the callback invoked after a keep-alive ping completes successfully. + * @param onSuccess The success callback. Must not be null. + * @return This builder instance for method chaining + */ + public Builder onSuccess(Consumer onSuccess) { + Assert.notNull(onSuccess, "OnSuccess callback must not be null"); + this.onSuccess = onSuccess; + return this; + } + + /** + * Sets the callback invoked after a keep-alive ping fails. + * @param onFailure The failure callback. Must not be null. + * @return This builder instance for method chaining + */ + public Builder onFailure(BiConsumer onFailure) { + Assert.notNull(onFailure, "OnFailure callback must not be null"); + this.onFailure = onFailure; + return this; + } + /** * Builds and returns a new KeepAliveScheduler instance. * @return A new KeepAliveScheduler configured with the builder's settings */ public KeepAliveScheduler build() { - return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions); + return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions, onSuccess, onFailure); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java new file mode 100644 index 000000000..e0d3d3e11 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.json.gson.GsonMcpJsonMapper; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for keep-alive failure eviction in + * {@link HttpServletStreamableServerTransportProvider}. + */ +class HttpServletStreamableServerTransportProviderTests { + + @Test + void firstKeepAliveFailureDoesNotEvictSession() throws Exception { + HttpServletStreamableServerTransportProvider provider = createProvider(); + TrackingStreamableSession session = createSession("session-1"); + putSession(provider, session); + + provider.handleKeepAliveFailure(session, new RuntimeException("ping failed")); + + assertThat(sessions(provider)).containsEntry("session-1", session); + assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1); + assertThat(session.closeCount()).isZero(); + } + + @Test + void repeatedKeepAliveFailuresEvictSession() throws Exception { + HttpServletStreamableServerTransportProvider provider = createProvider(); + TrackingStreamableSession session = createSession("session-1"); + putSession(provider, session); + + provider.handleKeepAliveFailure(session, new RuntimeException("first failure")); + provider.handleKeepAliveFailure(session, new RuntimeException("second failure")); + provider.handleKeepAliveFailure(session, new RuntimeException("third failure")); + + assertThat(sessions(provider)).doesNotContainKey("session-1"); + assertThat(keepAliveFailureCounts(provider)).doesNotContainKey("session-1"); + assertThat(session.closeCount()).isOne(); + } + + @Test + void successfulKeepAliveResetsFailureCount() throws Exception { + HttpServletStreamableServerTransportProvider provider = createProvider(); + TrackingStreamableSession session = createSession("session-1"); + putSession(provider, session); + + provider.handleKeepAliveFailure(session, new RuntimeException("first failure")); + provider.handleKeepAliveFailure(session, new RuntimeException("second failure")); + provider.resetKeepAliveFailures(session); + provider.handleKeepAliveFailure(session, new RuntimeException("failure after success")); + + assertThat(sessions(provider)).containsEntry("session-1", session); + assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1); + assertThat(session.closeCount()).isZero(); + } + + @Test + void successfulKeepAliveFromReplacedSessionDoesNotResetReplacementFailureCount() throws Exception { + HttpServletStreamableServerTransportProvider provider = createProvider(); + TrackingStreamableSession oldSession = createSession("session-1"); + TrackingStreamableSession replacementSession = createSession("session-1"); + putSession(provider, replacementSession); + + provider.handleKeepAliveFailure(replacementSession, new RuntimeException("replacement failure")); + provider.resetKeepAliveFailures(oldSession); + + assertThat(sessions(provider)).containsEntry("session-1", replacementSession); + assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1); + assertThat(oldSession.closeCount()).isZero(); + assertThat(replacementSession.closeCount()).isZero(); + } + + @Test + void keepAliveFailureDoesNotCloseReplacedSession() throws Exception { + HttpServletStreamableServerTransportProvider provider = createProvider(); + TrackingStreamableSession oldSession = createSession("session-1"); + TrackingStreamableSession replacementSession = createSession("session-1"); + putSession(provider, replacementSession); + + provider.handleKeepAliveFailure(oldSession, new RuntimeException("first failure")); + provider.handleKeepAliveFailure(oldSession, new RuntimeException("second failure")); + provider.handleKeepAliveFailure(oldSession, new RuntimeException("third failure")); + + assertThat(sessions(provider)).containsEntry("session-1", replacementSession); + assertThat(keepAliveFailureCounts(provider)).doesNotContainKey("session-1"); + assertThat(oldSession.closeCount()).isZero(); + assertThat(replacementSession.closeCount()).isZero(); + } + + private HttpServletStreamableServerTransportProvider createProvider() { + return HttpServletStreamableServerTransportProvider.builder().jsonMapper(new GsonMcpJsonMapper()).build(); + } + + private TrackingStreamableSession createSession(String sessionId) { + return new TrackingStreamableSession(sessionId); + } + + private void putSession(HttpServletStreamableServerTransportProvider provider, TrackingStreamableSession session) + throws Exception { + sessions(provider).put(session.getId(), session); + } + + @SuppressWarnings("unchecked") + private Map sessions(HttpServletStreamableServerTransportProvider provider) + throws Exception { + Field field = HttpServletStreamableServerTransportProvider.class.getDeclaredField("sessions"); + field.setAccessible(true); + return (Map) field.get(provider); + } + + @SuppressWarnings("unchecked") + private Map keepAliveFailureCounts(HttpServletStreamableServerTransportProvider provider) + throws Exception { + Field field = HttpServletStreamableServerTransportProvider.class.getDeclaredField("keepAliveFailureCounts"); + field.setAccessible(true); + return (Map) field.get(provider); + } + + private static class TrackingStreamableSession extends McpStreamableServerSession { + + private final AtomicInteger closeCount = new AtomicInteger(); + + TrackingStreamableSession(String id) { + super(id, McpSchema.ClientCapabilities.builder().build(), + new McpSchema.Implementation("test-client", "1.0.0"), Duration.ofSeconds(5), Map.of(), Map.of(), + Mono::empty); + } + + @Override + public void close() { + this.closeCount.incrementAndGet(); + } + + int closeCount() { + return this.closeCount.get(); + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java index d5ef8a91c..55f414ac6 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -82,6 +82,20 @@ void testBuilderWithNullInterval() { .hasMessage("Interval must not be null"); } + @Test + void testBuilderWithNullOnSuccess() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).onSuccess(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("OnSuccess callback must not be null"); + } + + @Test + void testBuilderWithNullOnFailure() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).onFailure(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("OnFailure callback must not be null"); + } + @Test void testBuilderDefaults() { KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier).build(); @@ -156,6 +170,30 @@ void testStartWithEmptySessionsList() { scheduler.stop(); } + @Test + void testPingSuccessInvokesSuccessHook() { + AtomicInteger successCount = new AtomicInteger(); + AtomicInteger failureCount = new AtomicInteger(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .onSuccess(session -> successCount.incrementAndGet()) + .onFailure((session, error) -> failureCount.incrementAndGet()) + .build(); + + scheduler.start(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + assertThat(mockSession1.getPingCount()).isEqualTo(1); + assertThat(successCount).hasValue(1); + assertThat(failureCount).hasValue(0); + + scheduler.stop(); + } + @Test void testStartWhenAlreadyRunning() { KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) @@ -229,6 +267,85 @@ void testPingFailureHandling() { scheduler.stop(); } + @Test + void testPingFailureInvokesFailureHook() { + mockSession1.setShouldFailPing(true); + AtomicInteger successCount = new AtomicInteger(); + AtomicInteger failureCount = new AtomicInteger(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .onSuccess(session -> successCount.incrementAndGet()) + .onFailure((session, error) -> failureCount.incrementAndGet()) + .build(); + + scheduler.start(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + assertThat(mockSession1.getPingCount()).isEqualTo(1); + assertThat(successCount).hasValue(0); + assertThat(failureCount).hasValue(1); + assertThat(scheduler.isRunning()).isTrue(); + + scheduler.stop(); + } + + @Test + void testSuccessHookExceptionDoesNotStopScheduler() { + AtomicInteger successAttempts = new AtomicInteger(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .onSuccess(session -> { + successAttempts.incrementAndGet(); + throw new IllegalStateException("success callback failed"); + }) + .build(); + + scheduler.start(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); + + assertThat(mockSession1.getPingCount()).isEqualTo(2); + assertThat(successAttempts).hasValue(2); + assertThat(scheduler.isRunning()).isTrue(); + + scheduler.stop(); + } + + @Test + void testFailureHookExceptionDoesNotStopScheduler() { + mockSession1.setShouldFailPing(true); + AtomicInteger failureAttempts = new AtomicInteger(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .onFailure((session, error) -> { + failureAttempts.incrementAndGet(); + throw new IllegalStateException("failure callback failed"); + }) + .build(); + + scheduler.start(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); + + assertThat(mockSession1.getPingCount()).isEqualTo(2); + assertThat(failureAttempts).hasValue(2); + assertThat(scheduler.isRunning()).isTrue(); + + scheduler.stop(); + } + @Test void testDisposableReturnedFromStart() { KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier)