/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability;

import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.client.WebClientOptions;
import io.vertx.mutiny.core.Vertx;
import io.vertx.mutiny.ext.web.client.WebClient;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.eclipse.microprofile.context.ThreadContext;
import org.kie.kogito.explainability.ConversionUtils;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.models.ExplainabilityRequest;
import org.kie.kogito.explainability.models.ModelIdentifier;
import org.kie.kogito.explainability.models.PredictInput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RemotePredictionProvider
implements PredictionProvider {
    private static final Logger LOG = LoggerFactory.getLogger(RemotePredictionProvider.class);
    private final ExplainabilityRequest request;
    private final ThreadContext threadContext;
    private final Executor asyncExecutor;
    private final WebClient client;

    public RemotePredictionProvider(ExplainabilityRequest request, Vertx vertx, ThreadContext threadContext, Executor asyncExecutor) {
        this.request = request;
        URI uri = URI.create(request.getServiceUrl());
        this.client = this.getClient(vertx, uri);
        this.threadContext = threadContext;
        this.asyncExecutor = asyncExecutor;
    }

    public CompletableFuture<List<PredictionOutput>> predictAsync(List<PredictionInput> inputs) {
        return this.sendPredictRequest(inputs, this.request.getModelIdentifier());
    }

    protected WebClient getClient(Vertx vertx, URI uri) {
        int port = uri.getPort() != -1 ? uri.getPort() : 80;
        return WebClient.create((Vertx)vertx, (WebClientOptions)new WebClientOptions().setDefaultHost(uri.getHost()).setDefaultPort(port).setSsl("https".equalsIgnoreCase(uri.getScheme())).setLogActivity(true));
    }

    protected PredictionOutput toPredictionOutput(JsonObject mainObj) {
        if (mainObj == null || !mainObj.containsKey("result")) {
            LOG.error("Malformed json {}", (Object)mainObj);
            return null;
        }
        List<Output> resultOutputs = ConversionUtils.toOutputList(mainObj.getJsonObject("result"));
        List resultOutputNames = resultOutputs.stream().map(Output::getName).collect(Collectors.toList());
        List outputs = Stream.concat(resultOutputs.stream().filter(output -> this.request.getOutputs().containsKey(output.getName())), this.request.getOutputs().keySet().stream().filter(key -> !resultOutputNames.contains(key)).map(key -> new Output(key, Type.UNDEFINED, new Value(null), 1.0))).collect(Collectors.toList());
        return new PredictionOutput(outputs);
    }

    protected Map<String, Object> toMap(List<Feature> features) {
        Map<String, Object> map = new HashMap<String, Object>();
        for (Feature f : features) {
            if (Type.COMPOSITE.equals((Object)f.getType())) {
                List compositeFeatures = (List)f.getValue().getUnderlyingObject();
                HashMap<String, Object> maps = new HashMap<String, Object>();
                for (Feature cf : compositeFeatures) {
                    Map<String, Object> compositeFeatureMap = this.toMap(List.of(cf));
                    maps.putAll(compositeFeatureMap);
                }
                map.put(f.getName(), maps);
                continue;
            }
            if (Type.UNDEFINED.equals((Object)f.getType())) {
                Feature underlyingFeature = (Feature)f.getValue().getUnderlyingObject();
                map.put(f.getName(), this.toMap(List.of(underlyingFeature)));
                continue;
            }
            Object underlyingObject = f.getValue().getUnderlyingObject();
            map.put(f.getName(), underlyingObject);
        }
        if (map.containsKey("context")) {
            map = (Map)map.get("context");
        }
        return map;
    }

    protected CompletableFuture<List<PredictionOutput>> sendPredictRequest(List<PredictionInput> inputs, ModelIdentifier modelIdentifier) {
        List piList = inputs.stream().map(input -> new PredictInput(modelIdentifier, this.toMap(input.getFeatures()))).collect(Collectors.toList());
        return this.threadContext.withContextCapture(this.client.post("/predict").sendJson(piList).subscribeAsCompletionStage()).thenApplyAsync(r -> this.parseRawResult(r.bodyAsJsonArray()), this.asyncExecutor);
    }

    protected List<PredictionOutput> parseRawResult(JsonArray jsonArray) {
        return jsonArray.stream().filter(JsonObject.class::isInstance).map(JsonObject.class::cast).map(this::toPredictionOutput).filter(Objects::nonNull).collect(Collectors.toList());
    }
}

