001 002/* 003 * This class is a copy of HttpClientSseClientTransport v. 0.8.1 modified to support telemetry. 004 * The original class could not be extended to support telemetry due to private fields. 005 * 006 * Copyright 2024 - 2025 the original author or authors. 007 */ 008package org.nasdanika.ai.mcp; 009 010import java.io.IOException; 011import java.net.URI; 012import java.net.http.HttpClient; 013import java.net.http.HttpRequest; 014import java.net.http.HttpRequest.Builder; 015import java.net.http.HttpResponse; 016import java.time.Duration; 017import java.util.concurrent.CompletableFuture; 018import java.util.concurrent.CountDownLatch; 019import java.util.concurrent.TimeUnit; 020import java.util.concurrent.atomic.AtomicReference; 021import java.util.function.BiConsumer; 022import java.util.function.Function; 023 024import org.slf4j.Logger; 025import org.slf4j.LoggerFactory; 026 027import com.fasterxml.jackson.core.type.TypeReference; 028import com.fasterxml.jackson.databind.ObjectMapper; 029 030import io.modelcontextprotocol.client.transport.FlowSseClient; 031import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; 032import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; 033import io.modelcontextprotocol.spec.McpClientTransport; 034import io.modelcontextprotocol.spec.McpError; 035import io.modelcontextprotocol.spec.McpSchema; 036import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; 037import io.modelcontextprotocol.util.Assert; 038import io.opentelemetry.api.trace.Span; 039import io.opentelemetry.api.trace.SpanKind; 040import io.opentelemetry.api.trace.StatusCode; 041import io.opentelemetry.api.trace.Tracer; 042import io.opentelemetry.context.Context; 043import io.opentelemetry.context.Scope; 044import io.opentelemetry.context.propagation.TextMapPropagator; 045import reactor.core.publisher.Mono; 046 047/** 048 * Server-Sent Events (SSE) implementation of the 049 * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE 050 * transport specification, using Java's HttpClient. 051 * 052 * <p> 053 * This transport implementation establishes a bidirectional communication channel between 054 * client and server using SSE for server-to-client messages and HTTP POST requests for 055 * client-to-server messages. The transport: 056 * <ul> 057 * <li>Establishes an SSE connection to receive server messages</li> 058 * <li>Handles endpoint discovery through SSE events</li> 059 * <li>Manages message serialization/deserialization using Jackson</li> 060 * <li>Provides graceful connection termination</li> 061 * </ul> 062 * 063 * <p> 064 * The transport supports two types of SSE events: 065 * <ul> 066 * <li>'endpoint' - Contains the URL for sending client messages</li> 067 * <li>'message' - Contains JSON-RPC message payload</li> 068 * </ul> 069 * 070 * @author Christian Tzolov 071 * @author Pavel Vlasov 072 * @see io.modelcontextprotocol.spec.McpTransport 073 * @see io.modelcontextprotocol.spec.McpClientTransport 074 */ 075public class HttpClientTelemetrySseClientTransport implements McpClientTransport { 076 077 private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); 078 079 /** SSE event type for JSON-RPC messages */ 080 private static final String MESSAGE_EVENT_TYPE = "message"; 081 082 /** SSE event type for endpoint discovery */ 083 private static final String ENDPOINT_EVENT_TYPE = "endpoint"; 084 085 /** Default SSE endpoint path */ 086 private static final String SSE_ENDPOINT = "/sse"; 087 088 /** Base URI for the MCP server */ 089 private final String baseUri; 090 091 /** SSE client for handling server-sent events. Uses the /sse endpoint */ 092 private final FlowSseClient sseClient; 093 094 /** 095 * HTTP client for sending messages to the server. Uses HTTP POST over the message 096 * endpoint 097 */ 098 private final HttpClient httpClient; 099 100 /** JSON object mapper for message serialization/deserialization */ 101 protected ObjectMapper objectMapper; 102 103 /** Flag indicating if the transport is in closing state */ 104 private volatile boolean isClosing = false; 105 106 /** Latch for coordinating endpoint discovery */ 107 private final CountDownLatch closeLatch = new CountDownLatch(1); 108 109 /** Holds the discovered message endpoint URL */ 110 private final AtomicReference<String> messageEndpoint = new AtomicReference<>(); 111 112 /** Holds the SSE connection future */ 113 private final AtomicReference<CompletableFuture<Void>> connectionFuture = new AtomicReference<>(); 114 115 private Tracer tracer; 116 117 private BiConsumer<String, Long> durationConsumer; 118 119 private TextMapPropagator propagator; 120 121 /** 122 * Creates a new transport instance with default HTTP client and object mapper. 123 * @param baseUri the base URI of the MCP server 124 * @param tracer If not null, creates a span for sendMessage HTTP request. Pass <code>null</code> when using {@link TelemetryMcpClientTransportFilter} to avoid two nested sendMessage spans 125 */ 126 public HttpClientTelemetrySseClientTransport( 127 String baseUri, 128 Tracer tracer, 129 TextMapPropagator propagator, 130 BiConsumer<String, Long> durationConsumer) { 131 this(HttpClient.newBuilder(), baseUri, new ObjectMapper(), tracer, propagator, durationConsumer); 132 } 133 134 /** 135 * Creates a new transport instance with custom HTTP client builder and object mapper. 136 * @param clientBuilder the HTTP client builder to use 137 * @param baseUri the base URI of the MCP server 138 * @param objectMapper the object mapper for JSON serialization/deserialization 139 * @throws IllegalArgumentException if objectMapper or clientBuilder is null 140 */ 141 public HttpClientTelemetrySseClientTransport( 142 HttpClient.Builder clientBuilder, 143 String baseUri, 144 ObjectMapper objectMapper, 145 Tracer tracer, 146 TextMapPropagator propagator, 147 BiConsumer<String, Long> durationConsumer) { 148 Assert.notNull(objectMapper, "ObjectMapper must not be null"); 149 Assert.hasText(baseUri, "baseUri must not be empty"); 150 Assert.notNull(clientBuilder, "clientBuilder must not be null"); 151 this.baseUri = baseUri; 152 this.objectMapper = objectMapper; 153 this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); 154 this.sseClient = new FlowSseClient(this.httpClient); 155 this.tracer = tracer; 156 this.propagator = propagator; 157 this.durationConsumer = durationConsumer; 158 } 159 160 /** 161 * Establishes the SSE connection with the server and sets up message handling. 162 * 163 * <p> 164 * This method: 165 * <ul> 166 * <li>Initiates the SSE connection</li> 167 * <li>Handles endpoint discovery events</li> 168 * <li>Processes incoming JSON-RPC messages</li> 169 * </ul> 170 * @param handler the function to process received JSON-RPC messages 171 * @return a Mono that completes when the connection is established 172 */ 173 @Override 174 public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) { 175 CompletableFuture<Void> future = new CompletableFuture<>(); 176 connectionFuture.set(future); 177 178 sseClient.subscribe(this.baseUri + SSE_ENDPOINT, new FlowSseClient.SseEventHandler() { 179 @Override 180 public void onEvent(SseEvent event) { 181 if (isClosing) { 182 return; 183 } 184 185 try { 186 if (ENDPOINT_EVENT_TYPE.equals(event.type())) { 187 String endpoint = event.data(); 188 messageEndpoint.set(endpoint); 189 closeLatch.countDown(); 190 future.complete(null); 191 } 192 else if (MESSAGE_EVENT_TYPE.equals(event.type())) { 193 JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); 194 handler.apply(Mono.just(message)).subscribe(); 195 } 196 else { 197 logger.error("Received unrecognized SSE event type: {}", event.type()); 198 } 199 } 200 catch (IOException e) { 201 logger.error("Error processing SSE event", e); 202 future.completeExceptionally(e); 203 } 204 } 205 206 @Override 207 public void onError(Throwable error) { 208 if (!isClosing) { 209 logger.error("SSE connection error", error); 210 future.completeExceptionally(error); 211 } 212 } 213 }); 214 215 return Mono.fromFuture(future); 216 } 217 218 /** 219 * Sends a JSON-RPC message to the server. 220 * 221 * <p> 222 * This method waits for the message endpoint to be discovered before sending the 223 * message. The message is serialized to JSON and sent as an HTTP POST request. 224 * @param message the JSON-RPC message to send 225 * @return a Mono that completes when the message is sent 226 * @throws McpError if the message endpoint is not available or the wait times out 227 */ 228 @Override 229 public Mono<Void> sendMessage(JSONRPCMessage message) { 230 if (isClosing) { 231 return Mono.empty(); 232 } 233 234 try { 235 if (!closeLatch.await(10, TimeUnit.SECONDS)) { 236 return Mono.error(new McpError("Failed to wait for the message endpoint")); 237 } 238 } 239 catch (InterruptedException e) { 240 return Mono.error(new McpError("Failed to wait for the message endpoint")); 241 } 242 243 String endpoint = messageEndpoint.get(); 244 if (endpoint == null) { 245 return Mono.error(new McpError("No message endpoint available")); 246 } 247 248 return Mono.deferContextual(contextView -> { 249 Context parentContext = contextView.getOrDefault(Context.class, Context.current()); 250 long start = System.currentTimeMillis(); 251 URI requestURI = URI.create(this.baseUri + endpoint); 252 253 Span requestSpan = tracer == null ? Span.fromContext(parentContext) : 254 tracer 255 .spanBuilder("sendMessage") 256 .setAttribute("uri", requestURI.toString()) 257 .setSpanKind(SpanKind.CLIENT) 258 .setParent(parentContext) 259 .startSpan(); 260 try (Scope scope = requestSpan.makeCurrent()) { 261 String jsonText = this.objectMapper.writeValueAsString(message); 262 requestSpan.setAttribute("message", jsonText); 263 Builder builder = getHttpRequestBuilder() 264 .uri(requestURI) 265 .header("Content-Type", "application/json") 266 .POST(HttpRequest.BodyPublishers.ofString(jsonText)); 267 268 Context telemetryContext = Context.current().with(requestSpan); 269 propagator.inject(telemetryContext, builder, (b, name, value) -> b.header(name, value)); 270 HttpRequest request = builder.build(); 271 272 return Mono.fromFuture( 273 httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { 274 if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 275 && response.statusCode() != 206) { 276 logger.error("Error sending message: {}", response.statusCode()); 277 } 278 })) 279 .map(result -> { 280 if (durationConsumer != null) { 281 durationConsumer.accept(requestURI.toString(), System.currentTimeMillis() - start); 282 } 283 requestSpan.setStatus(StatusCode.OK); 284 return result; 285 }) 286 .onErrorMap(error -> { 287 requestSpan.recordException(error); 288 requestSpan.setStatus(StatusCode.ERROR); 289 return error; 290 }) 291 .doFinally(signal -> requestSpan.end()); 292 } catch (IOException e) { 293 requestSpan.recordException(e); 294 if (!isClosing) { 295 return Mono.error(new RuntimeException("Failed to serialize message", e)); 296 } 297 return Mono.empty(); 298 } 299 }); 300 } 301 302 protected Builder getHttpRequestBuilder() { 303 return HttpRequest.newBuilder(); 304 } 305 306 /** 307 * Gracefully closes the transport connection. 308 * 309 * <p> 310 * Sets the closing flag and cancels any pending connection future. This prevents new 311 * messages from being sent and allows ongoing operations to complete. 312 * @return a Mono that completes when the closing process is initiated 313 */ 314 @Override 315 public Mono<Void> closeGracefully() { 316 return Mono.fromRunnable(() -> { 317 isClosing = true; 318 CompletableFuture<Void> future = connectionFuture.get(); 319 if (future != null && !future.isDone()) { 320 future.cancel(true); 321 } 322 }); 323 } 324 325 /** 326 * Unmarshals data to the specified type using the configured object mapper. 327 * @param data the data to unmarshal 328 * @param typeRef the type reference for the target type 329 * @param <T> the target type 330 * @return the unmarshalled object 331 */ 332 @Override 333 public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) { 334 return this.objectMapper.convertValue(data, typeRef); 335 } 336 337}