001/* 002 * This class is an adaptation of https://github.com/Nasdanika/mcp-java-sdk/blob/main/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java 003 * to Reactor Netty and OpenTelemetry. 004 */ 005 006package org.nasdanika.ai.mcp.sse; 007 008import java.io.ByteArrayOutputStream; 009import java.io.IOException; 010import java.io.OutputStreamWriter; 011import java.io.Writer; 012import java.util.HashMap; 013import java.util.Map; 014import java.util.concurrent.ConcurrentHashMap; 015import java.util.concurrent.atomic.AtomicReference; 016import java.util.function.BiConsumer; 017import java.util.function.BiFunction; 018 019import org.json.JSONObject; 020import org.nasdanika.http.TelemetryFilter; 021import org.nasdanika.telemetry.TelemetryUtil; 022import org.reactivestreams.Publisher; 023import org.slf4j.Logger; 024import org.slf4j.LoggerFactory; 025 026import com.fasterxml.jackson.core.type.TypeReference; 027import com.fasterxml.jackson.databind.ObjectMapper; 028 029import io.modelcontextprotocol.spec.McpError; 030import io.modelcontextprotocol.spec.McpSchema; 031import io.modelcontextprotocol.spec.McpServerSession; 032import io.modelcontextprotocol.spec.McpServerTransport; 033import io.modelcontextprotocol.spec.McpServerTransportProvider; 034import io.modelcontextprotocol.util.Assert; 035import io.netty.buffer.ByteBuf; 036import io.netty.buffer.ByteBufAllocator; 037import io.netty.handler.codec.http.HttpResponseStatus; 038import io.netty.handler.codec.http.QueryStringDecoder; 039import io.opentelemetry.api.trace.Span; 040import io.opentelemetry.api.trace.SpanBuilder; 041import io.opentelemetry.api.trace.StatusCode; 042import io.opentelemetry.api.trace.Tracer; 043import io.opentelemetry.context.Context; 044import io.opentelemetry.context.ContextKey; 045import io.opentelemetry.context.propagation.TextMapGetter; 046import io.opentelemetry.context.propagation.TextMapPropagator; 047import reactor.core.Exceptions; 048import reactor.core.publisher.Flux; 049import reactor.core.publisher.FluxSink; 050import reactor.core.publisher.Mono; 051import reactor.netty.http.server.HttpServerRequest; 052import reactor.netty.http.server.HttpServerResponse; 053import reactor.netty.http.server.HttpServerRoutes; 054//import reactor.core.publisher.ImmutableSignal; 055 056/** 057 * Reactor Netty implementation of {@link McpServerTransportProvider}. 058 */ 059public class HttpServerRoutesTransportProvider implements McpServerTransportProvider { 060 061 private static final String ID_KEY = "id"; 062 063 /** 064 * Maps message IDs to telemetry contexts injected in POST to be extracted in sendMessage. 065 */ 066 private Map<String, Map<String,String>> contextMap = new ConcurrentHashMap<>(); 067 068 private static final Logger logger = LoggerFactory.getLogger(HttpServerRoutesTransportProvider.class); 069 070 public static final String DEFAULT_SSE_ENDPOINT = "/sse"; 071 072 public static final String DEFAULT_BASE_URL = ""; 073 074 private static final String SESSION_ID_PARAMETER = "sessionId"; 075 public static final String MESSAGE_EVENT_TYPE = "message"; 076 public static final String ENDPOINT_EVENT_TYPE = "endpoint"; 077 078 private final ObjectMapper objectMapper; 079 080 private final String baseUrl; 081 082 private final String messageEndpoint; 083 084 private final String sseEndpoint; 085 086 private McpServerSession.Factory sessionFactory; 087 088 private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>(); 089 090 private volatile boolean isClosing = false; 091 092 private Tracer tracer; 093 private TextMapPropagator propagator; 094 private TelemetryFilter telemetryFilter; 095 096 public HttpServerRoutesTransportProvider( 097 ObjectMapper objectMapper, 098 String messageEndpoint, 099 HttpServerRoutes httpServerRoutes, 100 Tracer tracer, 101 boolean resolveRemoteHostName, 102 TextMapPropagator propagator, 103 BiConsumer<String, Long> durationConsumer) { 104 105 this( 106 objectMapper, 107 messageEndpoint, 108 DEFAULT_SSE_ENDPOINT, 109 httpServerRoutes, 110 tracer, 111 resolveRemoteHostName, 112 propagator, 113 durationConsumer); 114 } 115 116 public HttpServerRoutesTransportProvider( 117 ObjectMapper objectMapper, 118 String messageEndpoint, 119 String sseEndpoint, 120 HttpServerRoutes httpServerRoutes, 121 Tracer tracer, 122 boolean resolveRemoteHostName, 123 TextMapPropagator propagator, 124 BiConsumer<String, Long> durationConsumer) { 125 126 this( 127 objectMapper, 128 DEFAULT_BASE_URL, 129 messageEndpoint, 130 sseEndpoint, 131 httpServerRoutes, 132 tracer, 133 resolveRemoteHostName, 134 propagator, 135 durationConsumer); 136 } 137 138 public HttpServerRoutesTransportProvider( 139 ObjectMapper objectMapper, 140 String baseUrl, 141 String messageEndpoint, 142 String sseEndpoint, 143 HttpServerRoutes httpServerRoutes, 144 Tracer tracer, 145 boolean resolveRemoteHostName, 146 TextMapPropagator propagator, 147 BiConsumer<String, Long> durationConsumer) { 148 149 this.objectMapper = objectMapper; 150 this.baseUrl = baseUrl; 151 this.messageEndpoint = messageEndpoint; 152 this.sseEndpoint = sseEndpoint; 153 154 this.tracer = tracer; 155 this.propagator = propagator; 156 this.telemetryFilter = new TelemetryFilter( 157 tracer, 158 propagator, 159 durationConsumer, 160 resolveRemoteHostName); 161 httpServerRoutes 162 .get(this.sseEndpoint, serveSse()) 163 .post(this.messageEndpoint, this::processMessage); 164 } 165 166 public static class Builder { 167 168 private ObjectMapper objectMapper; 169 private String baseUrl = DEFAULT_BASE_URL; 170 private String messageEndpoint; 171 private String sseEndpoint = DEFAULT_SSE_ENDPOINT; 172 private Tracer tracer; 173 private boolean resolveRemoteHostName; 174 private TextMapPropagator propagator; 175 private BiConsumer<String, Long> durationConsumer; 176 177 public Builder objectMapper(ObjectMapper objectMapper) { 178 Assert.notNull(objectMapper, "ObjectMapper must not be null"); 179 this.objectMapper = objectMapper; 180 return this; 181 } 182 183 public Builder basePath(String baseUrl) { 184 Assert.notNull(baseUrl, "basePath must not be null"); 185 this.baseUrl = baseUrl; 186 return this; 187 } 188 189 public Builder messageEndpoint(String messageEndpoint) { 190 Assert.notNull(messageEndpoint, "Message endpoint must not be null"); 191 this.messageEndpoint = messageEndpoint; 192 return this; 193 } 194 195 public Builder sseEndpoint(String sseEndpoint) { 196 Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); 197 this.sseEndpoint = sseEndpoint; 198 return this; 199 } 200 201 public Builder tracer(Tracer tracer) { 202 this.tracer = tracer; 203 return this; 204 } 205 206 public Builder resolveRemoteHostName(boolean resolveRemoteHostName) { 207 this.resolveRemoteHostName = resolveRemoteHostName; 208 return this; 209 } 210 211 public Builder propagator(TextMapPropagator propagator) { 212 this.propagator = propagator; 213 return this; 214 } 215 216 /** 217 * Consumer of path, duration in milliseconds. E.g. histogram 218 * @param durationConsumer 219 * @return 220 */ 221 public Builder setDurationConsumer(BiConsumer<String, Long> durationConsumer) { 222 this.durationConsumer = durationConsumer; 223 return this; 224 } 225 226 public HttpServerRoutesTransportProvider build(HttpServerRoutes httpServerRoutes) { 227 return new HttpServerRoutesTransportProvider( 228 objectMapper, 229 baseUrl, 230 messageEndpoint, 231 sseEndpoint, 232 httpServerRoutes, 233 tracer, 234 resolveRemoteHostName, 235 propagator, 236 durationConsumer); 237 } 238 239 } 240 241 public static Builder builder() { 242 return new Builder(); 243 } 244 245 private record ServerSentEvent(String event, String data) {} 246 247 private class HttpServerRoutesSessionTransport implements McpServerTransport { 248 249 private final FluxSink<ServerSentEvent> sink; 250 251 public HttpServerRoutesSessionTransport(FluxSink<ServerSentEvent> sink) { 252 this.sink = sink; 253 } 254 255 @Override 256 public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) { 257 AtomicReference<Span> spanRef = new AtomicReference<>(); 258 259 Context contextDelegate = new Context() { 260 261 private Context getTarget() { 262 Context target = Context.current(); 263 Span span = spanRef.get(); 264 if (span == null) { 265 return target; 266 } 267 return target.with(span); 268 } 269 270 @Override 271 public <V> Context with(ContextKey<V> k1, V v1) { 272 return getTarget().with(k1, v1); 273 } 274 275 @Override 276 public <V> V get(ContextKey<V> key) { 277 return getTarget().get(key); 278 } 279 }; 280 281 return Mono.fromSupplier(() -> { 282 try { 283 SpanBuilder spanBuilder = TelemetryUtil.buildSpan(tracer.spanBuilder("sessionTransport.sendMessage")); 284 String jsonText = objectMapper.writeValueAsString(message); 285 JSONObject jObj = new JSONObject(jsonText); 286 if (jObj.has(ID_KEY)) { 287 Map<String,String> parentSpanData = contextMap.remove(jObj.get(ID_KEY).toString()); 288 if (parentSpanData != null) { 289 TextMapGetter<Map<String, String>> mapper = new TextMapGetter<Map<String,String>>() { 290 291 @Override 292 public Iterable<String> keys(Map<String, String> carrier) { 293 return carrier.keySet(); 294 } 295 296 @Override 297 public String get(Map<String, String> carrier, String key) { 298 return carrier.get(key); 299 } 300 301 }; 302 Context telemetrycontext = propagator.extract( 303 Context.current(), 304 parentSpanData, 305 mapper); 306 spanBuilder.setParent(telemetrycontext); 307 } 308 } 309 Span span = spanBuilder.startSpan(); 310 spanRef.set(span); 311 return jsonText; 312 } 313 catch (IOException e) { 314 throw Exceptions.propagate(e); 315 } 316 }) 317// .doOnEach(signal -> { 318// Context ctx = signal.getContextView().getOrDefault(Context.class, Context.current()); 319// if (signal instanceof ImmutableSignal) { 320// // add data as attribute 321// } 322//// System.out.println(signal); 323// Span signalSpan = Span.fromContext(ctx); 324// signalSpan.addEvent(signal.toString()); 325// }) 326 .doOnNext(jsonText -> { 327 sink.next(new ServerSentEvent("message", jsonText)); 328 Span span = spanRef.get(); 329 if (span != null) { 330 span.setAttribute("message", jsonText); 331 span.setStatus(StatusCode.OK); 332 } 333 }) 334 .doOnError(e -> { 335 Throwable exception = Exceptions.unwrap(e); 336 Span span = spanRef.get(); 337 if (span != null) { 338 span.recordException(exception); 339 span.setStatus(StatusCode.ERROR); 340 } 341 sink.error(exception); 342 }) 343 .contextWrite(reactor.util.context.Context.of(Context.class, contextDelegate)) 344 .doFinally(signal -> { 345 Span span = spanRef.get(); 346 if (span != null) { 347 span.end(); 348 } 349 }) 350 .then(); 351 } 352 353 @Override 354 public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) { 355 return objectMapper.convertValue(data, typeRef); 356 } 357 358 @Override 359 public Mono<Void> closeGracefully() { 360 return Mono.fromRunnable(sink::complete); 361 } 362 363 @Override 364 public void close() { 365 sink.complete(); 366 } 367 368 } 369 370 private BiFunction<HttpServerRequest, HttpServerResponse, Publisher<Void>> serveSse() { 371 Flux<ServerSentEvent> flux = Flux.create(sink -> { 372 HttpServerRoutesSessionTransport sessionTransport = new HttpServerRoutesSessionTransport(sink); 373 McpServerSession session = sessionFactory.create(sessionTransport); 374 String sessionId = session.getId(); 375 376 logger.debug("Created new SSE connection for session: {}", sessionId); 377 sessions.put(sessionId, session); 378 379 logger.debug("Sending initial endpoint event to session: {}", sessionId); 380 sink.next(new ServerSentEvent("endpoint", this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId)); 381 sink.onCancel(() -> { 382 logger.debug("Session {} cancelled", sessionId); 383 sessions.remove(sessionId); 384 }); 385 }); 386 return (request, response) -> 387 response.sse() 388 .send(flux.map(this::toByteBuf), b -> true); 389 } 390 391 /** 392 * Transforms the Object to ByteBuf following the expected SSE format. 393 */ 394 private ByteBuf toByteBuf(ServerSentEvent event) { 395 ByteArrayOutputStream out = new ByteArrayOutputStream(); 396 try (Writer writer = new OutputStreamWriter(out)) { 397 writer.write("event: "); 398 writer.write(event.event()); 399 writer.write("\n"); 400 writer.write("data: "); 401 writer.write(event.data()); 402 writer.write("\n\n"); 403 } catch (Exception e) { 404 throw new RuntimeException(e); 405 } 406 return ByteBufAllocator.DEFAULT 407 .buffer() 408 .writeBytes(out.toByteArray()); 409 } 410 411 @Override 412 public void setSessionFactory(McpServerSession.Factory sessionFactory) { 413 this.sessionFactory = sessionFactory; 414 } 415 416 private Publisher<Void> processMessage(HttpServerRequest request, HttpServerResponse response) { 417 if (isClosing) { 418 return response 419 .status(HttpResponseStatus.SERVICE_UNAVAILABLE) 420 .sendString(Mono.just("Server is shutting down")) 421 .then(); 422 } 423 QueryStringDecoder decoder = new QueryStringDecoder(request.uri()); 424 if (!decoder.parameters().containsKey(SESSION_ID_PARAMETER)) { 425 return response 426 .status(HttpResponseStatus.BAD_REQUEST) 427 .sendString(Mono.just("Session ID is missing")) 428 .then(); 429 } 430 431 McpServerSession session = sessions.get(decoder.parameters().get(SESSION_ID_PARAMETER).get(0)); 432 Mono<String> requestBody = Mono.deferContextual(contextView -> { 433 Context context = contextView.getOrDefault(Context.class, Context.current()); 434 Span span = Span.fromContext(context); 435 436 return request 437 .receive() 438 .aggregate() 439 .asString() 440 .doOnNext(rb -> { 441 if (span != null) { 442 span.setAttribute("request", rb); 443 JSONObject jRequest = new JSONObject(rb); 444 if (jRequest.has(ID_KEY)) { 445 Map<String,String> carrier = new HashMap<>(); 446 propagator.inject(context, carrier, (cr, name, value) -> cr.put(name, value)); 447 if (!carrier.isEmpty()) { 448 contextMap.put(jRequest.get(ID_KEY).toString(), carrier); 449 } 450 } 451 } 452 }); 453 }); 454 455 return telemetryFilter.filter(request, requestBody) 456 .flatMap(body -> { 457 try { 458 McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); 459 Mono<Void> handled = session.handle(message); 460 return handled 461 .flatMap(rsp -> response.status(HttpResponseStatus.OK).then().onErrorResume(error -> { 462 logger.error("Error processing message: {}", error.getMessage()); 463 McpError mcpError = new McpError(error.getMessage()); 464 return response 465 .status(HttpResponseStatus.BAD_REQUEST) 466 .sendString(Mono.just(mcpError.getJsonRpcError().toString())) 467 .then(); 468 469 })); 470 } catch (IllegalArgumentException | IOException e) { 471 // TODO - span record error 472 logger.error("Failed to deserialize message: {}", e.getMessage()); 473 McpError mcpError = new McpError("Invalid message format"); 474 return response 475 .status(HttpResponseStatus.BAD_REQUEST) 476 .sendString(Mono.just(mcpError.getJsonRpcError().toString())) 477 .then(); 478 } 479 }); 480 481 } 482 483 @Override 484 public Mono<Void> notifyClients(String method, Map<String, Object> params) { 485 // TODO - telemetry span 486 if (sessions.isEmpty()) { 487 logger.debug("No active sessions to broadcast message to"); 488 return Mono.empty(); 489 } 490 491 logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); 492 493 return Flux.fromIterable(sessions.values()) 494 .flatMap(session -> session.sendNotification(method, params) 495 .doOnError( 496 e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) 497 .onErrorComplete()) 498 .then(); 499 } 500 501 @Override 502 public Mono<Void> closeGracefully() { 503 // TODO - telemetry span 504 return Flux.fromIterable(sessions.values()) 505 .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) 506 .flatMap(McpServerSession::closeGracefully) 507 .then(); 508 } 509 510}