/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.client;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
import io.modelcontextprotocol.util.Assert;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.util.context.ContextView;

class LifecycleInitializer {
    private static final Logger logger = LoggerFactory.getLogger(LifecycleInitializer.class);
    private final Function<ContextView, McpClientSession> sessionSupplier;
    private final McpSchema.ClientCapabilities clientCapabilities;
    private final McpSchema.Implementation clientInfo;
    private List<String> protocolVersions;
    private final AtomicReference<DefaultInitialization> initializationRef = new AtomicReference();
    private final Duration initializationTimeout;

    public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, List<String> protocolVersions, Duration initializationTimeout, Function<ContextView, McpClientSession> sessionSupplier) {
        Assert.notNull(sessionSupplier, "Session supplier must not be null");
        Assert.notNull(clientCapabilities, "Client capabilities must not be null");
        Assert.notNull(clientInfo, "Client info must not be null");
        Assert.notEmpty(protocolVersions, "Protocol versions must not be empty");
        Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
        this.sessionSupplier = sessionSupplier;
        this.clientCapabilities = clientCapabilities;
        this.clientInfo = clientInfo;
        this.protocolVersions = Collections.unmodifiableList(new ArrayList<String>(protocolVersions));
        this.initializationTimeout = initializationTimeout;
    }

    void setProtocolVersions(List<String> protocolVersions) {
        this.protocolVersions = protocolVersions;
    }

    public boolean isInitialized() {
        return this.currentInitializationResult() != null;
    }

    public McpSchema.InitializeResult currentInitializationResult() {
        DefaultInitialization current = this.initializationRef.get();
        McpSchema.InitializeResult initializeResult = current != null ? (McpSchema.InitializeResult)current.result.get() : null;
        return initializeResult;
    }

    public void handleException(Throwable t) {
        logger.warn("Handling exception", t);
        if (t instanceof McpTransportSessionNotFoundException) {
            DefaultInitialization previous = this.initializationRef.getAndSet(null);
            if (previous != null) {
                previous.close();
            }
            this.withIntitialization("re-initializing", result -> Mono.empty()).subscribe();
        }
    }

    public <T> Mono<T> withIntitialization(String actionName, Function<Initialization, Mono<T>> operation) {
        return Mono.deferContextual(ctx -> {
            DefaultInitialization newInit = new DefaultInitialization();
            DefaultInitialization previous = null;
            if (!this.initializationRef.compareAndSet(null, newInit)) {
                previous = this.initializationRef.get();
            }
            boolean needsToInitialize = previous == null;
            logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
            Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? this.doInitialize(newInit, (ContextView)ctx) : previous.await();
            return initializationJob.map(initializeResult -> this.initializationRef.get()).timeout(this.initializationTimeout).onErrorResume(ex -> {
                logger.warn("Failed to initialize", ex);
                return Mono.error((Throwable)new RuntimeException("Client failed to initialize " + actionName, (Throwable)ex));
            }).flatMap(operation);
        });
    }

    private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization initialization, ContextView ctx) {
        initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
        McpClientSession mcpClientSession = initialization.mcpSession();
        String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
        McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion, this.clientCapabilities, this.clientInfo);
        Mono<McpSchema.InitializeResult> result = mcpClientSession.sendRequest("initialize", initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF);
        return result.flatMap(initializeResult -> {
            logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", new Object[]{initializeResult.getProtocolVersion(), initializeResult.getCapabilities(), initializeResult.getServerInfo(), initializeResult.getInstructions()});
            if (!this.protocolVersions.contains(initializeResult.getProtocolVersion())) {
                return Mono.error((Throwable)McpError.builder(-32602).message("Unsupported protocol version").data("Unsupported protocol version from the server: " + initializeResult.getProtocolVersion()).build());
            }
            return mcpClientSession.sendNotification("notifications/initialized", null).thenReturn(initializeResult);
        }).doOnNext(x$0 -> initialization.complete(x$0)).onErrorResume(ex -> {
            initialization.error(ex);
            return Mono.error((Throwable)ex);
        });
    }

    public void close() {
        DefaultInitialization current = this.initializationRef.getAndSet(null);
        if (current != null) {
            current.close();
        }
    }

    public Mono<?> closeGracefully() {
        return Mono.defer(() -> {
            DefaultInitialization current = this.initializationRef.getAndSet(null);
            Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty();
            return sessionClose;
        });
    }

    private static class DefaultInitialization
    implements Initialization {
        private final Sinks.One<McpSchema.InitializeResult> initSink = Sinks.one();
        private final AtomicReference<McpSchema.InitializeResult> result = new AtomicReference();
        private final AtomicReference<McpClientSession> mcpClientSession = new AtomicReference();

        private DefaultInitialization() {
        }

        @Override
        public McpClientSession mcpSession() {
            return this.mcpClientSession.get();
        }

        @Override
        public McpSchema.InitializeResult initializeResult() {
            return this.result.get();
        }

        private void setMcpClientSession(McpClientSession mcpClientSession) {
            this.mcpClientSession.set(mcpClientSession);
        }

        private Mono<McpSchema.InitializeResult> await() {
            return this.initSink.asMono();
        }

        private void complete(McpSchema.InitializeResult initializeResult) {
            this.result.set(initializeResult);
            this.initSink.emitValue((Object)initializeResult, Sinks.EmitFailureHandler.FAIL_FAST);
        }

        private void error(Throwable t) {
            this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST);
        }

        private void close() {
            this.mcpSession().close();
        }

        private Mono<Void> closeGracefully() {
            return this.mcpSession().closeGracefully();
        }
    }

    static interface Initialization {
        public McpClientSession mcpSession();

        public McpSchema.InitializeResult initializeResult();
    }
}

