001
002/*
003 * This class is a copy of HttpClientSseClientTransport v. 0.8.1 modified to support telemetry.
004 * The original class could not be extended to support telemetry due to private fields. 
005 *  
006 * Copyright 2024 - 2025 the original author or authors.
007 */
008package org.nasdanika.ai.mcp;
009
010import java.io.IOException;
011import java.net.URI;
012import java.net.http.HttpClient;
013import java.net.http.HttpRequest;
014import java.net.http.HttpRequest.Builder;
015import java.net.http.HttpResponse;
016import java.time.Duration;
017import java.util.concurrent.CompletableFuture;
018import java.util.concurrent.CountDownLatch;
019import java.util.concurrent.TimeUnit;
020import java.util.concurrent.atomic.AtomicReference;
021import java.util.function.BiConsumer;
022import java.util.function.Function;
023
024import org.slf4j.Logger;
025import org.slf4j.LoggerFactory;
026
027import com.fasterxml.jackson.core.type.TypeReference;
028import com.fasterxml.jackson.databind.ObjectMapper;
029
030import io.modelcontextprotocol.client.transport.FlowSseClient;
031import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent;
032import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
033import io.modelcontextprotocol.spec.McpClientTransport;
034import io.modelcontextprotocol.spec.McpError;
035import io.modelcontextprotocol.spec.McpSchema;
036import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
037import io.modelcontextprotocol.util.Assert;
038import io.opentelemetry.api.trace.Span;
039import io.opentelemetry.api.trace.SpanKind;
040import io.opentelemetry.api.trace.StatusCode;
041import io.opentelemetry.api.trace.Tracer;
042import io.opentelemetry.context.Context;
043import io.opentelemetry.context.Scope;
044import io.opentelemetry.context.propagation.TextMapPropagator;
045import reactor.core.publisher.Mono;
046
047/**
048 * Server-Sent Events (SSE) implementation of the
049 * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE
050 * transport specification, using Java's HttpClient.
051 *
052 * <p>
053 * This transport implementation establishes a bidirectional communication channel between
054 * client and server using SSE for server-to-client messages and HTTP POST requests for
055 * client-to-server messages. The transport:
056 * <ul>
057 * <li>Establishes an SSE connection to receive server messages</li>
058 * <li>Handles endpoint discovery through SSE events</li>
059 * <li>Manages message serialization/deserialization using Jackson</li>
060 * <li>Provides graceful connection termination</li>
061 * </ul>
062 *
063 * <p>
064 * The transport supports two types of SSE events:
065 * <ul>
066 * <li>'endpoint' - Contains the URL for sending client messages</li>
067 * <li>'message' - Contains JSON-RPC message payload</li>
068 * </ul>
069 *
070 * @author Christian Tzolov
071 * @author Pavel Vlasov
072 * @see io.modelcontextprotocol.spec.McpTransport
073 * @see io.modelcontextprotocol.spec.McpClientTransport
074 */
075public class HttpClientTelemetrySseClientTransport implements McpClientTransport {
076
077        private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class);
078
079        /** SSE event type for JSON-RPC messages */
080        private static final String MESSAGE_EVENT_TYPE = "message";
081
082        /** SSE event type for endpoint discovery */
083        private static final String ENDPOINT_EVENT_TYPE = "endpoint";
084
085        /** Default SSE endpoint path */
086        private static final String SSE_ENDPOINT = "/sse";
087
088        /** Base URI for the MCP server */
089        private final String baseUri;
090
091        /** SSE client for handling server-sent events. Uses the /sse endpoint */
092        private final FlowSseClient sseClient;
093
094        /**
095         * HTTP client for sending messages to the server. Uses HTTP POST over the message
096         * endpoint
097         */
098        private final HttpClient httpClient;
099
100        /** JSON object mapper for message serialization/deserialization */
101        protected ObjectMapper objectMapper;
102
103        /** Flag indicating if the transport is in closing state */
104        private volatile boolean isClosing = false;
105
106        /** Latch for coordinating endpoint discovery */
107        private final CountDownLatch closeLatch = new CountDownLatch(1);
108
109        /** Holds the discovered message endpoint URL */
110        private final AtomicReference<String> messageEndpoint = new AtomicReference<>();
111
112        /** Holds the SSE connection future */
113        private final AtomicReference<CompletableFuture<Void>> connectionFuture = new AtomicReference<>();
114
115        private Tracer tracer;
116        
117        private BiConsumer<String, Long> durationConsumer;
118        
119        private TextMapPropagator propagator; 
120
121        /**
122         * Creates a new transport instance with default HTTP client and object mapper.
123         * @param baseUri the base URI of the MCP server
124         * @param tracer If not null, creates a span for sendMessage HTTP request. Pass <code>null</code> when using {@link TelemetryMcpClientTransportFilter} to avoid two nested sendMessage spans 
125         */
126        public HttpClientTelemetrySseClientTransport(
127                        String baseUri, 
128                        Tracer tracer,
129                        TextMapPropagator propagator, 
130                        BiConsumer<String, Long> durationConsumer) {
131                this(HttpClient.newBuilder(), baseUri, new ObjectMapper(), tracer, propagator, durationConsumer);
132        }
133
134        /**
135         * Creates a new transport instance with custom HTTP client builder and object mapper.
136         * @param clientBuilder the HTTP client builder to use
137         * @param baseUri the base URI of the MCP server
138         * @param objectMapper the object mapper for JSON serialization/deserialization
139         * @throws IllegalArgumentException if objectMapper or clientBuilder is null
140         */
141        public HttpClientTelemetrySseClientTransport(
142                        HttpClient.Builder clientBuilder, 
143                        String baseUri, 
144                        ObjectMapper objectMapper,
145                        Tracer tracer,
146                        TextMapPropagator propagator,                   
147                        BiConsumer<String, Long> durationConsumer) {
148                Assert.notNull(objectMapper, "ObjectMapper must not be null");
149                Assert.hasText(baseUri, "baseUri must not be empty");
150                Assert.notNull(clientBuilder, "clientBuilder must not be null");
151                this.baseUri = baseUri;
152                this.objectMapper = objectMapper;
153                this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build();
154                this.sseClient = new FlowSseClient(this.httpClient);
155                this.tracer = tracer;
156                this.propagator = propagator;
157                this.durationConsumer = durationConsumer;
158        }
159
160        /**
161         * Establishes the SSE connection with the server and sets up message handling.
162         *
163         * <p>
164         * This method:
165         * <ul>
166         * <li>Initiates the SSE connection</li>
167         * <li>Handles endpoint discovery events</li>
168         * <li>Processes incoming JSON-RPC messages</li>
169         * </ul>
170         * @param handler the function to process received JSON-RPC messages
171         * @return a Mono that completes when the connection is established
172         */
173        @Override
174        public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
175                CompletableFuture<Void> future = new CompletableFuture<>();
176                connectionFuture.set(future);
177
178                sseClient.subscribe(this.baseUri + SSE_ENDPOINT, new FlowSseClient.SseEventHandler() {
179                        @Override
180                        public void onEvent(SseEvent event) {
181                                if (isClosing) {
182                                        return;
183                                }
184
185                                try {
186                                        if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
187                                                String endpoint = event.data();
188                                                messageEndpoint.set(endpoint);
189                                                closeLatch.countDown();
190                                                future.complete(null);
191                                        }
192                                        else if (MESSAGE_EVENT_TYPE.equals(event.type())) {
193                                                JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
194                                                handler.apply(Mono.just(message)).subscribe();
195                                        }
196                                        else {
197                                                logger.error("Received unrecognized SSE event type: {}", event.type());
198                                        }
199                                }
200                                catch (IOException e) {
201                                        logger.error("Error processing SSE event", e);
202                                        future.completeExceptionally(e);
203                                }
204                        }
205
206                        @Override
207                        public void onError(Throwable error) {
208                                if (!isClosing) {
209                                        logger.error("SSE connection error", error);
210                                        future.completeExceptionally(error);
211                                }
212                        }
213                });
214
215                return Mono.fromFuture(future);
216        }
217
218        /**
219         * Sends a JSON-RPC message to the server.
220         *
221         * <p>
222         * This method waits for the message endpoint to be discovered before sending the
223         * message. The message is serialized to JSON and sent as an HTTP POST request.
224         * @param message the JSON-RPC message to send
225         * @return a Mono that completes when the message is sent
226         * @throws McpError if the message endpoint is not available or the wait times out
227         */
228        @Override
229        public Mono<Void> sendMessage(JSONRPCMessage message) {
230                if (isClosing) {
231                        return Mono.empty();
232                }
233
234                try {
235                        if (!closeLatch.await(10, TimeUnit.SECONDS)) {
236                                return Mono.error(new McpError("Failed to wait for the message endpoint"));
237                        }
238                }
239                catch (InterruptedException e) {
240                        return Mono.error(new McpError("Failed to wait for the message endpoint"));
241                }
242
243                String endpoint = messageEndpoint.get();
244                if (endpoint == null) {
245                        return Mono.error(new McpError("No message endpoint available"));
246                }
247
248                return Mono.deferContextual(contextView -> {
249                        Context parentContext = contextView.getOrDefault(Context.class, Context.current());                             
250                        long start = System.currentTimeMillis();
251                        URI requestURI = URI.create(this.baseUri + endpoint);
252                        
253                Span requestSpan = tracer == null ? Span.fromContext(parentContext) :   
254                        tracer
255                                .spanBuilder("sendMessage")
256                                .setAttribute("uri", requestURI.toString())
257                                .setSpanKind(SpanKind.CLIENT)
258                                .setParent(parentContext)
259                                .startSpan();
260                        try (Scope scope = requestSpan.makeCurrent()) {
261                                String jsonText = this.objectMapper.writeValueAsString(message);                                
262                        requestSpan.setAttribute("message", jsonText);
263                                Builder builder = getHttpRequestBuilder()
264                                        .uri(requestURI)
265                                        .header("Content-Type", "application/json")
266                                        .POST(HttpRequest.BodyPublishers.ofString(jsonText));
267                                
268                                Context telemetryContext = Context.current().with(requestSpan);
269                                propagator.inject(telemetryContext, builder, (b, name, value) -> b.header(name, value));                                                                
270                                HttpRequest request = builder.build();
271                                
272                                return Mono.fromFuture(
273                                                httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> {
274                                                        if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202
275                                                                        && response.statusCode() != 206) {
276                                                                logger.error("Error sending message: {}", response.statusCode());
277                                                        }
278                                                }))
279                                                .map(result -> {
280                                                if (durationConsumer != null) {
281                                                        durationConsumer.accept(requestURI.toString(), System.currentTimeMillis() - start);
282                                                }
283                                                requestSpan.setStatus(StatusCode.OK);
284                                                return result;
285                                                })
286                                                .onErrorMap(error -> {
287                                                        requestSpan.recordException(error);
288                                                        requestSpan.setStatus(StatusCode.ERROR);
289                                                        return error;
290                                                })
291                                                .doFinally(signal -> requestSpan.end());
292                        } catch (IOException e) {
293                                requestSpan.recordException(e);
294                                if (!isClosing) {
295                                        return Mono.error(new RuntimeException("Failed to serialize message", e));
296                                }
297                                return Mono.empty();
298                        }
299                });
300        }
301
302        protected Builder getHttpRequestBuilder() {
303                return HttpRequest.newBuilder();
304        }
305
306        /**
307         * Gracefully closes the transport connection.
308         *
309         * <p>
310         * Sets the closing flag and cancels any pending connection future. This prevents new
311         * messages from being sent and allows ongoing operations to complete.
312         * @return a Mono that completes when the closing process is initiated
313         */
314        @Override
315        public Mono<Void> closeGracefully() {
316                return Mono.fromRunnable(() -> {
317                        isClosing = true;
318                        CompletableFuture<Void> future = connectionFuture.get();
319                        if (future != null && !future.isDone()) {
320                                future.cancel(true);
321                        }
322                });
323        }
324
325        /**
326         * Unmarshals data to the specified type using the configured object mapper.
327         * @param data the data to unmarshal
328         * @param typeRef the type reference for the target type
329         * @param <T> the target type
330         * @return the unmarshalled object
331         */
332        @Override
333        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
334                return this.objectMapper.convertValue(data, typeRef);
335        }
336
337}