/*
 * Decompiled with CFR 0.152.
 */
package org.nasdanika.ai.mcp.sse;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanBuilder;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.ContextKey;
import io.opentelemetry.context.ImplicitContextKeyed;
import io.opentelemetry.context.propagation.TextMapGetter;
import io.opentelemetry.context.propagation.TextMapPropagator;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import org.json.JSONObject;
import org.nasdanika.http.TelemetryFilter;
import org.nasdanika.telemetry.TelemetryUtil;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.netty.http.server.HttpServerRequest;
import reactor.netty.http.server.HttpServerResponse;
import reactor.netty.http.server.HttpServerRoutes;
import reactor.util.context.ContextView;

public class HttpServerRoutesTransportProvider
implements McpServerTransportProvider {
    private static final String ID_KEY = "id";
    private Map<String, Map<String, String>> contextMap = new ConcurrentHashMap<String, Map<String, String>>();
    private static final Logger logger = LoggerFactory.getLogger(HttpServerRoutesTransportProvider.class);
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    public static final String DEFAULT_BASE_URL = "";
    private static final String SESSION_ID_PARAMETER = "sessionId";
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private final ObjectMapper objectMapper;
    private final String baseUrl;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private McpServerSession.Factory sessionFactory;
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<String, McpServerSession>();
    private volatile boolean isClosing = false;
    private Tracer tracer;
    private TextMapPropagator propagator;
    private TelemetryFilter telemetryFilter;

    public HttpServerRoutesTransportProvider(ObjectMapper objectMapper, String messageEndpoint, HttpServerRoutes httpServerRoutes, Tracer tracer, boolean resolveRemoteHostName, TextMapPropagator propagator, BiConsumer<String, Long> durationConsumer) {
        this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT, httpServerRoutes, tracer, resolveRemoteHostName, propagator, durationConsumer);
    }

    public HttpServerRoutesTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint, HttpServerRoutes httpServerRoutes, Tracer tracer, boolean resolveRemoteHostName, TextMapPropagator propagator, BiConsumer<String, Long> durationConsumer) {
        this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, httpServerRoutes, tracer, resolveRemoteHostName, propagator, durationConsumer);
    }

    public HttpServerRoutesTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint, HttpServerRoutes httpServerRoutes, Tracer tracer, boolean resolveRemoteHostName, TextMapPropagator propagator, BiConsumer<String, Long> durationConsumer) {
        this.objectMapper = objectMapper;
        this.baseUrl = baseUrl;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
        this.tracer = tracer;
        this.propagator = propagator;
        this.telemetryFilter = new TelemetryFilter(tracer, propagator, durationConsumer, resolveRemoteHostName);
        httpServerRoutes.get(this.sseEndpoint, this.serveSse()).post(this.messageEndpoint, this::processMessage);
    }

    public static Builder builder() {
        return new Builder();
    }

    private BiFunction<HttpServerRequest, HttpServerResponse, Publisher<Void>> serveSse() {
        Flux flux = Flux.create(sink -> {
            HttpServerRoutesSessionTransport sessionTransport = new HttpServerRoutesSessionTransport((FluxSink<ServerSentEvent>)sink);
            McpServerSession session = this.sessionFactory.create((McpServerTransport)sessionTransport);
            String sessionId = session.getId();
            logger.debug("Created new SSE connection for session: {}", (Object)sessionId);
            this.sessions.put(sessionId, session);
            logger.debug("Sending initial endpoint event to session: {}", (Object)sessionId);
            sink.next((Object)new ServerSentEvent(ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId));
            sink.onCancel(() -> {
                logger.debug("Session {} cancelled", (Object)sessionId);
                this.sessions.remove(sessionId);
            });
        });
        return (request, response) -> response.sse().send((Publisher)flux.map(this::toByteBuf), b -> true);
    }

    private ByteBuf toByteBuf(ServerSentEvent event) {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        try (OutputStreamWriter writer = new OutputStreamWriter(out);){
            writer.write("event: ");
            writer.write(event.event());
            writer.write("\n");
            writer.write("data: ");
            writer.write(event.data());
            writer.write("\n\n");
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return ByteBufAllocator.DEFAULT.buffer().writeBytes(out.toByteArray());
    }

    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    private Publisher<Void> processMessage(HttpServerRequest request, HttpServerResponse response) {
        if (this.isClosing) {
            return response.status(HttpResponseStatus.SERVICE_UNAVAILABLE).sendString((Publisher)Mono.just((Object)"Server is shutting down")).then();
        }
        QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
        if (!decoder.parameters().containsKey(SESSION_ID_PARAMETER)) {
            return response.status(HttpResponseStatus.BAD_REQUEST).sendString((Publisher)Mono.just((Object)"Session ID is missing")).then();
        }
        McpServerSession session = this.sessions.get(((List)decoder.parameters().get(SESSION_ID_PARAMETER)).get(0));
        Mono requestBody = Mono.deferContextual(contextView -> {
            Context context = (Context)contextView.getOrDefault(Context.class, (Object)Context.current());
            Span span = Span.fromContext((Context)context);
            return request.receive().aggregate().asString().doOnNext(rb -> {
                if (span != null) {
                    span.setAttribute("request", rb);
                    JSONObject jRequest = new JSONObject(rb);
                    if (jRequest.has(ID_KEY)) {
                        HashMap carrier = new HashMap();
                        this.propagator.inject(context, carrier, (cr, name, value) -> cr.put(name, value));
                        if (!carrier.isEmpty()) {
                            this.contextMap.put(jRequest.get(ID_KEY).toString(), carrier);
                        }
                    }
                }
            });
        });
        return this.telemetryFilter.filter(request, requestBody).flatMap(body -> {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((ObjectMapper)this.objectMapper, (String)body);
                Mono handled = session.handle(message);
                return handled.flatMap(rsp -> response.status(HttpResponseStatus.OK).then().onErrorResume(error -> {
                    logger.error("Error processing  message: {}", (Object)error.getMessage());
                    McpError mcpError = new McpError((Object)error.getMessage());
                    return response.status(HttpResponseStatus.BAD_REQUEST).sendString((Publisher)Mono.just((Object)mcpError.getJsonRpcError().toString())).then();
                }));
            }
            catch (IOException | IllegalArgumentException e) {
                logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
                McpError mcpError = new McpError((Object)"Invalid message format");
                return response.status(HttpResponseStatus.BAD_REQUEST).sendString((Publisher)Mono.just((Object)mcpError.getJsonRpcError().toString())).then();
            }
        });
    }

    public Mono<Void> notifyClients(String method, Map<String, Object> params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size())).flatMap(McpServerSession::closeGracefully).then();
    }

    public static class Builder {
        private ObjectMapper objectMapper;
        private String baseUrl = "";
        private String messageEndpoint;
        private String sseEndpoint = "/sse";
        private Tracer tracer;
        private boolean resolveRemoteHostName;
        private TextMapPropagator propagator;
        private BiConsumer<String, Long> durationConsumer;

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull((Object)objectMapper, (String)"ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder basePath(String baseUrl) {
            Assert.notNull((Object)baseUrl, (String)"basePath must not be null");
            this.baseUrl = baseUrl;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
            this.messageEndpoint = messageEndpoint;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.notNull((Object)sseEndpoint, (String)"SSE endpoint must not be null");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public Builder tracer(Tracer tracer) {
            this.tracer = tracer;
            return this;
        }

        public Builder resolveRemoteHostName(boolean resolveRemoteHostName) {
            this.resolveRemoteHostName = resolveRemoteHostName;
            return this;
        }

        public Builder propagator(TextMapPropagator propagator) {
            this.propagator = propagator;
            return this;
        }

        public Builder setDurationConsumer(BiConsumer<String, Long> durationConsumer) {
            this.durationConsumer = durationConsumer;
            return this;
        }

        public HttpServerRoutesTransportProvider build(HttpServerRoutes httpServerRoutes) {
            return new HttpServerRoutesTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint, httpServerRoutes, this.tracer, this.resolveRemoteHostName, this.propagator, this.durationConsumer);
        }
    }

    private record ServerSentEvent(String event, String data) {
    }

    private class HttpServerRoutesSessionTransport
    implements McpServerTransport {
        private final FluxSink<ServerSentEvent> sink;

        public HttpServerRoutesSessionTransport(FluxSink<ServerSentEvent> sink) {
            this.sink = sink;
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            final AtomicReference spanRef = new AtomicReference();
            Context contextDelegate = new Context(){

                private Context getTarget() {
                    Context target = Context.current();
                    Span span = (Span)spanRef.get();
                    if (span == null) {
                        return target;
                    }
                    return target.with((ImplicitContextKeyed)span);
                }

                public <V> Context with(ContextKey<V> k1, V v1) {
                    return this.getTarget().with(k1, v1);
                }

                public <V> V get(ContextKey<V> key) {
                    return (V)this.getTarget().get(key);
                }
            };
            return Mono.fromSupplier(() -> {
                try {
                    Map<String, String> parentSpanData;
                    SpanBuilder spanBuilder = TelemetryUtil.buildSpan((SpanBuilder)HttpServerRoutesTransportProvider.this.tracer.spanBuilder("sessionTransport.sendMessage"));
                    String jsonText = HttpServerRoutesTransportProvider.this.objectMapper.writeValueAsString((Object)message);
                    JSONObject jObj = new JSONObject(jsonText);
                    if (jObj.has(HttpServerRoutesTransportProvider.ID_KEY) && (parentSpanData = HttpServerRoutesTransportProvider.this.contextMap.remove(jObj.get(HttpServerRoutesTransportProvider.ID_KEY).toString())) != null) {
                        TextMapGetter<Map<String, String>> mapper = new TextMapGetter<Map<String, String>>(){

                            public Iterable<String> keys(Map<String, String> carrier) {
                                return carrier.keySet();
                            }

                            public String get(Map<String, String> carrier, String key) {
                                return carrier.get(key);
                            }
                        };
                        Context telemetrycontext = HttpServerRoutesTransportProvider.this.propagator.extract(Context.current(), parentSpanData, (TextMapGetter)mapper);
                        spanBuilder.setParent(telemetrycontext);
                    }
                    Span span = spanBuilder.startSpan();
                    spanRef.set(span);
                    return jsonText;
                }
                catch (IOException e) {
                    throw Exceptions.propagate((Throwable)e);
                }
            }).doOnNext(jsonText -> {
                this.sink.next((Object)new ServerSentEvent(HttpServerRoutesTransportProvider.MESSAGE_EVENT_TYPE, (String)jsonText));
                Span span = (Span)spanRef.get();
                if (span != null) {
                    span.setAttribute(HttpServerRoutesTransportProvider.MESSAGE_EVENT_TYPE, jsonText);
                    span.setStatus(StatusCode.OK);
                }
            }).doOnError(e -> {
                Throwable exception = Exceptions.unwrap((Throwable)e);
                Span span = (Span)spanRef.get();
                if (span != null) {
                    span.recordException(exception);
                    span.setStatus(StatusCode.ERROR);
                }
                this.sink.error(exception);
            }).contextWrite((ContextView)reactor.util.context.Context.of(Context.class, (Object)contextDelegate)).doFinally(signal -> {
                Span span = (Span)spanRef.get();
                if (span != null) {
                    span.end();
                }
            }).then();
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)HttpServerRoutesTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> this.sink.complete());
        }

        public void close() {
            this.sink.complete();
        }
    }
}

