Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
Expand Down Expand Up @@ -87,6 +88,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet

public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";

private static final int KEEP_ALIVE_FAILURE_THRESHOLD = 3;

/**
* The endpoint URI where clients should send their JSON-RPC messages. Defaults to
* "/mcp".
Expand All @@ -107,6 +110,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
*/
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();

private final ConcurrentHashMap<String, Integer> keepAliveFailureCounts = new ConcurrentHashMap<>();

private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

/**
Expand Down Expand Up @@ -158,6 +163,8 @@ private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, S
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values()))
.initialDelay(keepAliveInterval)
.interval(keepAliveInterval)
.onSuccess(this::resetKeepAliveFailures)
.onFailure(this::handleKeepAliveFailure)
.build();

this.keepAliveScheduler.start();
Expand Down Expand Up @@ -231,8 +238,10 @@ public Mono<Void> closeGracefully() {
});

this.sessions.clear();
this.keepAliveFailureCounts.clear();
}).then().doOnSuccess(v -> {
sessions.clear();
keepAliveFailureCounts.clear();
logger.debug("Graceful shutdown completed");
if (this.keepAliveScheduler != null) {
this.keepAliveScheduler.shutdown();
Expand Down Expand Up @@ -445,6 +454,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
.startSession(initializeRequest);
this.sessions.put(init.session().getId(), init.session());
this.keepAliveFailureCounts.remove(init.session().getId());

try {
McpSchema.InitializeResult initResult = init.initResult().block();
Expand Down Expand Up @@ -614,6 +624,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
try {
session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
this.sessions.remove(sessionId);
this.keepAliveFailureCounts.remove(sessionId);
response.setStatus(HttpServletResponse.SC_OK);
}
catch (Exception e) {
Expand All @@ -640,6 +651,42 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m
return;
}

void resetKeepAliveFailures(McpSession session) {
if (session instanceof McpStreamableServerSession streamableSession) {
String sessionId = streamableSession.getId();
if (this.sessions.get(sessionId) == streamableSession) {
this.keepAliveFailureCounts.remove(sessionId);
}
}
}

void handleKeepAliveFailure(McpSession session, Throwable error) {
if (!(session instanceof McpStreamableServerSession streamableSession)) {
return;
}

String sessionId = streamableSession.getId();
if (this.sessions.get(sessionId) != streamableSession) {
return;
}

int failures = this.keepAliveFailureCounts.merge(sessionId, 1, Integer::sum);
if (failures < KEEP_ALIVE_FAILURE_THRESHOLD) {
logger.debug("Keep-alive ping failed for session {} ({}/{} consecutive failures): {}", sessionId, failures,
KEEP_ALIVE_FAILURE_THRESHOLD, error.getMessage());
return;
}

if (this.sessions.remove(sessionId, streamableSession)) {
this.keepAliveFailureCounts.remove(sessionId);
streamableSession.close();
logger.info("Evicted session {} after {} failed keep-alive attempts", sessionId, failures);
}
else {
this.keepAliveFailureCounts.remove(sessionId);
}
}

/**
* Sends an SSE event to a client with a specific ID.
* @param writer The writer to send the event through
Expand Down Expand Up @@ -748,6 +795,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
catch (Exception e) {
logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
HttpServletStreamableServerTransportProvider.this.keepAliveFailureCounts.remove(this.sessionId);
this.asyncContext.complete();
}
finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;

import org.slf4j.Logger;
Expand Down Expand Up @@ -57,6 +59,10 @@ public class KeepAliveScheduler {
/** Supplier for reactive McpSession instances */
private final Supplier<Flux<McpSession>> mcpSessions;

private final Consumer<McpSession> onSuccess;

private final BiConsumer<McpSession, Throwable> onFailure;

/**
* Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a
* supplier for McpSession instances.
Expand All @@ -66,11 +72,14 @@ public class KeepAliveScheduler {
* @param mcpSessions Supplier for McpSession instances
*/
KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval,
Supplier<Flux<McpSession>> mcpSessions) {
Supplier<Flux<McpSession>> mcpSessions, Consumer<McpSession> onSuccess,
BiConsumer<McpSession, Throwable> onFailure) {
this.scheduler = scheduler;
this.initialDelay = initialDelay;
this.interval = interval;
this.mcpSessions = mcpSessions;
this.onSuccess = onSuccess;
this.onFailure = onFailure;
}

/**
Expand All @@ -92,8 +101,12 @@ public Disposable start() {
.doOnNext(tick -> {
this.mcpSessions.get()
.flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)
.doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session,
e.getMessage()))
.doOnSuccess(result -> this.notifySuccess(session))
.doOnError(e -> {
logger.warn("Failed to send keep-alive ping to session {}: {}", session,
e.getMessage());
this.notifyFailure(session, e);
})
.onErrorComplete())
.subscribe();
})
Expand Down Expand Up @@ -131,6 +144,24 @@ public boolean isRunning() {
return this.isRunning.get();
}

private void notifySuccess(McpSession session) {
try {
this.onSuccess.accept(session);
}
catch (Exception e) {
logger.warn("Keep-alive success callback failed for session {}: {}", session, e.getMessage());
}
}

private void notifyFailure(McpSession session, Throwable error) {
try {
this.onFailure.accept(session, error);
}
catch (Exception e) {
logger.warn("Keep-alive failure callback failed for session {}: {}", session, e.getMessage());
}
}

/**
* Shuts down the scheduler and releases resources.
*/
Expand All @@ -154,6 +185,12 @@ public static class Builder {

private Supplier<Flux<McpSession>> mcpSessions;

private Consumer<McpSession> onSuccess = session -> {
};

private BiConsumer<McpSession, Throwable> onFailure = (session, error) -> {
};

/**
* Creates a new Builder instance with a supplier for McpSession instances.
* @param mcpSessions The supplier for McpSession instances
Expand Down Expand Up @@ -204,12 +241,34 @@ public Builder interval(Duration interval) {
return this;
}

/**
* Sets the callback invoked after a keep-alive ping completes successfully.
* @param onSuccess The success callback. Must not be null.
* @return This builder instance for method chaining
*/
public Builder onSuccess(Consumer<McpSession> onSuccess) {
Assert.notNull(onSuccess, "OnSuccess callback must not be null");
this.onSuccess = onSuccess;
return this;
}

/**
* Sets the callback invoked after a keep-alive ping fails.
* @param onFailure The failure callback. Must not be null.
* @return This builder instance for method chaining
*/
public Builder onFailure(BiConsumer<McpSession, Throwable> onFailure) {
Assert.notNull(onFailure, "OnFailure callback must not be null");
this.onFailure = onFailure;
return this;
}

/**
* Builds and returns a new KeepAliveScheduler instance.
* @return A new KeepAliveScheduler configured with the builder's settings
*/
public KeepAliveScheduler build() {
return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions);
return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions, onSuccess, onFailure);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright 2026-2026 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import java.lang.reflect.Field;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.json.gson.GsonMcpJsonMapper;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for keep-alive failure eviction in
* {@link HttpServletStreamableServerTransportProvider}.
*/
class HttpServletStreamableServerTransportProviderTests {

@Test
void firstKeepAliveFailureDoesNotEvictSession() throws Exception {
HttpServletStreamableServerTransportProvider provider = createProvider();
TrackingStreamableSession session = createSession("session-1");
putSession(provider, session);

provider.handleKeepAliveFailure(session, new RuntimeException("ping failed"));

assertThat(sessions(provider)).containsEntry("session-1", session);
assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1);
assertThat(session.closeCount()).isZero();
}

@Test
void repeatedKeepAliveFailuresEvictSession() throws Exception {
HttpServletStreamableServerTransportProvider provider = createProvider();
TrackingStreamableSession session = createSession("session-1");
putSession(provider, session);

provider.handleKeepAliveFailure(session, new RuntimeException("first failure"));
provider.handleKeepAliveFailure(session, new RuntimeException("second failure"));
provider.handleKeepAliveFailure(session, new RuntimeException("third failure"));

assertThat(sessions(provider)).doesNotContainKey("session-1");
assertThat(keepAliveFailureCounts(provider)).doesNotContainKey("session-1");
assertThat(session.closeCount()).isOne();
}

@Test
void successfulKeepAliveResetsFailureCount() throws Exception {
HttpServletStreamableServerTransportProvider provider = createProvider();
TrackingStreamableSession session = createSession("session-1");
putSession(provider, session);

provider.handleKeepAliveFailure(session, new RuntimeException("first failure"));
provider.handleKeepAliveFailure(session, new RuntimeException("second failure"));
provider.resetKeepAliveFailures(session);
provider.handleKeepAliveFailure(session, new RuntimeException("failure after success"));

assertThat(sessions(provider)).containsEntry("session-1", session);
assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1);
assertThat(session.closeCount()).isZero();
}

@Test
void successfulKeepAliveFromReplacedSessionDoesNotResetReplacementFailureCount() throws Exception {
HttpServletStreamableServerTransportProvider provider = createProvider();
TrackingStreamableSession oldSession = createSession("session-1");
TrackingStreamableSession replacementSession = createSession("session-1");
putSession(provider, replacementSession);

provider.handleKeepAliveFailure(replacementSession, new RuntimeException("replacement failure"));
provider.resetKeepAliveFailures(oldSession);

assertThat(sessions(provider)).containsEntry("session-1", replacementSession);
assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1);
assertThat(oldSession.closeCount()).isZero();
assertThat(replacementSession.closeCount()).isZero();
}

@Test
void keepAliveFailureDoesNotCloseReplacedSession() throws Exception {
HttpServletStreamableServerTransportProvider provider = createProvider();
TrackingStreamableSession oldSession = createSession("session-1");
TrackingStreamableSession replacementSession = createSession("session-1");
putSession(provider, replacementSession);

provider.handleKeepAliveFailure(oldSession, new RuntimeException("first failure"));
provider.handleKeepAliveFailure(oldSession, new RuntimeException("second failure"));
provider.handleKeepAliveFailure(oldSession, new RuntimeException("third failure"));

assertThat(sessions(provider)).containsEntry("session-1", replacementSession);
assertThat(keepAliveFailureCounts(provider)).doesNotContainKey("session-1");
assertThat(oldSession.closeCount()).isZero();
assertThat(replacementSession.closeCount()).isZero();
}

private HttpServletStreamableServerTransportProvider createProvider() {
return HttpServletStreamableServerTransportProvider.builder().jsonMapper(new GsonMcpJsonMapper()).build();
}

private TrackingStreamableSession createSession(String sessionId) {
return new TrackingStreamableSession(sessionId);
}

private void putSession(HttpServletStreamableServerTransportProvider provider, TrackingStreamableSession session)
throws Exception {
sessions(provider).put(session.getId(), session);
}

@SuppressWarnings("unchecked")
private Map<String, McpStreamableServerSession> sessions(HttpServletStreamableServerTransportProvider provider)
throws Exception {
Field field = HttpServletStreamableServerTransportProvider.class.getDeclaredField("sessions");
field.setAccessible(true);
return (Map<String, McpStreamableServerSession>) field.get(provider);
}

@SuppressWarnings("unchecked")
private Map<String, Integer> keepAliveFailureCounts(HttpServletStreamableServerTransportProvider provider)
throws Exception {
Field field = HttpServletStreamableServerTransportProvider.class.getDeclaredField("keepAliveFailureCounts");
field.setAccessible(true);
return (Map<String, Integer>) field.get(provider);
}

private static class TrackingStreamableSession extends McpStreamableServerSession {

private final AtomicInteger closeCount = new AtomicInteger();

TrackingStreamableSession(String id) {
super(id, McpSchema.ClientCapabilities.builder().build(),
new McpSchema.Implementation("test-client", "1.0.0"), Duration.ofSeconds(5), Map.of(), Map.of(),
Mono::empty);
}

@Override
public void close() {
this.closeCount.incrementAndGet();
}

int closeCount() {
return this.closeCount.get();
}

}

}
Loading