/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.spec;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpNotificationHandler;
import io.modelcontextprotocol.server.McpRequestHandler;
import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.MissingMcpTransportSession;
import io.modelcontextprotocol.util.Assert;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

public class McpStreamableServerSession
implements McpLoggableSession {
    private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class);
    private final ConcurrentHashMap<Object, McpStreamableServerSessionStream> requestIdToStream = new ConcurrentHashMap();
    private final String id;
    private final Duration requestTimeout;
    private final AtomicLong requestCounter = new AtomicLong(0L);
    private final Map<String, McpRequestHandler<?>> requestHandlers;
    private final Map<String, McpNotificationHandler> notificationHandlers;
    private final AtomicReference<McpSchema.ClientCapabilities> clientCapabilities = new AtomicReference();
    private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference();
    private final AtomicReference<McpLoggableSession> listeningStreamRef;
    private final MissingMcpTransportSession missingMcpTransportSession;
    private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO;

    public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, Duration requestTimeout, Map<String, McpRequestHandler<?>> requestHandlers, Map<String, McpNotificationHandler> notificationHandlers) {
        this.id = id;
        this.missingMcpTransportSession = new MissingMcpTransportSession(id);
        this.listeningStreamRef = new AtomicReference<MissingMcpTransportSession>(this.missingMcpTransportSession);
        this.clientCapabilities.lazySet(clientCapabilities);
        this.clientInfo.lazySet(clientInfo);
        this.requestTimeout = requestTimeout;
        this.requestHandlers = requestHandlers;
        this.notificationHandlers = notificationHandlers;
    }

    @Override
    public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) {
        Assert.notNull((Object)minLoggingLevel, "minLoggingLevel must not be null");
        this.minLoggingLevel = minLoggingLevel;
    }

    @Override
    public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) {
        return loggingLevel.level() >= this.minLoggingLevel.level();
    }

    public String getId() {
        return this.id;
    }

    private String generateRequestId() {
        return this.id + "-" + this.requestCounter.getAndIncrement();
    }

    @Override
    public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
        return Mono.defer(() -> {
            McpLoggableSession listeningStream = this.listeningStreamRef.get();
            return listeningStream.sendRequest(method, requestParams, typeRef);
        });
    }

    @Override
    public Mono<Void> sendNotification(String method, Object params) {
        return Mono.defer(() -> {
            McpLoggableSession listeningStream = this.listeningStreamRef.get();
            return listeningStream.sendNotification(method, params);
        });
    }

    public Mono<Void> delete() {
        return this.closeGracefully().then(Mono.fromRunnable(() -> {}));
    }

    public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) {
        McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport);
        this.listeningStreamRef.set(listeningStream);
        return listeningStream;
    }

    public Flux<McpSchema.JSONRPCMessage> replay(Object lastEventId) {
        return Flux.empty();
    }

    public Mono<Void> responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) {
        return Mono.deferContextual(ctx -> {
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport);
            McpRequestHandler<?> requestHandler = this.requestHandlers.get(jsonrpcRequest.getMethod());
            if (requestHandler == null) {
                MethodNotFoundError error = this.getMethodNotFoundError(jsonrpcRequest.getMethod());
                return transport.sendMessage(new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32601, error.getMessage(), error.getData())));
            }
            return requestHandler.handle(new McpAsyncServerExchange(this.id, stream, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), jsonrpcRequest.getParams()).map(result -> new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.getId(), result, null)).onErrorResume(e -> {
                McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32603, e.getMessage(), null));
                return Mono.just((Object)errorResponse);
            }).flatMap(transport::sendMessage).then(transport.closeGracefully());
        });
    }

    public Mono<Void> accept(McpSchema.JSONRPCNotification notification) {
        return Mono.deferContextual(ctx -> {
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.getMethod());
            if (notificationHandler == null) {
                logger.warn("No handler registered for notification method: {}", (Object)notification);
                return Mono.empty();
            }
            McpLoggableSession listeningStream = this.listeningStreamRef.get();
            return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.getParams());
        });
    }

    public Mono<Void> accept(McpSchema.JSONRPCResponse response) {
        return Mono.defer(() -> {
            McpStreamableServerSessionStream stream = this.requestIdToStream.get(response.getId());
            if (stream == null) {
                return Mono.error((Throwable)new McpError((Object)("Unexpected response for unknown id " + response.getId())));
            }
            MonoSink sink = null;
            if (response.getId() != null) {
                sink = (MonoSink)stream.pendingResponses.remove(response.getId());
            }
            if (sink == null) {
                return Mono.error((Throwable)new McpError((Object)("Unexpected response for unknown id " + response.getId())));
            }
            sink.success((Object)response);
            return Mono.empty();
        });
    }

    private MethodNotFoundError getMethodNotFoundError(String method) {
        return new MethodNotFoundError(method, "Method not found: " + method, null);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.defer(() -> {
            McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(this.missingMcpTransportSession);
            return listeningStream.closeGracefully();
        });
    }

    @Override
    public void close() {
        McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(this.missingMcpTransportSession);
        if (listeningStream != null) {
            listeningStream.close();
        }
    }

    public final class McpStreamableServerSessionStream
    implements McpLoggableSession {
        private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap();
        private final McpStreamableServerTransport transport;
        private final String transportId;
        private final Supplier<String> uuidGenerator;

        public McpStreamableServerSessionStream(McpStreamableServerTransport transport) {
            this.transport = transport;
            this.transportId = UUID.randomUUID().toString();
            this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID();
        }

        @Override
        public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) {
            Assert.notNull((Object)minLoggingLevel, "minLoggingLevel must not be null");
            McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel);
        }

        @Override
        public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) {
            return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel);
        }

        @Override
        public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
            String requestId = McpStreamableServerSession.this.generateRequestId();
            McpStreamableServerSession.this.requestIdToStream.put(requestId, this);
            return Mono.create(sink -> {
                this.pendingResponses.put(requestId, (MonoSink<McpSchema.JSONRPCResponse>)sink);
                McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest("2.0", method, requestId, requestParams);
                String messageId = this.uuidGenerator.get();
                this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> {}, arg_0 -> ((MonoSink)sink).error(arg_0));
            }).timeout(McpStreamableServerSession.this.requestTimeout).doOnError(e -> {
                this.pendingResponses.remove(requestId);
                McpStreamableServerSession.this.requestIdToStream.remove(requestId);
            }).handle((jsonRpcResponse, sink) -> {
                if (jsonRpcResponse.getError() != null) {
                    sink.error((Throwable)new McpError(jsonRpcResponse.getError()));
                } else if (typeRef.getType().equals(Void.class)) {
                    sink.complete();
                } else {
                    sink.next(this.transport.unmarshalFrom(jsonRpcResponse.getResult(), typeRef));
                }
            });
        }

        @Override
        public Mono<Void> sendNotification(String method, Object params) {
            McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification("2.0", method, params);
            String messageId = this.uuidGenerator.get();
            return this.transport.sendMessage(jsonrpcNotification, messageId);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.defer(() -> {
                this.pendingResponses.values().forEach(s -> s.error((Throwable)new RuntimeException("Stream closed")));
                this.pendingResponses.clear();
                McpStreamableServerSession.this.listeningStreamRef.compareAndSet(this, McpStreamableServerSession.this.missingMcpTransportSession);
                McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals);
                return this.transport.closeGracefully();
            });
        }

        @Override
        public void close() {
            this.pendingResponses.values().forEach(s -> s.error((Throwable)new RuntimeException("Stream closed")));
            this.pendingResponses.clear();
            McpStreamableServerSession.this.listeningStreamRef.compareAndSet(this, McpStreamableServerSession.this.missingMcpTransportSession);
            McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals);
            this.transport.close();
        }

        @Generated
        public ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> getPendingResponses() {
            return this.pendingResponses;
        }

        @Generated
        public McpStreamableServerTransport getTransport() {
            return this.transport;
        }

        @Generated
        public String getTransportId() {
            return this.transportId;
        }

        @Generated
        public Supplier<String> getUuidGenerator() {
            return this.uuidGenerator;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof McpStreamableServerSessionStream)) {
                return false;
            }
            McpStreamableServerSessionStream other = (McpStreamableServerSessionStream)o;
            ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> this$pendingResponses = this.getPendingResponses();
            ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> other$pendingResponses = other.getPendingResponses();
            if (this$pendingResponses == null ? other$pendingResponses != null : !((Object)this$pendingResponses).equals(other$pendingResponses)) {
                return false;
            }
            McpStreamableServerTransport this$transport = this.getTransport();
            McpStreamableServerTransport other$transport = other.getTransport();
            if (this$transport == null ? other$transport != null : !this$transport.equals(other$transport)) {
                return false;
            }
            String this$transportId = this.getTransportId();
            String other$transportId = other.getTransportId();
            if (this$transportId == null ? other$transportId != null : !this$transportId.equals(other$transportId)) {
                return false;
            }
            Supplier<String> this$uuidGenerator = this.getUuidGenerator();
            Supplier<String> other$uuidGenerator = other.getUuidGenerator();
            return !(this$uuidGenerator == null ? other$uuidGenerator != null : !this$uuidGenerator.equals(other$uuidGenerator));
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> $pendingResponses = this.getPendingResponses();
            result = result * 59 + ($pendingResponses == null ? 43 : ((Object)$pendingResponses).hashCode());
            McpStreamableServerTransport $transport = this.getTransport();
            result = result * 59 + ($transport == null ? 43 : $transport.hashCode());
            String $transportId = this.getTransportId();
            result = result * 59 + ($transportId == null ? 43 : $transportId.hashCode());
            Supplier<String> $uuidGenerator = this.getUuidGenerator();
            result = result * 59 + ($uuidGenerator == null ? 43 : $uuidGenerator.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "McpStreamableServerSession.McpStreamableServerSessionStream(pendingResponses=" + this.getPendingResponses() + ", transport=" + this.getTransport() + ", transportId=" + this.getTransportId() + ", uuidGenerator=" + this.getUuidGenerator() + ")";
        }

        @Generated
        public McpStreamableServerSessionStream(McpStreamableServerTransport transport, String transportId, Supplier<String> uuidGenerator) {
            this.transport = transport;
            this.transportId = transportId;
            this.uuidGenerator = uuidGenerator;
        }
    }

    public static class McpStreamableServerSessionInit {
        McpStreamableServerSession session;
        Mono<McpSchema.InitializeResult> initResult;

        @Generated
        public McpStreamableServerSession getSession() {
            return this.session;
        }

        @Generated
        public Mono<McpSchema.InitializeResult> getInitResult() {
            return this.initResult;
        }

        @Generated
        public void setSession(McpStreamableServerSession session) {
            this.session = session;
        }

        @Generated
        public void setInitResult(Mono<McpSchema.InitializeResult> initResult) {
            this.initResult = initResult;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof McpStreamableServerSessionInit)) {
                return false;
            }
            McpStreamableServerSessionInit other = (McpStreamableServerSessionInit)o;
            if (!other.canEqual(this)) {
                return false;
            }
            McpStreamableServerSession this$session = this.getSession();
            McpStreamableServerSession other$session = other.getSession();
            if (this$session == null ? other$session != null : !this$session.equals(other$session)) {
                return false;
            }
            Mono<McpSchema.InitializeResult> this$initResult = this.getInitResult();
            Mono<McpSchema.InitializeResult> other$initResult = other.getInitResult();
            return !(this$initResult == null ? other$initResult != null : !this$initResult.equals(other$initResult));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof McpStreamableServerSessionInit;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            McpStreamableServerSession $session = this.getSession();
            result = result * 59 + ($session == null ? 43 : $session.hashCode());
            Mono<McpSchema.InitializeResult> $initResult = this.getInitResult();
            result = result * 59 + ($initResult == null ? 43 : $initResult.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "McpStreamableServerSession.McpStreamableServerSessionInit(session=" + this.getSession() + ", initResult=" + this.getInitResult() + ")";
        }

        @Generated
        public McpStreamableServerSessionInit(McpStreamableServerSession session, Mono<McpSchema.InitializeResult> initResult) {
            this.session = session;
            this.initResult = initResult;
        }

        @Generated
        public McpStreamableServerSessionInit() {
        }
    }

    public static interface Factory {
        public McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest var1);
    }

    public static interface InitRequestHandler {
        public Mono<McpSchema.InitializeResult> handle(McpSchema.InitializeRequest var1);
    }

    public static class MethodNotFoundError {
        String method;
        String message;
        Object data;

        @Generated
        public String getMethod() {
            return this.method;
        }

        @Generated
        public String getMessage() {
            return this.message;
        }

        @Generated
        public Object getData() {
            return this.data;
        }

        @Generated
        public void setMethod(String method) {
            this.method = method;
        }

        @Generated
        public void setMessage(String message) {
            this.message = message;
        }

        @Generated
        public void setData(Object data) {
            this.data = data;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof MethodNotFoundError)) {
                return false;
            }
            MethodNotFoundError other = (MethodNotFoundError)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$method = this.getMethod();
            String other$method = other.getMethod();
            if (this$method == null ? other$method != null : !this$method.equals(other$method)) {
                return false;
            }
            String this$message = this.getMessage();
            String other$message = other.getMessage();
            if (this$message == null ? other$message != null : !this$message.equals(other$message)) {
                return false;
            }
            Object this$data = this.getData();
            Object other$data = other.getData();
            return !(this$data == null ? other$data != null : !this$data.equals(other$data));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof MethodNotFoundError;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $method = this.getMethod();
            result = result * 59 + ($method == null ? 43 : $method.hashCode());
            String $message = this.getMessage();
            result = result * 59 + ($message == null ? 43 : $message.hashCode());
            Object $data = this.getData();
            result = result * 59 + ($data == null ? 43 : $data.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "McpStreamableServerSession.MethodNotFoundError(method=" + this.getMethod() + ", message=" + this.getMessage() + ", data=" + this.getData() + ")";
        }

        @Generated
        public MethodNotFoundError(String method, String message, Object data) {
            this.method = method;
            this.message = message;
            this.data = data;
        }

        @Generated
        public MethodNotFoundError() {
        }
    }
}

