/*
 * Decompiled with CFR 0.152.
 */
package org.nasdanika.ai.mcp;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.FlowSseClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.ImplicitContextKeyed;
import io.opentelemetry.context.Scope;
import io.opentelemetry.context.propagation.TextMapPropagator;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

public class HttpClientTelemetrySseClientTransport
implements McpClientTransport {
    private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class);
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String SSE_ENDPOINT = "/sse";
    private final String baseUri;
    private final FlowSseClient sseClient;
    private final HttpClient httpClient;
    protected ObjectMapper objectMapper;
    private volatile boolean isClosing = false;
    private final CountDownLatch closeLatch = new CountDownLatch(1);
    private final AtomicReference<String> messageEndpoint = new AtomicReference();
    private final AtomicReference<CompletableFuture<Void>> connectionFuture = new AtomicReference();
    private Tracer tracer;
    private BiConsumer<String, Long> durationConsumer;
    private TextMapPropagator propagator;

    public HttpClientTelemetrySseClientTransport(String baseUri, Tracer tracer, TextMapPropagator propagator, BiConsumer<String, Long> durationConsumer) {
        this(HttpClient.newBuilder(), baseUri, new ObjectMapper(), tracer, propagator, durationConsumer);
    }

    public HttpClientTelemetrySseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper, Tracer tracer, TextMapPropagator propagator, BiConsumer<String, Long> durationConsumer) {
        Assert.notNull((Object)objectMapper, (String)"ObjectMapper must not be null");
        Assert.hasText((String)baseUri, (String)"baseUri must not be empty");
        Assert.notNull((Object)clientBuilder, (String)"clientBuilder must not be null");
        this.baseUri = baseUri;
        this.objectMapper = objectMapper;
        this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10L)).build();
        this.sseClient = new FlowSseClient(this.httpClient);
        this.tracer = tracer;
        this.propagator = propagator;
        this.durationConsumer = durationConsumer;
    }

    public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        final CompletableFuture future = new CompletableFuture();
        this.connectionFuture.set(future);
        this.sseClient.subscribe(this.baseUri + SSE_ENDPOINT, new FlowSseClient.SseEventHandler(){

            public void onEvent(FlowSseClient.SseEvent event) {
                if (HttpClientTelemetrySseClientTransport.this.isClosing) {
                    return;
                }
                try {
                    if (HttpClientTelemetrySseClientTransport.ENDPOINT_EVENT_TYPE.equals(event.type())) {
                        String endpoint = event.data();
                        HttpClientTelemetrySseClientTransport.this.messageEndpoint.set(endpoint);
                        HttpClientTelemetrySseClientTransport.this.closeLatch.countDown();
                        future.complete(null);
                    } else if (HttpClientTelemetrySseClientTransport.MESSAGE_EVENT_TYPE.equals(event.type())) {
                        McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((ObjectMapper)HttpClientTelemetrySseClientTransport.this.objectMapper, (String)event.data());
                        ((Mono)handler.apply(Mono.just((Object)message))).subscribe();
                    } else {
                        logger.error("Received unrecognized SSE event type: {}", (Object)event.type());
                    }
                }
                catch (IOException e) {
                    logger.error("Error processing SSE event", (Throwable)e);
                    future.completeExceptionally(e);
                }
            }

            public void onError(Throwable error) {
                if (!HttpClientTelemetrySseClientTransport.this.isClosing) {
                    logger.error("SSE connection error", error);
                    future.completeExceptionally(error);
                }
            }
        });
        return Mono.fromFuture(future);
    }

    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
        if (this.isClosing) {
            return Mono.empty();
        }
        try {
            if (!this.closeLatch.await(10L, TimeUnit.SECONDS)) {
                return Mono.error((Throwable)new McpError((Object)"Failed to wait for the message endpoint"));
            }
        }
        catch (InterruptedException e) {
            return Mono.error((Throwable)new McpError((Object)"Failed to wait for the message endpoint"));
        }
        String endpoint = this.messageEndpoint.get();
        if (endpoint == null) {
            return Mono.error((Throwable)new McpError((Object)"No message endpoint available"));
        }
        return Mono.deferContextual(contextView -> {
            Mono mono;
            block9: {
                Context parentContext = (Context)contextView.getOrDefault(Context.class, (Object)Context.current());
                long start = System.currentTimeMillis();
                URI requestURI = URI.create(this.baseUri + endpoint);
                Span requestSpan = this.tracer == null ? Span.fromContext((Context)parentContext) : this.tracer.spanBuilder("sendMessage").setAttribute("uri", requestURI.toString()).setSpanKind(SpanKind.CLIENT).setParent(parentContext).startSpan();
                Scope scope = requestSpan.makeCurrent();
                try {
                    String jsonText = this.objectMapper.writeValueAsString((Object)message);
                    requestSpan.setAttribute(MESSAGE_EVENT_TYPE, jsonText);
                    HttpRequest.Builder builder = this.getHttpRequestBuilder().uri(requestURI).header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonText));
                    Context telemetryContext = Context.current().with((ImplicitContextKeyed)requestSpan);
                    this.propagator.inject(telemetryContext, (Object)builder, (b, name, value) -> b.header(name, value));
                    HttpRequest request = builder.build();
                    mono = Mono.fromFuture((CompletableFuture)this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> {
                        if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 && response.statusCode() != 206) {
                            logger.error("Error sending message: {}", (Object)response.statusCode());
                        }
                    })).map(result -> {
                        if (this.durationConsumer != null) {
                            this.durationConsumer.accept(requestURI.toString(), System.currentTimeMillis() - start);
                        }
                        requestSpan.setStatus(StatusCode.OK);
                        return result;
                    }).onErrorMap(error -> {
                        requestSpan.recordException(error);
                        requestSpan.setStatus(StatusCode.ERROR);
                        return error;
                    }).doFinally(signal -> requestSpan.end());
                    if (scope == null) break block9;
                }
                catch (Throwable throwable) {
                    try {
                        if (scope != null) {
                            try {
                                scope.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (IOException e) {
                        requestSpan.recordException((Throwable)e);
                        if (!this.isClosing) {
                            return Mono.error((Throwable)new RuntimeException("Failed to serialize message", e));
                        }
                        return Mono.empty();
                    }
                }
                scope.close();
            }
            return mono;
        });
    }

    protected HttpRequest.Builder getHttpRequestBuilder() {
        return HttpRequest.newBuilder();
    }

    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            CompletableFuture<Void> future = this.connectionFuture.get();
            if (future != null && !future.isDone()) {
                future.cancel(true);
            }
        });
    }

    public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
        return (T)this.objectMapper.convertValue(data, typeRef);
    }
}

