/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebRxSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpSession;
import io.modelcontextprotocol.spec.StatelessMcpSession;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.noear.solon.SolonApp;
import org.noear.solon.core.handle.Context;
import org.noear.solon.core.handle.Entity;
import org.noear.solon.web.sse.SseEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class WebRxStreamableServerTransportProvider
implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(WebRxStreamableServerTransportProvider.class);
    private static final String MCP_SESSION_ID = "Mcp-Session-Id";
    private static final String APPLICATION_JSON = "application/json";
    private static final String TEXT_EVENT_STREAM = "text/event-stream";
    private static final String DEFAULT_MCP_ENDPOINT = "/mcp";
    private final ObjectMapper objectMapper;
    private final String endpoint;
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<String, McpServerSession>();
    private McpServerSession.Factory sessionFactory;

    public WebRxStreamableServerTransportProvider(ObjectMapper objectMapper, String endpoint) {
        this.objectMapper = objectMapper;
        this.endpoint = endpoint;
    }

    public void sendHeartbeat() {
        for (McpServerSession session : this.sessions.values()) {
            ((WebRxSseServerTransportProvider.WebRxMcpSessionTransport)session.getTransport()).sendHeartbeat();
        }
    }

    public void toHttpHandler(SolonApp app) {
        if (app != null) {
            app.post(this.endpoint, this::doPost);
            app.delete(this.endpoint, this::doDelete);
        }
    }

    @Override
    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    @Override
    public Mono<Void> notifyClients(String method, Map<String, Object> params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    @Override
    public Mono<Void> closeGracefully() {
        logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(McpSession::closeGracefully).then();
    }

    public void doPost(Context ctx) throws Throwable {
        String accept = ctx.headerOrDefault("Accept", "");
        List acceptTypes = Arrays.stream(accept.split(",")).map(String::trim).collect(Collectors.toList());
        if (!acceptTypes.contains(APPLICATION_JSON) && !acceptTypes.contains(TEXT_EVENT_STREAM)) {
            ctx.status(406, "Legacy transport not available");
            return;
        }
        StreamableHttpServerTransport transport = new StreamableHttpServerTransport(ctx, this.objectMapper);
        McpSession session = this.getOrCreateSession(ctx.header(MCP_SESSION_ID), transport);
        if (!"stateless".equals(session.getId())) {
            ctx.headerSet(MCP_SESSION_ID, session.getId());
        }
        Flux<McpSchema.JSONRPCMessage> messages = this.parseRequestBodyAsStream(ctx);
        if (accept.contains(TEXT_EVENT_STREAM)) {
            ctx.contentType(TEXT_EVENT_STREAM);
            Mono mono = messages.flatMap(session::handle).collectList().flatMap(response -> Mono.just((Object)new Entity())).onErrorResume(error -> {
                logger.error("Error processing  message: {}", (Object)error.getMessage());
                return Mono.just((Object)new Entity().status(500).body((Object)new McpError((Object)error.getMessage())));
            }).doOnTerminate(() -> this.closeGracefully());
            ctx.returnValue((Object)mono);
        } else if (accept.contains(APPLICATION_JSON)) {
            ctx.contentType(APPLICATION_JSON);
            Mono mono = messages.flatMap(session::handle).collectList().flatMap(responses -> Mono.just((Object)new Entity())).onErrorResume(error -> {
                logger.error("Error processing  message: {}", (Object)error.getMessage());
                return Mono.just((Object)new Entity().status(500).body((Object)new McpError((Object)error.getMessage())));
            }).doOnTerminate(() -> this.closeGracefully());
            ctx.returnValue((Object)mono);
        } else {
            ctx.status(406, "Unsupported Accept header");
        }
    }

    public void doDelete(Context ctx) throws IOException {
        String sessionId = ctx.header("mcp-session-id");
        if (sessionId == null || !this.sessions.containsKey(sessionId)) {
            ctx.status(404, "Session not found");
            return;
        }
        McpSession session = this.sessions.remove(sessionId);
        session.closeGracefully().subscribe();
        ctx.status(204);
    }

    private Flux<McpSchema.JSONRPCMessage> parseRequestBodyAsStream(Context req) {
        return Mono.fromCallable(() -> {
            try (InputStream inputStream = req.bodyAsStream();){
                JsonNode node = this.objectMapper.readTree(inputStream);
                if (node.isArray()) {
                    ArrayList<McpSchema.JSONRPCMessage> messages = new ArrayList<McpSchema.JSONRPCMessage>();
                    for (JsonNode item : node) {
                        McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, item);
                        messages.add(message);
                    }
                    ArrayList<McpSchema.JSONRPCMessage> arrayList = messages;
                    return arrayList;
                }
                if (node.isObject()) {
                    McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, node);
                    List<McpSchema.JSONRPCMessage> list = Collections.singletonList(message);
                    return list;
                }
                List list = Collections.emptyList();
                return list;
            }
        }).flatMapMany(Flux::fromIterable);
    }

    private McpSession getOrCreateSession(String sessionId, McpServerTransport transport) {
        if (sessionId != null && this.sessionFactory != null) {
            return this.sessions.get(sessionId);
        }
        if (this.sessionFactory != null) {
            McpServerSession session = this.sessionFactory.create(transport);
            this.sessions.put(session.getId(), session);
            return session;
        }
        return new StatelessMcpSession(transport);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private ObjectMapper objectMapper;
        private String endpoint = "/mcp";

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder endpoint(String endpoint) {
            Assert.notNull(endpoint, "Endpoint must not be null");
            this.endpoint = endpoint;
            return this;
        }

        public WebRxStreamableServerTransportProvider build() {
            Assert.notNull(this.objectMapper, "ObjectMapper must be set");
            Assert.notNull(this.endpoint, "Endpoint must be set");
            return new WebRxStreamableServerTransportProvider(this.objectMapper, this.endpoint);
        }
    }

    public static class StreamableHttpServerTransport
    implements McpServerTransport {
        private final ObjectMapper objectMapper;
        private final Context ctx;

        public StreamableHttpServerTransport(Context ctx, ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
            this.ctx = ctx;
        }

        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                try {
                    String json = this.objectMapper.writeValueAsString((Object)message);
                    if (WebRxStreamableServerTransportProvider.APPLICATION_JSON.equals(this.ctx.contentTypeNew())) {
                        this.ctx.output(json);
                    } else {
                        SseEvent event = new SseEvent().id(UUID.randomUUID().toString()).data((Object)json);
                        this.ctx.render((Object)event);
                    }
                }
                catch (Throwable e) {
                    throw new RuntimeException("Failed to send message", e);
                }
            });
        }

        @Override
        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)this.objectMapper.convertValue(data, typeRef);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                try {
                    this.ctx.flush();
                    this.ctx.close();
                }
                catch (IOException iOException) {
                    // empty catch block
                }
            });
        }
    }
}

