001package org.nasdanika.ai.mcp;
002
003import java.util.Map.Entry;
004import java.util.function.BiConsumer;
005
006import io.modelcontextprotocol.server.McpServerFeatures.AsyncPromptSpecification;
007import io.modelcontextprotocol.server.McpServerFeatures.AsyncResourceSpecification;
008import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
009import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification;
010import io.modelcontextprotocol.server.McpServerFeatures.SyncResourceSpecification;
011import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification;
012import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
013import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
014import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
015import io.opentelemetry.api.trace.Span;
016import io.opentelemetry.api.trace.StatusCode;
017import io.opentelemetry.api.trace.Tracer;
018import io.opentelemetry.context.Context;
019import io.opentelemetry.context.Scope;
020import reactor.core.publisher.Mono;
021
022/**
023 * Filters (wraps) Mcp features for collecting telemetry
024 */
025public class McpTelemetryFilter {
026        
027        protected Tracer tracer;
028        protected BiConsumer<String,Long> durationConsumer;
029        
030        public McpTelemetryFilter(Tracer tracer, BiConsumer<String,Long> durationConsumer) {
031                
032                this.tracer = tracer;
033                this.durationConsumer = durationConsumer;
034        }
035                        
036        public SyncToolSpecification filter(SyncToolSpecification syncToolSpecification) {
037                return new SyncToolSpecification(
038                        syncToolSpecification.tool(), 
039                        (exchange, request) -> {
040                                long start = System.currentTimeMillis();
041                                Span span = tracer.spanBuilder("Sync tool " + syncToolSpecification.tool().name())
042                                        .setAttribute("description", syncToolSpecification.tool().description())
043                                        .startSpan();
044                                
045                                for (Entry<String, Object> re: request.entrySet()) {
046                                        span.setAttribute("request." + re.getKey(), String.valueOf(re.getValue()));
047                                }
048                                try (Scope scope = span.makeCurrent()) {                                
049                                        CallToolResult result = syncToolSpecification.call().apply(exchange, request);
050                                        span.setStatus(StatusCode.OK);
051                                        return result;
052                                } catch (RuntimeException e) {
053                                        span.recordException(e);
054                                        span.setStatus(StatusCode.ERROR);
055                                        throw e;
056                                } finally {
057                                        if (durationConsumer != null) {
058                                                durationConsumer.accept("tool.sync." + syncToolSpecification.tool().name(), System.currentTimeMillis() - start);
059                                        }
060                                        span.end();
061                                }                                                               
062                        });                                             
063        }
064        
065        public AsyncToolSpecification filter(AsyncToolSpecification asyncToolSpecification) {
066                return new AsyncToolSpecification(
067                        asyncToolSpecification.tool(),                          
068                        (exchange, request) -> {                                
069                                return Mono.deferContextual(contextView -> {
070                                        Context parentContext = contextView.getOrDefault(Context.class, Context.current());
071                                
072                                        long start = System.currentTimeMillis();                                        
073                                        Span span = tracer.spanBuilder("Async tool " + asyncToolSpecification.tool().name())
074                                                .setAttribute("description", asyncToolSpecification.tool().description())
075                                                .setParent(parentContext)
076                                                .startSpan();
077                                        
078                                        for (Entry<String, Object> re: request.entrySet()) {
079                                                span.setAttribute("request." + re.getKey(), String.valueOf(re.getValue()));
080                                        }
081                                
082                                        try (Scope scope = span.makeCurrent()) {
083                                                Mono<CallToolResult> publisher = asyncToolSpecification.call().apply(exchange, request);
084                                                return publisher
085                                                        .map(result -> {
086                                                        span.setStatus(StatusCode.OK);
087                                                                return result;
088                                                        })
089                                                        .onErrorMap(error -> {
090                                                        span.recordException(error);
091                                                        span.setStatus(StatusCode.ERROR);
092                                                                return error;
093                                                        })
094                                                .contextWrite(reactor.util.context.Context.of(Context.class, Context.current().with(span)))
095                                                        .doFinally(signal -> {
096                                                                if (durationConsumer != null) {
097                                                                        durationConsumer.accept("tool.sync." + asyncToolSpecification.tool().name(), System.currentTimeMillis() - start);
098                                                                }
099                                                                span.end();
100                                                        });
101                                        }
102                                });
103                        });                                             
104        }
105                
106        public SyncResourceSpecification filter(SyncResourceSpecification syncResourceSpecification) {
107                return new SyncResourceSpecification(
108                        syncResourceSpecification.resource(), 
109                        (exchange, request) -> {
110                                long start = System.currentTimeMillis();
111                                Span span = tracer.spanBuilder("Sync resource " + syncResourceSpecification.resource().name())
112                                        .setAttribute("description", syncResourceSpecification.resource().description())
113                                        .setAttribute("resource-uri", syncResourceSpecification.resource().uri())
114                                        .setAttribute("request-uri", request.uri())
115                                        .setAttribute("mime-type", syncResourceSpecification.resource().mimeType())
116                                        .startSpan();
117                                
118                                try (Scope scope = span.makeCurrent()) {                                
119                                        ReadResourceResult result = syncResourceSpecification.readHandler().apply(exchange, request);
120                                        span.setStatus(StatusCode.OK);
121                                        return result;
122                                } catch (RuntimeException e) {
123                                        span.recordException(e);
124                                        span.setStatus(StatusCode.ERROR);
125                                        throw e;
126                                } finally {
127                                        if (durationConsumer != null) {
128                                                durationConsumer.accept("tool.sync." + syncResourceSpecification.resource().name(), System.currentTimeMillis() - start);
129                                        }
130                                        span.end();
131                                }                                                               
132                        });                                             
133        }
134        
135        public AsyncResourceSpecification filter(AsyncResourceSpecification asyncResourceSpecification) {
136                return new AsyncResourceSpecification(
137                        asyncResourceSpecification.resource(),                          
138                        (exchange, request) -> {
139                                return Mono.deferContextual(contextView -> {
140                                        Context parentContext = contextView.getOrDefault(Context.class, Context.current());
141                                
142                                        long start = System.currentTimeMillis();                                        
143                                        Span span = tracer.spanBuilder("Async resource " + asyncResourceSpecification.resource().name())
144                                                .setAttribute("description", asyncResourceSpecification.resource().description())
145                                                .setAttribute("resource-uri", asyncResourceSpecification.resource().uri())
146                                                .setAttribute("request-uri", request.uri())
147                                                .setAttribute("mime-type", asyncResourceSpecification.resource().mimeType())
148                                                .setParent(parentContext)
149                                                .startSpan();
150
151                                        try (Scope scope = span.makeCurrent()) {                                
152                                                Mono<ReadResourceResult> publisher = asyncResourceSpecification.readHandler().apply(exchange, request);
153                                                return publisher
154                                                        .map(result -> {
155                                                        span.setStatus(StatusCode.OK);
156                                                                return result;
157                                                        })
158                                                        .onErrorMap(error -> {
159                                                        span.recordException(error);
160                                                        span.setStatus(StatusCode.ERROR);
161                                                                return error;
162                                                        })
163                                                .contextWrite(reactor.util.context.Context.of(Context.class, Context.current().with(span)))
164                                                        .doFinally(signal -> {
165                                                                if (durationConsumer != null) {
166                                                                        durationConsumer.accept("tool.sync." + asyncResourceSpecification.resource().name(), System.currentTimeMillis() - start);
167                                                                }
168                                                                span.end();
169                                                        });
170                                        }
171                                });
172                        });
173        }
174        
175        public SyncPromptSpecification filter(SyncPromptSpecification syncPromptSpecification) {
176        return new SyncPromptSpecification(
177                syncPromptSpecification.prompt(), 
178                (exchange, request) -> {
179                        long start = System.currentTimeMillis();
180                        Span span = tracer.spanBuilder("Sync prompt " + syncPromptSpecification.prompt().name())
181                                .setAttribute("description", syncPromptSpecification.prompt().description())
182                                .startSpan();
183                        
184                        try (Scope scope = span.makeCurrent()) {                                
185                                GetPromptResult result = syncPromptSpecification.promptHandler().apply(exchange, request);
186                                span.setStatus(StatusCode.OK);
187                                return result;
188                        } catch (RuntimeException e) {
189                                span.recordException(e);
190                                span.setStatus(StatusCode.ERROR);
191                                throw e;
192                        } finally {
193                                if (durationConsumer != null) {
194                                        durationConsumer.accept("tool.sync." + syncPromptSpecification.prompt().name(), System.currentTimeMillis() - start);
195                                }
196                                span.end();
197                        }                                                               
198                });                                             
199        }
200        
201        public AsyncPromptSpecification filter(AsyncPromptSpecification asyncPromptSpecification) {
202        return new AsyncPromptSpecification(
203                asyncPromptSpecification.prompt(),                              
204                (exchange, request) -> {                                
205                        return Mono.deferContextual(contextView -> {
206                                Context parentContext = contextView.getOrDefault(Context.class, Context.current());
207                        
208                                long start = System.currentTimeMillis();                                        
209                                Span span = tracer.spanBuilder("Async prompt " + asyncPromptSpecification.prompt().name())
210                                        .setAttribute("description", asyncPromptSpecification.prompt().description())
211                                        .setParent(parentContext)
212                                        .startSpan();
213                        
214                                try (Scope scope = span.makeCurrent()) {
215                                        Mono<GetPromptResult> publisher = asyncPromptSpecification.promptHandler().apply(exchange, request);
216                                        return publisher
217                                                .map(result -> {
218                                                span.setStatus(StatusCode.OK);
219                                                        return result;
220                                                })
221                                                .onErrorMap(error -> {
222                                                span.recordException(error);
223                                                span.setStatus(StatusCode.ERROR);
224                                                        return error;
225                                                })
226                                        .contextWrite(reactor.util.context.Context.of(Context.class, Context.current().with(span)))
227                                                .doFinally(signal -> {
228                                                        if (durationConsumer != null) {
229                                                                durationConsumer.accept("tool.sync." + asyncPromptSpecification.prompt().name(), System.currentTimeMillis() - start);
230                                                        }
231                                                        span.end();
232                                                });
233                                }
234                        });
235                });                                             
236        }       
237        
238}