001/*
002 * This class is an adaptation of https://github.com/Nasdanika/mcp-java-sdk/blob/main/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java
003 * to Reactor Netty and OpenTelemetry.
004 */
005
006package org.nasdanika.ai.mcp.sse;
007
008import java.io.ByteArrayOutputStream;
009import java.io.IOException;
010import java.io.OutputStreamWriter;
011import java.io.Writer;
012import java.util.HashMap;
013import java.util.Map;
014import java.util.concurrent.ConcurrentHashMap;
015import java.util.concurrent.atomic.AtomicReference;
016import java.util.function.BiConsumer;
017import java.util.function.BiFunction;
018
019import org.json.JSONObject;
020import org.nasdanika.http.TelemetryFilter;
021import org.nasdanika.telemetry.TelemetryUtil;
022import org.reactivestreams.Publisher;
023import org.slf4j.Logger;
024import org.slf4j.LoggerFactory;
025
026import com.fasterxml.jackson.core.type.TypeReference;
027import com.fasterxml.jackson.databind.ObjectMapper;
028
029import io.modelcontextprotocol.spec.McpError;
030import io.modelcontextprotocol.spec.McpSchema;
031import io.modelcontextprotocol.spec.McpServerSession;
032import io.modelcontextprotocol.spec.McpServerTransport;
033import io.modelcontextprotocol.spec.McpServerTransportProvider;
034import io.modelcontextprotocol.util.Assert;
035import io.netty.buffer.ByteBuf;
036import io.netty.buffer.ByteBufAllocator;
037import io.netty.handler.codec.http.HttpResponseStatus;
038import io.netty.handler.codec.http.QueryStringDecoder;
039import io.opentelemetry.api.trace.Span;
040import io.opentelemetry.api.trace.SpanBuilder;
041import io.opentelemetry.api.trace.StatusCode;
042import io.opentelemetry.api.trace.Tracer;
043import io.opentelemetry.context.Context;
044import io.opentelemetry.context.ContextKey;
045import io.opentelemetry.context.propagation.TextMapGetter;
046import io.opentelemetry.context.propagation.TextMapPropagator;
047import reactor.core.Exceptions;
048import reactor.core.publisher.Flux;
049import reactor.core.publisher.FluxSink;
050import reactor.core.publisher.Mono;
051import reactor.netty.http.server.HttpServerRequest;
052import reactor.netty.http.server.HttpServerResponse;
053import reactor.netty.http.server.HttpServerRoutes;
054//import reactor.core.publisher.ImmutableSignal;
055
056/**
057 * Reactor Netty implementation of {@link McpServerTransportProvider}.
058 */
059public class HttpServerRoutesTransportProvider implements McpServerTransportProvider {
060        
061        private static final String ID_KEY = "id";
062
063        /**
064         * Maps message IDs to telemetry contexts injected in POST to be extracted in sendMessage. 
065         */
066        private Map<String, Map<String,String>> contextMap = new ConcurrentHashMap<>();
067        
068        private static final Logger logger = LoggerFactory.getLogger(HttpServerRoutesTransportProvider.class);  
069
070        public static final String DEFAULT_SSE_ENDPOINT = "/sse";
071
072        public static final String DEFAULT_BASE_URL = "";
073                
074        private static final String SESSION_ID_PARAMETER = "sessionId";
075        public static final String MESSAGE_EVENT_TYPE = "message";
076        public static final String ENDPOINT_EVENT_TYPE = "endpoint";    
077
078        private final ObjectMapper objectMapper;
079
080        private final String baseUrl;
081
082        private final String messageEndpoint;
083
084        private final String sseEndpoint;
085
086        private McpServerSession.Factory sessionFactory;
087
088        private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();
089
090        private volatile boolean isClosing = false;
091
092        private Tracer tracer;
093        private TextMapPropagator propagator;
094        private TelemetryFilter telemetryFilter;
095
096        public HttpServerRoutesTransportProvider(
097                        ObjectMapper objectMapper, 
098                        String messageEndpoint, 
099                        HttpServerRoutes httpServerRoutes,
100                        Tracer tracer,
101                        boolean resolveRemoteHostName,
102                        TextMapPropagator propagator, 
103                        BiConsumer<String, Long> durationConsumer) {
104                
105                this(
106                        objectMapper, 
107                        messageEndpoint, 
108                        DEFAULT_SSE_ENDPOINT, 
109                        httpServerRoutes, 
110                        tracer, 
111                        resolveRemoteHostName,
112                        propagator,
113                        durationConsumer);
114        }
115
116        public HttpServerRoutesTransportProvider(
117                        ObjectMapper objectMapper, 
118                        String messageEndpoint, 
119                        String sseEndpoint,
120                        HttpServerRoutes httpServerRoutes,
121                        Tracer tracer,
122                        boolean resolveRemoteHostName,
123                        TextMapPropagator propagator, 
124                        BiConsumer<String, Long> durationConsumer) {
125                
126                this(
127                        objectMapper, 
128                        DEFAULT_BASE_URL, 
129                        messageEndpoint, 
130                        sseEndpoint, 
131                        httpServerRoutes, 
132                        tracer, 
133                        resolveRemoteHostName,
134                        propagator,
135                        durationConsumer);
136        }
137
138        public HttpServerRoutesTransportProvider(
139                        ObjectMapper objectMapper, 
140                        String baseUrl, 
141                        String messageEndpoint,
142                        String sseEndpoint,
143                        HttpServerRoutes httpServerRoutes,
144                        Tracer tracer,
145                        boolean resolveRemoteHostName,
146                        TextMapPropagator propagator, 
147                        BiConsumer<String, Long> durationConsumer) {
148
149                this.objectMapper = objectMapper;
150                this.baseUrl = baseUrl;
151                this.messageEndpoint = messageEndpoint;
152                this.sseEndpoint = sseEndpoint;
153                
154                this.tracer = tracer;
155                this.propagator = propagator;
156                this.telemetryFilter = new TelemetryFilter(
157                                tracer, 
158                                propagator, 
159                                durationConsumer, 
160                                resolveRemoteHostName);                         
161                httpServerRoutes
162                        .get(this.sseEndpoint, serveSse())
163                        .post(this.messageEndpoint, this::processMessage);
164        }
165        
166        public static class Builder {
167
168                private ObjectMapper objectMapper;
169                private String baseUrl = DEFAULT_BASE_URL;
170                private String messageEndpoint;
171                private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
172                private Tracer tracer;
173                private boolean resolveRemoteHostName;
174                private TextMapPropagator propagator;
175                private BiConsumer<String, Long> durationConsumer;
176                                
177                public Builder objectMapper(ObjectMapper objectMapper) {
178                        Assert.notNull(objectMapper, "ObjectMapper must not be null");
179                        this.objectMapper = objectMapper;
180                        return this;
181                }
182
183                public Builder basePath(String baseUrl) {
184                        Assert.notNull(baseUrl, "basePath must not be null");
185                        this.baseUrl = baseUrl;
186                        return this;
187                }
188
189                public Builder messageEndpoint(String messageEndpoint) {
190                        Assert.notNull(messageEndpoint, "Message endpoint must not be null");
191                        this.messageEndpoint = messageEndpoint;
192                        return this;
193                }
194
195                public Builder sseEndpoint(String sseEndpoint) {
196                        Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
197                        this.sseEndpoint = sseEndpoint;
198                        return this;
199                }
200                
201                public Builder tracer(Tracer tracer) {
202                        this.tracer = tracer;
203                        return this;
204                }
205                
206                public Builder resolveRemoteHostName(boolean resolveRemoteHostName) {
207                        this.resolveRemoteHostName = resolveRemoteHostName;
208                        return this;
209                }
210                
211                public Builder propagator(TextMapPropagator propagator) {
212                        this.propagator = propagator;
213                        return this;
214                }
215                
216                /**
217                 * Consumer of path, duration in milliseconds. E.g. histogram
218                 * @param durationConsumer
219                 * @return
220                 */
221                public Builder setDurationConsumer(BiConsumer<String, Long> durationConsumer) {
222                        this.durationConsumer = durationConsumer;
223                        return this;
224                }
225
226                public HttpServerRoutesTransportProvider build(HttpServerRoutes httpServerRoutes) {
227                        return new HttpServerRoutesTransportProvider(
228                                        objectMapper, 
229                                        baseUrl, 
230                                        messageEndpoint, 
231                                        sseEndpoint,
232                                        httpServerRoutes,
233                                        tracer,
234                                        resolveRemoteHostName,
235                                        propagator,
236                                        durationConsumer);
237                }
238
239        }
240        
241        public static Builder builder() {
242                return new Builder();
243        }       
244                
245        private record ServerSentEvent(String event, String data) {}    
246        
247        private class HttpServerRoutesSessionTransport implements McpServerTransport {
248
249                private final FluxSink<ServerSentEvent> sink;
250
251                public HttpServerRoutesSessionTransport(FluxSink<ServerSentEvent> sink) {
252                        this.sink = sink;
253                }
254
255                @Override
256                public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
257                        AtomicReference<Span> spanRef = new AtomicReference<>();
258                        
259                        Context contextDelegate = new Context() {
260                                
261                                private Context getTarget() {
262                                        Context target = Context.current();
263                                        Span span = spanRef.get();
264                                        if (span == null) {
265                                                return target;
266                                        }
267                                        return target.with(span);
268                                }
269                                
270                                @Override
271                                public <V> Context with(ContextKey<V> k1, V v1) {
272                                        return getTarget().with(k1, v1);
273                                }
274                                
275                                @Override
276                                public <V> V get(ContextKey<V> key) {
277                                        return getTarget().get(key);
278                                }
279                        };
280                        
281                        return Mono.fromSupplier(() -> {
282                                try {
283                                        SpanBuilder spanBuilder = TelemetryUtil.buildSpan(tracer.spanBuilder("sessionTransport.sendMessage"));
284                                        String jsonText = objectMapper.writeValueAsString(message);
285                                        JSONObject jObj = new JSONObject(jsonText);
286                                        if (jObj.has(ID_KEY)) {
287                                                Map<String,String> parentSpanData = contextMap.remove(jObj.get(ID_KEY).toString());
288                                                if (parentSpanData != null) {
289                                                        TextMapGetter<Map<String, String>> mapper = new TextMapGetter<Map<String,String>>() {
290
291                                                                @Override
292                                                                public Iterable<String> keys(Map<String, String> carrier) {
293                                                                        return carrier.keySet();
294                                                                }
295
296                                                                @Override
297                                                                public String get(Map<String, String> carrier, String key) {
298                                                                        return carrier.get(key);
299                                                                }
300                                
301                                                        };
302                                                        Context telemetrycontext = propagator.extract(
303                                                                        Context.current(),
304                                                                        parentSpanData,
305                                                                        mapper);
306                                                        spanBuilder.setParent(telemetrycontext);
307                                                }
308                                        }
309                                        Span span = spanBuilder.startSpan();
310                                        spanRef.set(span);
311                                        return jsonText;
312                                }
313                                catch (IOException e) {
314                                        throw Exceptions.propagate(e);
315                                }
316                        })              
317//                      .doOnEach(signal -> {
318//                              Context ctx = signal.getContextView().getOrDefault(Context.class, Context.current());
319//                              if (signal instanceof ImmutableSignal) {
320//                                      // add data as attribute
321//                              }
322////                            System.out.println(signal);
323//                              Span signalSpan = Span.fromContext(ctx);
324//                              signalSpan.addEvent(signal.toString());
325//                      })                                      
326                        .doOnNext(jsonText -> {
327                                sink.next(new ServerSentEvent("message", jsonText));
328                                Span span = spanRef.get();
329                                if (span != null) {
330                                        span.setAttribute("message", jsonText);
331                                        span.setStatus(StatusCode.OK);
332                                }
333                        })
334                        .doOnError(e -> {
335                                Throwable exception = Exceptions.unwrap(e);
336                                Span span = spanRef.get();
337                                if (span != null) {
338                                        span.recordException(exception);
339                                        span.setStatus(StatusCode.ERROR);
340                                }
341                                sink.error(exception);
342                        })
343                .contextWrite(reactor.util.context.Context.of(Context.class, contextDelegate))
344                        .doFinally(signal -> {
345                                Span span = spanRef.get();
346                                if (span != null) {
347                                        span.end();
348                                }
349                        })
350                        .then();
351                }
352
353                @Override
354                public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
355                        return objectMapper.convertValue(data, typeRef);
356                }
357
358                @Override
359                public Mono<Void> closeGracefully() {
360                        return Mono.fromRunnable(sink::complete);
361                }
362
363                @Override
364                public void close() {
365                        sink.complete();
366                }
367
368        }
369        
370        private BiFunction<HttpServerRequest, HttpServerResponse, Publisher<Void>> serveSse() {
371                Flux<ServerSentEvent> flux = Flux.create(sink -> {
372                        HttpServerRoutesSessionTransport sessionTransport = new HttpServerRoutesSessionTransport(sink);
373                        McpServerSession session = sessionFactory.create(sessionTransport);
374                        String sessionId = session.getId();
375
376                        logger.debug("Created new SSE connection for session: {}", sessionId);
377                        sessions.put(sessionId, session);
378
379                        logger.debug("Sending initial endpoint event to session: {}", sessionId);
380                        sink.next(new ServerSentEvent("endpoint", this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId));
381                        sink.onCancel(() -> {
382                                logger.debug("Session {} cancelled", sessionId);
383                                sessions.remove(sessionId);
384                        });                     
385                });
386                return (request, response) ->
387                        response.sse()
388                                .send(flux.map(this::toByteBuf), b -> true);
389        }
390        
391        /**
392         * Transforms the Object to ByteBuf following the expected SSE format.
393         */
394        private ByteBuf toByteBuf(ServerSentEvent event) {
395                ByteArrayOutputStream out = new ByteArrayOutputStream();
396                try (Writer writer = new OutputStreamWriter(out)) {
397                        writer.write("event: ");
398                        writer.write(event.event());
399                        writer.write("\n");
400                        writer.write("data: ");
401                        writer.write(event.data());
402                        writer.write("\n\n");
403                } catch (Exception e) {
404                        throw new RuntimeException(e);
405                }
406                return ByteBufAllocator.DEFAULT
407                                       .buffer()
408                                       .writeBytes(out.toByteArray());
409        }       
410
411        @Override
412        public void setSessionFactory(McpServerSession.Factory sessionFactory) {
413                this.sessionFactory = sessionFactory;
414        }
415        
416        private Publisher<Void> processMessage(HttpServerRequest request, HttpServerResponse response) {
417                if (isClosing) {
418                        return response
419                                        .status(HttpResponseStatus.SERVICE_UNAVAILABLE)
420                                        .sendString(Mono.just("Server is shutting down"))
421                                        .then();
422                }
423                QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
424                if (!decoder.parameters().containsKey(SESSION_ID_PARAMETER)) {
425                        return response
426                                        .status(HttpResponseStatus.BAD_REQUEST)
427                                        .sendString(Mono.just("Session ID is missing"))
428                                        .then();
429                }
430                
431                McpServerSession session = sessions.get(decoder.parameters().get(SESSION_ID_PARAMETER).get(0));
432                Mono<String> requestBody = Mono.deferContextual(contextView -> {
433                        Context context = contextView.getOrDefault(Context.class, Context.current());
434                        Span span = Span.fromContext(context);
435                        
436                        return request
437                                .receive()
438                                .aggregate()
439                                .asString()
440                                .doOnNext(rb -> {
441                                        if (span != null) {
442                                                span.setAttribute("request", rb);
443                                                JSONObject jRequest = new JSONObject(rb);
444                                                if (jRequest.has(ID_KEY)) {
445                                                        Map<String,String> carrier = new HashMap<>();
446                                                        propagator.inject(context, carrier, (cr, name, value) -> cr.put(name, value));          
447                                                        if (!carrier.isEmpty()) {
448                                                                contextMap.put(jRequest.get(ID_KEY).toString(), carrier);
449                                                        }
450                                                }
451                                        }                                       
452                                });                     
453                });
454                                
455                return telemetryFilter.filter(request, requestBody)             
456                                .flatMap(body -> {
457                                        try {
458                                                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
459                                                Mono<Void> handled = session.handle(message);
460                                                return handled
461                                                                .flatMap(rsp -> response.status(HttpResponseStatus.OK).then().onErrorResume(error -> {
462                                                                        logger.error("Error processing  message: {}", error.getMessage());
463                                                                        McpError mcpError = new McpError(error.getMessage());
464                                                                        return response
465                                                                                        .status(HttpResponseStatus.BAD_REQUEST)
466                                                                                        .sendString(Mono.just(mcpError.getJsonRpcError().toString()))
467                                                                                        .then();
468                                                                        
469                                                                }));
470                                        } catch (IllegalArgumentException | IOException e) {
471                                                // TODO - span record error
472                                                logger.error("Failed to deserialize message: {}", e.getMessage());
473                                                McpError mcpError = new McpError("Invalid message format");
474                                                return response
475                                                                .status(HttpResponseStatus.BAD_REQUEST)
476                                                                .sendString(Mono.just(mcpError.getJsonRpcError().toString()))
477                                                                .then();
478                                        }
479                                });
480                
481        }       
482        
483        @Override
484        public Mono<Void> notifyClients(String method, Map<String, Object> params) {
485                // TODO - telemetry span
486                if (sessions.isEmpty()) {
487                        logger.debug("No active sessions to broadcast message to");
488                        return Mono.empty();
489                }
490
491                logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
492
493                return Flux.fromIterable(sessions.values())
494                        .flatMap(session -> session.sendNotification(method, params)
495                                .doOnError(
496                                                e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()))
497                                .onErrorComplete())
498                        .then();
499        }
500
501        @Override
502        public Mono<Void> closeGracefully() {
503                // TODO - telemetry span
504                return Flux.fromIterable(sessions.values())
505                        .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()))
506                        .flatMap(McpServerSession::closeGracefully)
507                        .then();
508        }
509
510}