001package org.nasdanika.ai.mcp;
002
003import java.util.function.Function;
004
005import com.fasterxml.jackson.core.type.TypeReference;
006
007import io.modelcontextprotocol.spec.McpClientTransport;
008import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
009import io.opentelemetry.api.trace.Span;
010import io.opentelemetry.api.trace.SpanKind;
011import io.opentelemetry.api.trace.StatusCode;
012import io.opentelemetry.api.trace.Tracer;
013import io.opentelemetry.context.Context;
014import io.opentelemetry.context.Scope;
015import reactor.core.publisher.Mono;
016
017/**
018 * Creates {@link Span}s for transport method calls. 
019 */
020public class TelemetryMcpClientTransportFilter implements McpClientTransport {
021        
022        private McpClientTransport target;
023        private Tracer tracer;
024        private Context context;
025
026        public TelemetryMcpClientTransportFilter(
027                        McpClientTransport target,
028                        Tracer tracer,
029                        Context context) {
030                
031                this.target = target;
032                this.tracer = tracer;
033                this.context = context;
034        }
035        
036        public TelemetryMcpClientTransportFilter(
037                        McpClientTransport target,
038                        Tracer tracer) {
039                
040                this.target = target;
041                this.tracer = tracer;
042                this.context = Context.current();
043        }       
044        
045        public void setContext(Context context) {
046                this.context = context;
047        }
048
049        @Override
050        public Mono<Void> closeGracefully() {
051                return Mono.deferContextual(contextView -> {
052                        Context parentContext = contextView.getOrDefault(Context.class, getContext());
053                Span span = tracer
054                                .spanBuilder("closeGracefully")
055                                .setSpanKind(SpanKind.CLIENT)
056                                .setParent(parentContext)
057                                .startSpan();
058
059                        try (Scope scope = span.makeCurrent()) {                                
060                                return 
061                                        target.closeGracefully()
062                                                .map(result -> {
063                                                span.setStatus(StatusCode.OK);
064                                                        return result;
065                                                })
066                                                .onErrorMap(error -> {
067                                                span.recordException(error);
068                                                span.setStatus(StatusCode.ERROR);
069                                                        return error;
070                                                })
071                                                .contextWrite(reactor.util.context.Context.of(Context.class, getContext().with(span)))
072                                                .doFinally(signal -> span.end());
073                        }
074                });             
075        }
076
077        @Override
078        public Mono<Void> sendMessage(JSONRPCMessage message) {
079                return Mono.deferContextual(contextView -> {
080                        Context parentContext = contextView.getOrDefault(Context.class, getContext());
081                        
082                Span span = tracer
083                                .spanBuilder("sendMessage")
084                                .setSpanKind(SpanKind.CLIENT)
085                                .setParent(parentContext)
086                                .setAttribute("message", message.toString())
087                                .startSpan();
088                
089                try (Scope scope = span.makeCurrent()) { 
090                                return 
091                                        target.sendMessage(message)
092                                                .map(result -> {
093                                                span.setStatus(StatusCode.OK);
094                                                        return result;
095                                                })
096                                                .onErrorMap(error -> {
097                                                span.recordException(error);
098                                                span.setStatus(StatusCode.ERROR);
099                                                        return error;
100                                                })
101                                                .contextWrite(reactor.util.context.Context.of(Context.class, getContext().with(span)))
102                                                .doFinally(signal -> {
103                                                        span.end();
104                                                });
105                }
106                });             
107        }
108
109        protected Context getContext() {
110                return context == null ? Context.current() : context;
111        }
112
113        @Override
114        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
115                return target.unmarshalFrom(data, typeRef);
116        }
117
118        @Override
119        public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
120                return Mono.deferContextual(contextView -> {
121                        Context parentContext = contextView.getOrDefault(Context.class, getContext());
122                Span span = tracer
123                                .spanBuilder("connect")
124                                .setSpanKind(SpanKind.CLIENT)
125                                .setParent(parentContext)
126                                .startSpan();
127
128                try (Scope scope = span.makeCurrent()) {
129                                return 
130                                        target.connect(filterHandler(handler))
131                                                .doOnNext(result -> {
132                                                span.setStatus(StatusCode.OK);                                          
133                                                })
134                                                .onErrorMap(error -> {
135                                                span.recordException(error);
136                                                span.setStatus(StatusCode.ERROR);
137                                                        return error;
138                                                })
139                                                .contextWrite(reactor.util.context.Context.of(Context.class, getContext().with(span)))
140                                                .doFinally(signal -> {
141                                                        span.end();
142                                                });
143                }
144                });             
145                
146        }
147
148        private Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> filterHandler(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
149                return requestMono -> {
150                        return Mono.deferContextual(contextView -> {
151                                Context parentContext = contextView.getOrDefault(Context.class, getContext());
152                        Span span = tracer
153                                        .spanBuilder("handle")
154                                        .setSpanKind(SpanKind.CLIENT)
155                                        .setParent(parentContext)
156                                        .startSpan();
157                        
158                        try (Scope scope = span.makeCurrent()) {
159                                        return 
160                                                handler.apply(requestMono)
161                                                        .map(result -> {
162                                                        span.setStatus(StatusCode.OK);
163                                                        span.setAttribute("response", result.toString());
164                                                                return result;
165                                                        })
166                                                        .onErrorMap(error -> {
167                                                        span.recordException(error);
168                                                        span.setStatus(StatusCode.ERROR);
169                                                                return error;
170                                                        })
171                                                        .contextWrite(reactor.util.context.Context.of(Context.class, getContext().with(span)))
172                                                        .doFinally(signal -> span.end());
173                        }
174                        });             
175                        
176                };
177        }
178
179}