/*
 * Decompiled with CFR 0.152.
 */
package org.miaixz.bus.vortex.support;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.miaixz.bus.core.xyz.StringKit;
import org.miaixz.bus.vortex.Assets;
import org.miaixz.bus.vortex.Context;
import org.miaixz.bus.vortex.Router;
import org.reactivestreams.Publisher;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ExchangeStrategies;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClientRequest;
import reactor.util.annotation.NonNull;

public class McpRequestRouter
implements Router {
    private static final ExchangeStrategies CACHED_EXCHANGE_STRATEGIES = ExchangeStrategies.builder().codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(Math.toIntExact(0x8000000L))).build();
    private final Map<String, WebClient> clients = new ConcurrentHashMap<String, WebClient>();
    private final Map<String, List<ServiceInstance>> serviceCache = new ConcurrentHashMap<String, List<ServiceInstance>>();
    private final Map<String, AtomicInteger> counters = new ConcurrentHashMap<String, AtomicInteger>();
    private static final LoadBalanceStrategy DEFAULT_STRATEGY = LoadBalanceStrategy.ROUND_ROBIN;
    private static final int DEFAULT_WEIGHT = 1;

    @Override
    @NonNull
    public Mono<ServerResponse> route(ServerRequest request, Context context, Assets assets) {
        String serviceName = assets.getName();
        if (!StringKit.hasText((String)serviceName)) {
            return Mono.error((Throwable)new IllegalArgumentException("Service name cannot be empty"));
        }
        List<ServiceInstance> instances = this.serviceCache.get(serviceName);
        if (instances == null || instances.isEmpty()) {
            return Mono.error((Throwable)new IllegalArgumentException("No available instances for service: " + serviceName));
        }
        if ((instances = instances.stream().filter(ServiceInstance::isHealthy).collect(Collectors.toList())).isEmpty()) {
            return Mono.error((Throwable)new IllegalStateException("All instances are unhealthy for service: " + serviceName));
        }
        LoadBalanceStrategy strategy = this.parseLoadBalanceStrategy(null);
        ServiceInstance selectedInstance = this.selectInstance(instances, strategy, serviceName);
        selectedInstance.incrementRequestCount();
        return this.buildAndSendMcpRequest(request, context, assets, selectedInstance);
    }

    private Mono<ServerResponse> buildAndSendMcpRequest(ServerRequest request, Context context, Assets assets, ServiceInstance instance) {
        String baseUrl = this.buildMcpBaseUrl(instance);
        WebClient webClient = this.clients.computeIfAbsent(baseUrl, client -> WebClient.builder().exchangeStrategies(CACHED_EXCHANGE_STRATEGIES).baseUrl(baseUrl).build());
        String targetUri = this.buildMcpTargetUri(assets, context);
        WebClient.RequestBodySpec bodySpec = (WebClient.RequestBodySpec)webClient.method(context.getHttpMethod()).uri(targetUri, new Object[0]);
        bodySpec.headers(headers -> {
            headers.addAll((MultiValueMap)request.headers().asHttpHeaders());
            headers.remove((Object)"Host");
            headers.clearContentHeaders();
            headers.add("X-MCP-Protocol", "1.0");
            headers.add("X-MCP-Request-ID", context.getX_request_id());
            headers.add("X-MCP-Instance-ID", instance.getInstanceId());
        });
        Map<String, String> params = context.getRequestMap();
        if (!params.isEmpty()) {
            bodySpec.contentType(MediaType.APPLICATION_JSON).bodyValue(params);
        }
        return ((WebClient.RequestBodySpec)bodySpec.httpRequest(clientHttpRequest -> {
            HttpClientRequest reactorRequest = (HttpClientRequest)clientHttpRequest.getNativeRequest();
            reactorRequest.responseTimeout(Duration.ofMillis(assets.getTimeout()));
        })).retrieve().toEntity(DataBuffer.class).flatMap(this::processResponse);
    }

    private String buildMcpBaseUrl(ServiceInstance instance) {
        StringBuilder baseUrlBuilder = new StringBuilder("mcp://").append(instance.getHost());
        if (instance.getPort() > 0) {
            baseUrlBuilder.append(":").append(instance.getPort());
        }
        if (instance.getPath() != null && !instance.getPath().isEmpty()) {
            if (!instance.getPath().startsWith("/")) {
                baseUrlBuilder.append("/");
            }
            baseUrlBuilder.append(instance.getPath());
        }
        return baseUrlBuilder.toString();
    }

    private String buildMcpTargetUri(Assets assets, Context context) {
        UriComponentsBuilder builder = UriComponentsBuilder.fromUriString((String)assets.getUrl());
        builder.queryParam("protocol", new Object[]{"mcp"});
        builder.queryParam("version", new Object[]{"1.0"});
        builder.queryParam("requestId", new Object[]{context.getX_request_id()});
        Map<String, String> params = context.getRequestMap();
        if (!params.isEmpty()) {
            LinkedMultiValueMap multiValueMap = new LinkedMultiValueMap(params.size());
            params.forEach((arg_0, arg_1) -> ((MultiValueMap)multiValueMap).add(arg_0, arg_1));
            builder.queryParams((MultiValueMap)multiValueMap);
        }
        return builder.build().toUriString();
    }

    private Mono<ServerResponse> processResponse(ResponseEntity<DataBuffer> responseEntity) {
        return ((ServerResponse.BodyBuilder)ServerResponse.ok().headers(headers -> {
            headers.addAll((MultiValueMap)responseEntity.getHeaders());
            headers.remove((Object)"Content-Length");
        })).body(responseEntity.getBody() == null ? BodyInserters.empty() : BodyInserters.fromDataBuffers((Publisher)Flux.just((Object)((DataBuffer)responseEntity.getBody()))));
    }

    private LoadBalanceStrategy parseLoadBalanceStrategy(String loadBalanceConfig) {
        if (!StringKit.hasText((String)loadBalanceConfig)) {
            return DEFAULT_STRATEGY;
        }
        try {
            return LoadBalanceStrategy.valueOf(loadBalanceConfig.toUpperCase());
        }
        catch (IllegalArgumentException e) {
            return DEFAULT_STRATEGY;
        }
    }

    private ServiceInstance selectInstance(List<ServiceInstance> instances, LoadBalanceStrategy strategy, String serviceName) {
        switch (strategy.ordinal()) {
            case 0: {
                return this.roundRobinSelect(instances, serviceName);
            }
            case 1: {
                return this.randomSelect(instances);
            }
            case 2: {
                return this.weightSelect(instances);
            }
        }
        return this.roundRobinSelect(instances, serviceName);
    }

    private ServiceInstance roundRobinSelect(List<ServiceInstance> instances, String serviceName) {
        AtomicInteger counter = this.counters.computeIfAbsent(serviceName, k -> new AtomicInteger(0));
        int index = Math.abs(counter.getAndIncrement()) % instances.size();
        return instances.get(index);
    }

    private ServiceInstance randomSelect(List<ServiceInstance> instances) {
        int index = (int)(Math.random() * (double)instances.size());
        return instances.get(index);
    }

    private ServiceInstance weightSelect(List<ServiceInstance> instances) {
        int totalWeight = instances.stream().mapToInt(ServiceInstance::getWeight).sum();
        if (totalWeight <= 0) {
            return this.roundRobinSelect(instances, instances.get(0).getServiceName());
        }
        int randomWeight = (int)(Math.random() * (double)totalWeight);
        int currentWeight = 0;
        for (ServiceInstance instance : instances) {
            if (randomWeight > (currentWeight += instance.getWeight())) continue;
            return instance;
        }
        return instances.get(0);
    }

    public void addServiceInstance(String serviceName, ServiceInstance instance) {
        Objects.requireNonNull(serviceName, "Service name cannot be null");
        Objects.requireNonNull(instance, "Service instance cannot be null");
        this.serviceCache.computeIfAbsent(serviceName, k -> new CopyOnWriteArrayList()).add(instance);
        instance.setHealthy(true);
        instance.setLastHealthCheckTime(System.currentTimeMillis());
    }

    public void removeServiceInstance(String serviceName, String instanceId) {
        Objects.requireNonNull(serviceName, "Service name cannot be null");
        Objects.requireNonNull(instanceId, "Instance ID cannot be null");
        List<ServiceInstance> instances = this.serviceCache.get(serviceName);
        if (instances != null) {
            instances.removeIf(instance -> instanceId.equals(instance.getInstanceId()));
        }
    }

    public void updateInstanceHealth(String serviceName, String instanceId, boolean healthy) {
        Objects.requireNonNull(serviceName, "Service name cannot be null");
        Objects.requireNonNull(instanceId, "Instance ID cannot be null");
        List<ServiceInstance> instances = this.serviceCache.get(serviceName);
        if (instances != null) {
            instances.stream().filter(instance -> instanceId.equals(instance.getInstanceId())).forEach(instance -> {
                instance.setHealthy(healthy);
                instance.setLastHealthCheckTime(System.currentTimeMillis());
            });
        }
    }

    private static enum LoadBalanceStrategy {
        ROUND_ROBIN,
        RANDOM,
        WEIGHT;

    }

    public static class ServiceInstance {
        private final String instanceId;
        private final String serviceName;
        private final String host;
        private final int port;
        private final String path;
        private final int weight;
        private volatile boolean healthy;
        private volatile long lastHealthCheckTime;
        private final AtomicInteger requestCount = new AtomicInteger(0);
        private final Map<String, String> metadata = new ConcurrentHashMap<String, String>();

        public ServiceInstance(String instanceId, String serviceName, String host, int port, String path, int weight) {
            this.instanceId = instanceId;
            this.serviceName = serviceName;
            this.host = host;
            this.port = port;
            this.path = path;
            this.weight = weight > 0 ? weight : 1;
            this.healthy = true;
            this.lastHealthCheckTime = System.currentTimeMillis();
        }

        public String getInstanceId() {
            return this.instanceId;
        }

        public String getServiceName() {
            return this.serviceName;
        }

        public String getHost() {
            return this.host;
        }

        public int getPort() {
            return this.port;
        }

        public String getPath() {
            return this.path;
        }

        public int getWeight() {
            return this.weight;
        }

        public boolean isHealthy() {
            return this.healthy;
        }

        public void setHealthy(boolean healthy) {
            this.healthy = healthy;
        }

        public long getLastHealthCheckTime() {
            return this.lastHealthCheckTime;
        }

        public void setLastHealthCheckTime(long lastHealthCheckTime) {
            this.lastHealthCheckTime = lastHealthCheckTime;
        }

        public int getRequestCount() {
            return this.requestCount.get();
        }

        public void incrementRequestCount() {
            this.requestCount.incrementAndGet();
        }

        public Map<String, String> getMetadata() {
            return this.metadata;
        }

        public void addMetadata(String key, String value) {
            this.metadata.put(key, value);
        }

        public String getMetadata(String key) {
            return this.metadata.get(key);
        }
    }
}

