/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.spec;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSession;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import lombok.Generated;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

public class McpClientSession
implements McpSession {
    private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class);
    private final Duration requestTimeout;
    private final McpClientTransport transport;
    private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap();
    private final ConcurrentHashMap<String, RequestHandler<?>> requestHandlers = new ConcurrentHashMap();
    private final ConcurrentHashMap<String, NotificationHandler> notificationHandlers = new ConcurrentHashMap();
    private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8);
    private final AtomicLong requestCounter = new AtomicLong(0L);

    @Deprecated
    public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
        this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity());
    }

    public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers, Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
        Assert.notNull(requestTimeout, "The requestTimeout can not be null");
        Assert.notNull(transport, "The transport can not be null");
        Assert.notNull(requestHandlers, "The requestHandlers can not be null");
        Assert.notNull(notificationHandlers, "The notificationHandlers can not be null");
        this.requestTimeout = requestTimeout;
        this.transport = transport;
        this.requestHandlers.putAll(requestHandlers);
        this.notificationHandlers.putAll(notificationHandlers);
        this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
    }

    private void dismissPendingResponses() {
        this.pendingResponses.forEach((id, sink) -> {
            logger.warn("Abruptly terminating exchange for request {}", id);
            sink.error((Throwable)new RuntimeException("MCP session with server terminated"));
        });
        this.pendingResponses.clear();
    }

    private void handle(McpSchema.JSONRPCMessage message) {
        if (message instanceof McpSchema.JSONRPCResponse) {
            McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse)message;
            logger.debug("Received Response: {}", (Object)response);
            MonoSink<McpSchema.JSONRPCResponse> sink = this.pendingResponses.remove(response.getId());
            if (sink == null) {
                logger.warn("Unexpected response for unknown id {}", response.getId());
            } else {
                sink.success((Object)response);
            }
        } else if (message instanceof McpSchema.JSONRPCRequest) {
            McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest)message;
            logger.debug("Received request: {}", (Object)request);
            this.handleIncomingRequest(request).onErrorResume(error -> {
                McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse("2.0", request.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32603, error.getMessage(), null));
                return Mono.just((Object)errorResponse);
            }).flatMap(this.transport::sendMessage).onErrorComplete(t -> {
                logger.warn("Issue sending response to the client, ", t);
                return true;
            }).subscribe();
        } else if (message instanceof McpSchema.JSONRPCNotification) {
            McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification)message;
            logger.debug("Received notification: {}", (Object)notification);
            this.handleIncomingNotification(notification).onErrorComplete(t -> {
                logger.error("Error handling notification: {}", (Object)t.getMessage());
                return true;
            }).subscribe();
        } else {
            logger.warn("Received unknown message type: {}", (Object)message);
        }
    }

    private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request) {
        return Mono.defer(() -> {
            RequestHandler<?> handler = this.requestHandlers.get(request.getMethod());
            if (handler == null) {
                MethodNotFoundError error = this.getMethodNotFoundError(request.getMethod());
                return Mono.just((Object)new McpSchema.JSONRPCResponse("2.0", request.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32601, error.getMessage(), error.getData())));
            }
            return handler.handle(request.getParams()).map(result -> new McpSchema.JSONRPCResponse("2.0", request.getId(), result, null));
        });
    }

    private MethodNotFoundError getMethodNotFoundError(String method) {
        switch (method) {
            case "roots/list": {
                return new MethodNotFoundError(method, "Roots not supported", Utils.asMap("reason", "Client does not have roots capability"));
            }
        }
        return new MethodNotFoundError(method, "Method not found: " + method, null);
    }

    private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification) {
        return Mono.defer(() -> {
            NotificationHandler handler = this.notificationHandlers.get(notification.getMethod());
            if (handler == null) {
                logger.error("No handler registered for notification method: {}", (Object)notification.getMethod());
                return Mono.empty();
            }
            return handler.handle(notification.getParams());
        });
    }

    private String generateRequestId() {
        return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement();
    }

    @Override
    public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
        String requestId = this.generateRequestId();
        return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> {
            logger.debug("Sending message for method {}", (Object)method);
            this.pendingResponses.put(requestId, (MonoSink<McpSchema.JSONRPCResponse>)pendingResponseSink);
            McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest("2.0", method, requestId, requestParams);
            this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> {}, error -> {
                this.pendingResponses.remove(requestId);
                pendingResponseSink.error(error);
            });
        })).timeout(this.requestTimeout).handle((jsonRpcResponse, deliveredResponseSink) -> {
            if (jsonRpcResponse.getError() != null) {
                logger.error("Error handling request: {}", (Object)jsonRpcResponse.getError());
                deliveredResponseSink.error((Throwable)new McpError(jsonRpcResponse.getError()));
            } else if (typeRef.getType().equals(Void.class)) {
                deliveredResponseSink.complete();
            } else {
                deliveredResponseSink.next(this.transport.unmarshalFrom(jsonRpcResponse.getResult(), typeRef));
            }
        });
    }

    @Override
    public Mono<Void> sendNotification(String method, Object params) {
        McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification("2.0", method, params);
        return this.transport.sendMessage(jsonrpcNotification);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(this::dismissPendingResponses);
    }

    @Override
    public void close() {
        this.dismissPendingResponses();
    }

    public static class MethodNotFoundError {
        String method;
        String message;
        Object data;

        @Generated
        public String getMethod() {
            return this.method;
        }

        @Generated
        public String getMessage() {
            return this.message;
        }

        @Generated
        public Object getData() {
            return this.data;
        }

        @Generated
        public void setMethod(String method) {
            this.method = method;
        }

        @Generated
        public void setMessage(String message) {
            this.message = message;
        }

        @Generated
        public void setData(Object data) {
            this.data = data;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof MethodNotFoundError)) {
                return false;
            }
            MethodNotFoundError other = (MethodNotFoundError)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$method = this.getMethod();
            String other$method = other.getMethod();
            if (this$method == null ? other$method != null : !this$method.equals(other$method)) {
                return false;
            }
            String this$message = this.getMessage();
            String other$message = other.getMessage();
            if (this$message == null ? other$message != null : !this$message.equals(other$message)) {
                return false;
            }
            Object this$data = this.getData();
            Object other$data = other.getData();
            return !(this$data == null ? other$data != null : !this$data.equals(other$data));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof MethodNotFoundError;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $method = this.getMethod();
            result = result * 59 + ($method == null ? 43 : $method.hashCode());
            String $message = this.getMessage();
            result = result * 59 + ($message == null ? 43 : $message.hashCode());
            Object $data = this.getData();
            result = result * 59 + ($data == null ? 43 : $data.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "McpClientSession.MethodNotFoundError(method=" + this.getMethod() + ", message=" + this.getMessage() + ", data=" + this.getData() + ")";
        }

        @Generated
        public MethodNotFoundError(String method, String message, Object data) {
            this.method = method;
            this.message = message;
            this.data = data;
        }

        @Generated
        public MethodNotFoundError() {
        }
    }

    @FunctionalInterface
    public static interface NotificationHandler {
        public Mono<Void> handle(Object var1);
    }

    @FunctionalInterface
    public static interface RequestHandler<T> {
        public Mono<T> handle(Object var1);
    }
}

