/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.client.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.WebRxSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.noear.solon.net.http.HttpResponse;
import org.noear.solon.net.http.HttpUtils;
import org.noear.solon.net.http.HttpUtilsBuilder;
import org.noear.solon.net.http.textstream.TextStreamUtil;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;

public class WebRxStreamableClientTransport
implements McpClientTransport {
    private static final Logger LOGGER = LoggerFactory.getLogger(WebRxStreamableClientTransport.class);
    private static final String DEFAULT_MCP_ENDPOINT = "/mcp";
    private static final String MCP_SESSION_ID = "Mcp-Session-Id";
    private static final String LAST_EVENT_ID = "Last-Event-ID";
    private static final String ACCEPT = "Accept";
    private static final String CONTENT_TYPE = "Content-Type";
    private static final String APPLICATION_JSON = "application/json";
    private static final String TEXT_EVENT_STREAM = "text/event-stream";
    private static final String APPLICATION_JSON_SEQ = "application/json-seq";
    private static final String DEFAULT_ACCEPT_VALUES = String.format("%s, %s", "application/json", "text/event-stream");
    private final WebRxSseClientTransport sseClientTransport;
    private final HttpUtilsBuilder webBuilder;
    private final String endpoint;
    private final ObjectMapper objectMapper;
    private final AtomicReference<String> lastEventId = new AtomicReference();
    private final AtomicReference<String> mcpSessionId = new AtomicReference();
    private final AtomicBoolean fallbackToSse = new AtomicBoolean(false);

    public WebRxStreamableClientTransport(HttpUtilsBuilder webBuilder, ObjectMapper objectMapper, String endpoint, WebRxSseClientTransport sseClientTransport) {
        this.webBuilder = webBuilder;
        this.objectMapper = objectMapper;
        this.endpoint = endpoint;
        this.sseClientTransport = sseClientTransport;
    }

    public static Builder builder(HttpUtilsBuilder webBuilder) {
        return new Builder(webBuilder);
    }

    @Override
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        if (this.fallbackToSse.get()) {
            return this.sseClientTransport.connect(handler);
        }
        return Mono.defer(() -> Mono.fromFuture(() -> {
            HttpUtils build = this.webBuilder.build(this.endpoint);
            build.header(ACCEPT, TEXT_EVENT_STREAM);
            String lastId = this.lastEventId.get();
            if (lastId != null) {
                build.header(LAST_EVENT_ID, lastId);
            }
            if (this.mcpSessionId.get() != null) {
                build.header(MCP_SESSION_ID, this.mcpSessionId.get());
            }
            return build.execAsync("POST");
        }).flatMap(response -> {
            if (this.mcpSessionId.get() != null && response.code() == 404) {
                this.mcpSessionId.set(null);
            }
            if (response.code() == 405 || response.code() == 404) {
                LOGGER.warn("Operation not allowed, falling back to SSE");
                this.fallbackToSse.set(true);
                return this.sseClientTransport.connect(handler);
            }
            return this.handleStreamingResponse((HttpResponse)response, handler);
        }).retryWhen((Retry)Retry.backoff((long)3L, (Duration)Duration.ofSeconds(3L)).filter(err -> err instanceof IllegalStateException)).onErrorResume(e -> {
            LOGGER.error("Streamable transport connection error", e);
            return Mono.error((Throwable)e);
        })).doOnTerminate(this::closeGracefully);
    }

    @Override
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
        return this.sendMessage(message, msg -> msg);
    }

    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        if (this.fallbackToSse.get()) {
            return this.fallbackToSse(message);
        }
        return this.serializeJson(message).flatMap(json -> {
            HttpUtils build = this.webBuilder.build(this.endpoint).bodyOfJson(json).header(ACCEPT, DEFAULT_ACCEPT_VALUES).header(CONTENT_TYPE, APPLICATION_JSON);
            if (this.mcpSessionId.get() != null) {
                build.header(MCP_SESSION_ID, this.mcpSessionId.get());
            }
            return Mono.fromFuture((CompletableFuture)build.execAsync("POST")).flatMap(response -> {
                String sessionId;
                if (message instanceof McpSchema.JSONRPCRequest && ((McpSchema.JSONRPCRequest)message).getMethod().equals("initialize") && (sessionId = response.header(MCP_SESSION_ID)) != null) {
                    this.mcpSessionId.set(sessionId);
                }
                if (response.code() == 202) {
                    return Mono.empty();
                }
                if (this.mcpSessionId.get() != null && response.code() == 404) {
                    this.mcpSessionId.set(null);
                }
                if (response.code() == 405 || response.code() == 404) {
                    LOGGER.warn("Operation not allowed, falling back to SSE");
                    this.fallbackToSse.set(true);
                    return this.fallbackToSse(message);
                }
                if (response.code() >= 400) {
                    return Mono.error((Throwable)new IllegalArgumentException("Unexpected status code: " + response.code()));
                }
                return this.handleStreamingResponse((HttpResponse)response, handler);
            });
        }).onErrorResume(e -> {
            LOGGER.error("Streamable transport sendMessages error", e);
            return Mono.error((Throwable)e);
        });
    }

    private Mono<Void> fallbackToSse(McpSchema.JSONRPCMessage msg) {
        if (msg instanceof McpSchema.JSONRPCBatchRequest) {
            McpSchema.JSONRPCBatchRequest batchReq = (McpSchema.JSONRPCBatchRequest)msg;
            return Flux.fromIterable(batchReq.getItems()).flatMap(this.sseClientTransport::sendMessage).then();
        }
        if (msg instanceof McpSchema.JSONRPCBatchResponse) {
            McpSchema.JSONRPCBatchResponse batch = (McpSchema.JSONRPCBatchResponse)msg;
            return Flux.fromIterable(batch.getItems()).flatMap(this.sseClientTransport::sendMessage).then();
        }
        return this.sseClientTransport.sendMessage(msg);
    }

    private Mono<String> serializeJson(McpSchema.JSONRPCMessage msg) {
        try {
            return Mono.just((Object)this.objectMapper.writeValueAsString((Object)msg));
        }
        catch (IOException e) {
            LOGGER.error("Error serializing JSON-RPC message", (Throwable)e);
            return Mono.error((Throwable)e);
        }
    }

    private Mono<Void> handleStreamingResponse(HttpResponse response, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        String contentType = response.header(CONTENT_TYPE);
        if (contentType.contains(APPLICATION_JSON_SEQ)) {
            return this.handleJsonStream(response, handler);
        }
        if (contentType.contains(TEXT_EVENT_STREAM)) {
            return this.handleSseStream(response, handler);
        }
        if (contentType.contains(APPLICATION_JSON)) {
            return this.handleSingleJson(response, handler);
        }
        return Mono.error((Throwable)new UnsupportedOperationException("Unsupported Content-Type: " + contentType));
    }

    private Mono<Void> handleSingleJson(HttpResponse response, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        return Mono.fromCallable(() -> {
            try {
                McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(this.objectMapper, new String(response.bodyAsBytes(), StandardCharsets.UTF_8));
                return (Mono)handler.apply(Mono.just((Object)msg));
            }
            catch (IOException e) {
                LOGGER.error("Error processing JSON response", (Throwable)e);
                return Mono.error((Throwable)e);
            }
        }).flatMap(Function.identity()).then();
    }

    private Mono<Void> handleJsonStream(HttpResponse response, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        return Flux.from((Publisher)TextStreamUtil.parseLineStream((InputStream)response.body())).flatMap(jsonLine -> {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, jsonLine);
                return (Publisher)handler.apply(Mono.just((Object)message));
            }
            catch (IOException e) {
                LOGGER.error("Error processing JSON line", (Throwable)e);
                return Mono.error((Throwable)e);
            }
        }).then();
    }

    private Mono<Void> handleSseStream(HttpResponse response, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        return Flux.from((Publisher)TextStreamUtil.parseSseStream((InputStream)response.body())).filter(sseEvent -> "message".equals(sseEvent.getEvent())).concatMap(sseEvent -> {
            String rawData = sseEvent.getData().trim();
            try {
                JsonNode node = this.objectMapper.readTree(rawData);
                ArrayList<McpSchema.JSONRPCMessage> messages = new ArrayList<McpSchema.JSONRPCMessage>();
                if (node.isArray()) {
                    for (JsonNode item : node) {
                        messages.add(McpSchema.deserializeJsonRpcMessage(this.objectMapper, item.toString()));
                    }
                } else if (node.isObject()) {
                    messages.add(McpSchema.deserializeJsonRpcMessage(this.objectMapper, node.toString()));
                } else {
                    String warning = "Unexpected JSON in SSE data: " + rawData;
                    LOGGER.warn(warning);
                    return Mono.error((Throwable)new IllegalArgumentException(warning));
                }
                return Flux.fromIterable(messages).concatMap(msg -> (Mono)handler.apply(Mono.just((Object)msg))).then(Mono.fromRunnable(() -> {
                    if (!sseEvent.getId().isEmpty()) {
                        this.lastEventId.set(sseEvent.getId());
                    }
                }));
            }
            catch (IOException e) {
                LOGGER.error("Error parsing SSE JSON: {}", (Object)rawData, (Object)e);
                return Mono.error((Throwable)e);
            }
        }).then();
    }

    @Override
    public Mono<Void> closeGracefully() {
        this.mcpSessionId.set(null);
        this.lastEventId.set(null);
        if (this.fallbackToSse.get()) {
            return this.sseClientTransport.closeGracefully();
        }
        return Mono.empty();
    }

    @Override
    public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
        return (T)this.objectMapper.convertValue(data, typeRef);
    }

    public static class Builder {
        private final HttpUtilsBuilder webBuilder;
        private ObjectMapper objectMapper = new ObjectMapper();
        private String endpoint = "/mcp";

        public Builder(HttpUtilsBuilder webBuilder) {
            Assert.notNull(webBuilder, "webBuilder must not be empty");
            this.webBuilder = webBuilder;
        }

        public Builder endpoint(String endpoint) {
            Assert.hasText(endpoint, "endpoint must not be null");
            this.endpoint = endpoint;
            return this;
        }

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "objectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public WebRxStreamableClientTransport build() {
            return new WebRxStreamableClientTransport(this.webBuilder, this.objectMapper, this.endpoint, new WebRxSseClientTransport(this.webBuilder, this.endpoint, this.objectMapper));
        }
    }
}

