/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.support_vector_machine;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import java.lang.reflect.Field;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.Array;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.support_vector_machine.Coefficient;
import org.dmg.pmml.support_vector_machine.Coefficients;
import org.dmg.pmml.support_vector_machine.Kernel;
import org.dmg.pmml.support_vector_machine.SupportVector;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.support_vector_machine.SupportVectors;
import org.dmg.pmml.support_vector_machine.VectorDictionary;
import org.dmg.pmml.support_vector_machine.VectorFields;
import org.dmg.pmml.support_vector_machine.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValues;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.InvalidElementListException;
import org.jpmml.evaluator.MisplacedAttributeException;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.PMMLElements;
import org.jpmml.evaluator.PMMLException;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.XPathUtil;
import org.jpmml.evaluator.support_vector_machine.DistanceDistribution;
import org.jpmml.evaluator.support_vector_machine.KernelUtil;
import org.jpmml.evaluator.support_vector_machine.VoteDistribution;
import org.jpmml.evaluator.support_vector_machine.VoteMap;
import org.jpmml.model.ReflectionUtil;

public class SupportVectorMachineModelEvaluator
extends ModelEvaluator<SupportVectorMachineModel> {
    private transient Map<String, double[]> vectorMap = null;
    private static final LoadingCache<SupportVectorMachineModel, Map<String, double[]>> vectorCache = CacheUtil.buildLoadingCache(new CacheLoader<SupportVectorMachineModel, Map<String, double[]>>(){

        public Map<String, double[]> load(SupportVectorMachineModel supportVectorMachineModel) {
            return ImmutableMap.copyOf((Map)SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel));
        }
    });

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, SupportVectorMachineModelEvaluator.selectModel(pmml, SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
        boolean maxWins = supportVectorMachineModel.isMaxWins();
        if (maxWins) {
            throw new UnsupportedAttributeException((PMMLObject)supportVectorMachineModel, PMMLAttributes.SUPPORTVECTORMACHINEMODEL_MAXWINS, maxWins);
        }
        SupportVectorMachineModel.Representation representation = supportVectorMachineModel.getRepresentation();
        switch (representation) {
            case SUPPORT_VECTORS: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)supportVectorMachineModel, (Enum<?>)representation);
            }
        }
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        if (vectorDictionary == null) {
            throw new MissingElementException((PMMLObject)supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_VECTORDICTIONARY);
        }
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        if (vectorFields == null) {
            throw new MissingElementException((PMMLObject)vectorDictionary, PMMLElements.VECTORDICTIONARY_VECTORFIELDS);
        }
        if (!supportVectorMachineModel.hasSupportVectorMachines()) {
            throw new MissingElementException((PMMLObject)supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_SUPPORTVECTORMACHINES);
        }
    }

    @Override
    public String getSummary() {
        return "Support vector machine";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, Object> predictions;
        ValueFactory<Double> valueFactory;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.ensureScorableModel();
        MathContext mathContext = supportVectorMachineModel.getMathContext();
        switch (mathContext) {
            case DOUBLE: {
                valueFactory = this.ensureValueFactory();
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)supportVectorMachineModel, (Enum<?>)mathContext);
            }
        }
        MiningFunction miningFunction = supportVectorMachineModel.getMiningFunction();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(valueFactory, context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(valueFactory, context);
                break;
            }
            case ASSOCIATION_RULES: 
            case SEQUENCES: 
            case CLUSTERING: 
            case TIME_SERIES: 
            case MIXED: {
                throw new InvalidAttributeException((PMMLObject)supportVectorMachineModel, (Enum<?>)miningFunction);
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)supportVectorMachineModel, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(ValueFactory<Double> valueFactory, EvaluationContext context) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        List supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() != 1) {
            throw new InvalidElementListException(supportVectorMachines);
        }
        SupportVectorMachine supportVectorMachine = (SupportVectorMachine)supportVectorMachines.get(0);
        double[] input = this.createInput(context);
        Value<Double> result = this.evaluateSupportVectorMachine(valueFactory, supportVectorMachine, input);
        return TargetUtil.evaluateRegression(this.getTargetField(), result);
    }

    /*
     * WARNING - void declaration
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private Map<FieldName, ? extends Classification<Double>> evaluateClassification(final ValueFactory<Double> valueFactory, EvaluationContext context) {
        VoteMap<String, Double> values;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        List supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = this.getClassificationMethod();
        switch (classificationMethod) {
            case ONE_AGAINST_ALL: {
                values = new ValueMap(2 * supportVectorMachines.size());
                break;
            }
            case ONE_AGAINST_ONE: {
                values = new VoteMap<String, Double>(2 * supportVectorMachines.size()){

                    @Override
                    public ValueFactory<Double> getValueFactory() {
                        return valueFactory;
                    }
                };
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)supportVectorMachineModel, (Enum<?>)classificationMethod);
            }
        }
        double[] input = this.createInput(context);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            String targetCategory = supportVectorMachine.getTargetCategory();
            if (targetCategory == null) {
                throw new MissingAttributeException((PMMLObject)supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_TARGETCATEGORY);
            }
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            Value<Double> value = this.evaluateSupportVectorMachine(valueFactory, supportVectorMachine, input);
            switch (classificationMethod) {
                case ONE_AGAINST_ALL: {
                    if (alternateTargetCategory != null) {
                        throw new MisplacedAttributeException((PMMLObject)supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY, alternateTargetCategory);
                    }
                    values.put(targetCategory, (Double)((Object)value));
                    break;
                }
                case ONE_AGAINST_ONE: {
                    String label;
                    if (alternateBinaryTargetCategory != null) {
                        if (alternateTargetCategory != null) {
                            throw new MisplacedAttributeException((PMMLObject)supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY, alternateTargetCategory);
                        }
                        value.round();
                        if (value.equals(1.0)) {
                            label = targetCategory;
                        } else {
                            if (!value.equals(0.0)) throw new EvaluationException("Expected " + PMMLException.formatValue(0.0) + " or " + PMMLException.formatValue(1.0) + ", got " + PMMLException.formatValue(value.getValue()));
                            label = alternateBinaryTargetCategory;
                        }
                    } else {
                        if (alternateTargetCategory == null) {
                            throw new MissingAttributeException((PMMLObject)supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY);
                        }
                        Double threshold = supportVectorMachine.getThreshold();
                        if (threshold == null) {
                            threshold = supportVectorMachineModel.getThreshold();
                        }
                        label = value.compareTo(threshold) < 0 ? targetCategory : alternateTargetCategory;
                    }
                    VoteMap votes = values;
                    votes.increment(label);
                    break;
                }
            }
        }
        switch (classificationMethod) {
            case ONE_AGAINST_ALL: {
                void var9_12;
                DistanceDistribution<Double> distanceDistribution = new DistanceDistribution<Double>(values);
                return TargetUtil.evaluateClassification(this.getTargetField(), var9_12);
            }
            case ONE_AGAINST_ONE: {
                void var9_12;
                VoteDistribution<Double> voteDistribution = new VoteDistribution<Double>(values);
                return TargetUtil.evaluateClassification(this.getTargetField(), var9_12);
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)supportVectorMachineModel, (Enum<?>)classificationMethod);
            }
        }
    }

    private Value<Double> evaluateSupportVectorMachine(ValueFactory<Double> valueFactory, SupportVectorMachine supportVectorMachine, double[] input) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        Value<Double> result = valueFactory.newValue();
        Kernel kernel = supportVectorMachineModel.getKernel();
        if (kernel == null) {
            throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(supportVectorMachineModel.getClass()) + "/<Kernel>"), (PMMLObject)supportVectorMachine);
        }
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator coefficientIt = coefficients.iterator();
        SupportVectors supportVectors = supportVectorMachine.getSupportVectors();
        Iterator supportVectorIt = supportVectors.iterator();
        Map<String, double[]> vectorMap = this.getVectorMap();
        while (coefficientIt.hasNext() && supportVectorIt.hasNext()) {
            Coefficient coefficient = (Coefficient)coefficientIt.next();
            SupportVector supportVector = (SupportVector)supportVectorIt.next();
            String vectorId = supportVector.getVectorId();
            if (vectorId == null) {
                throw new MissingAttributeException((PMMLObject)supportVector, PMMLAttributes.SUPPORTVECTOR_VECTORID);
            }
            double[] vector = vectorMap.get(vectorId);
            if (vector == null) {
                throw new InvalidAttributeException((PMMLObject)supportVector, PMMLAttributes.SUPPORTVECTOR_VECTORID, vectorId);
            }
            result.add((double)coefficient.getValue(), KernelUtil.evaluate(kernel, input, vector));
        }
        if (coefficientIt.hasNext() || supportVectorIt.hasNext()) {
            throw new InvalidElementException((PMMLObject)supportVectorMachine);
        }
        double absoluteValue = coefficients.getAbsoluteValue();
        if (absoluteValue != 0.0) {
            result.add(absoluteValue);
        }
        return result;
    }

    private SupportVectorMachineModel.ClassificationMethod getClassificationMethod() {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = (SupportVectorMachineModel.ClassificationMethod)ReflectionUtil.getFieldValue((Field)PMMLAttributes.SUPPORTVECTORMACHINEMODEL_CLASSIFICATIONMETHOD, (Object)supportVectorMachineModel);
        if (classificationMethod != null) {
            return classificationMethod;
        }
        List supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        if (alternateBinaryTargetCategory != null) {
            if (supportVectorMachines.size() == 1) {
                SupportVectorMachine supportVectorMachine = (SupportVectorMachine)supportVectorMachines.get(0);
                String targetCategory = supportVectorMachine.getTargetCategory();
                if (targetCategory != null) {
                    return SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE;
                }
                throw new InvalidElementException((PMMLObject)supportVectorMachine);
            }
            throw new InvalidElementException((PMMLObject)supportVectorMachineModel);
        }
        Iterator iterator = supportVectorMachines.iterator();
        if (iterator.hasNext()) {
            SupportVectorMachine supportVectorMachine = (SupportVectorMachine)iterator.next();
            String targetCategory = supportVectorMachine.getTargetCategory();
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            if (targetCategory != null) {
                if (alternateTargetCategory != null) {
                    return SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE;
                }
                return SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL;
            }
            throw new InvalidElementException((PMMLObject)supportVectorMachine);
        }
        throw new InvalidElementException((PMMLObject)supportVectorMachineModel);
    }

    private double[] createInput(EvaluationContext context) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        List content = vectorFields.getContent();
        double[] result = new double[content.size()];
        for (int i = 0; i < content.size(); ++i) {
            FieldValue value;
            FieldName name;
            PMMLObject object = (PMMLObject)content.get(i);
            if (object instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef)content.get(i);
                name = fieldRef.getField();
                value = ExpressionUtil.evaluate((Expression)fieldRef, context);
                if (Objects.equals(FieldValues.MISSING_VALUE, value)) {
                    throw new MissingValueException(name, (PMMLObject)vectorFields);
                }
                result[i] = value.asNumber().doubleValue();
                continue;
            }
            if (object instanceof CategoricalPredictor) {
                CategoricalPredictor categoricalPredictor = (CategoricalPredictor)object;
                name = categoricalPredictor.getName();
                if (name == null) {
                    throw new MissingAttributeException((PMMLObject)categoricalPredictor, PMMLAttributes.CATEGORICALPREDICTOR_FIELD);
                }
                value = context.evaluate(name);
                if (Objects.equals(FieldValues.MISSING_VALUE, value)) {
                    throw new MissingValueException(name, (PMMLObject)categoricalPredictor);
                }
                double coefficient = categoricalPredictor.getCoefficient();
                if (coefficient != 1.0) {
                    throw new InvalidAttributeException((PMMLObject)categoricalPredictor, PMMLAttributes.CATEGORICALPREDICTOR_COEFFICIENT, coefficient);
                }
                boolean equals = value.equals((HasValue<?>)categoricalPredictor);
                result[i] = equals ? 1.0 : 0.0;
                continue;
            }
            throw new MisplacedElementException(object);
        }
        return result;
    }

    private Map<String, double[]> getVectorMap() {
        if (this.vectorMap == null) {
            this.vectorMap = this.getValue(vectorCache);
        }
        return this.vectorMap;
    }

    private static Map<String, double[]> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        List content = vectorFields.getContent();
        LinkedHashMap<String, double[]> result = new LinkedHashMap<String, double[]>();
        List vectorInstances = vectorDictionary.getVectorInstances();
        for (VectorInstance vectorInstance : vectorInstances) {
            List<Number> values;
            String id = vectorInstance.getId();
            if (id == null) {
                throw new MissingAttributeException((PMMLObject)vectorInstance, PMMLAttributes.VECTORINSTANCE_ID);
            }
            Array array = vectorInstance.getArray();
            RealSparseArray sparseArray = vectorInstance.getRealSparseArray();
            if (array != null && sparseArray == null) {
                values = ArrayUtil.asNumberList(array);
            } else if (array == null && sparseArray != null) {
                values = SparseArrayUtil.asNumberList(sparseArray);
            } else {
                throw new InvalidElementException((PMMLObject)vectorInstance);
            }
            if (content.size() != values.size()) {
                throw new InvalidElementException((PMMLObject)vectorInstance);
            }
            double[] vector = Doubles.toArray(values);
            result.put(id, vector);
        }
        return result;
    }
}

