001package org.nasdanika.ai.mcp;
002
003import java.util.ArrayList;
004import java.util.Collection;
005import java.util.Collections;
006import java.util.List;
007
008import org.nasdanika.capability.CapabilityLoader;
009import org.nasdanika.cli.CommandGroup;
010
011import io.modelcontextprotocol.server.McpAsyncServer;
012import io.modelcontextprotocol.server.McpServer;
013import io.modelcontextprotocol.server.McpServerFeatures.AsyncPromptSpecification;
014import io.modelcontextprotocol.server.McpServerFeatures.AsyncResourceSpecification;
015import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
016import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
017import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.Builder;
018import io.modelcontextprotocol.spec.McpServerTransportProvider;
019import io.opentelemetry.api.OpenTelemetry;
020import io.opentelemetry.api.trace.Tracer;
021import reactor.core.publisher.Mono;
022
023/**
024 * Base class for MCP server commands. 
025 * This class does nothing - you'd need to override one or more of 
026 * getResourceSpecifications(), getToolSpecifications(), or getPromptSpecifications() methods
027 * to provide server capabilities.
028 */
029public class McpServerCommandBase extends CommandGroup implements McpAsyncServerProvider {
030
031        private OpenTelemetry openTelemetry;
032
033        public McpServerCommandBase(OpenTelemetry openTelemetry) {
034                super();
035                this.openTelemetry = openTelemetry;
036        }
037
038        public McpServerCommandBase(CapabilityLoader capabilityLoader, OpenTelemetry openTelemetry) {
039                super(capabilityLoader);
040                this.openTelemetry = openTelemetry;
041        }
042        
043        public Collection<AsyncResourceSpecification> getResourceSpecifications() {
044                return Collections.emptyList();
045        }
046        
047        public Collection<AsyncToolSpecification> getToolSpecifications() {
048                return Collections.emptyList();
049        }
050        
051        public Collection<AsyncPromptSpecification> getPromptSpecifications() {
052                return Collections.emptyList();
053        }
054        
055        protected boolean isLogging() {
056                return true;
057        }
058        
059        protected void measureDuration(String name, long duration) {
060                
061        }
062        
063        @Override
064        public McpAsyncServer createServer(McpServerTransportProvider transportProvider) {
065                Builder capabilitiesBuilder = ServerCapabilities.builder();
066                Collection<AsyncResourceSpecification> resourceSpecifications = getResourceSpecifications();
067                if (!resourceSpecifications.isEmpty()) {
068                        capabilitiesBuilder.resources(true, true);
069                }
070                Collection<AsyncToolSpecification> toolSpecifications = getToolSpecifications();
071                if (!toolSpecifications.isEmpty()) {
072                        capabilitiesBuilder.tools(true);
073                }
074                Collection<AsyncPromptSpecification> promptSpecifications = getPromptSpecifications();
075                if (!promptSpecifications.isEmpty()) {
076                        capabilitiesBuilder.prompts(true);
077                }
078                if (isLogging()) {
079                        capabilitiesBuilder.logging();
080                }               
081                
082                McpAsyncServer asyncServer = McpServer.async(transportProvider)
083                                .serverInfo(getName(), getVersion())
084                                .capabilities(capabilitiesBuilder.build())
085                                .build();
086                
087                Tracer tracer = openTelemetry.getTracer(getInstrumentationScopeName());
088                McpTelemetryFilter mcpTelemetryFilter = new McpTelemetryFilter(tracer, this::measureDuration);          
089                
090                List<Mono<Void>> registrations = new ArrayList<>();
091                for (AsyncResourceSpecification rSpec: resourceSpecifications) {
092                        registrations.add(asyncServer.addResource(mcpTelemetryFilter.filter(rSpec)));
093                }
094                for (AsyncToolSpecification tSpec: toolSpecifications) {
095                        registrations.add(asyncServer.addTool(mcpTelemetryFilter.filter(tSpec)));
096                }
097                for (AsyncPromptSpecification pSpec: promptSpecifications) {
098                        registrations.add(asyncServer.addPrompt(mcpTelemetryFilter.filter(pSpec)));
099                }
100                
101                return Mono.zip(registrations, ra -> asyncServer).block();
102        }
103
104        protected String getInstrumentationScopeName() {
105                return getClass().getName();
106        }
107
108        /**
109         * MCP Server name, this implementation returns command name.
110         */
111        protected String getName() {
112                return spec.name();
113        }
114        
115        /**
116         * MCP Server name, this implementation returns command name.
117         */
118        protected String getVersion() {
119                String[] version = spec.version();
120                if (version == null) {
121                        return "(unknown)";
122                }
123                if (version.length == 1) {
124                        return version[0];
125                }
126                return String.join(" ", version);
127        }
128
129}