/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Maps;
import java.io.Serializable;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.Array;
import org.dmg.pmml.Coefficient;
import org.dmg.pmml.Coefficients;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.KernelType;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.SupportVector;
import org.dmg.pmml.SupportVectorMachine;
import org.dmg.pmml.SupportVectorMachineModel;
import org.dmg.pmml.SupportVectors;
import org.dmg.pmml.SvmClassificationMethodType;
import org.dmg.pmml.SvmRepresentationType;
import org.dmg.pmml.VectorDictionary;
import org.dmg.pmml.VectorFields;
import org.dmg.pmml.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.KernelTypeUtil;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelManagerEvaluationContext;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PMMLObjectUtil;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class SupportVectorMachineModelEvaluator
extends ModelEvaluator<SupportVectorMachineModel> {
    private static final LoadingCache<SupportVectorMachineModel, Map<String, double[]>> vectorCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<SupportVectorMachineModel, Map<String, double[]>>(){

        @Override
        public Map<String, double[]> load(SupportVectorMachineModel supportVectorMachineModel) {
            return SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel);
        }
    });

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, SupportVectorMachineModelEvaluator.find(pmml.getModels(), SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
    }

    @Override
    public String getSummary() {
        return "Support vector machine";
    }

    @Override
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> arguments) {
        Map<FieldName, Serializable> predictions;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        if (!supportVectorMachineModel.isScorable()) {
            throw new InvalidResultException(supportVectorMachineModel);
        }
        SvmRepresentationType svmRepresentation = supportVectorMachineModel.getSvmRepresentation();
        switch (svmRepresentation) {
            case SUPPORT_VECTORS: {
                break;
            }
            default: {
                throw new UnsupportedFeatureException(supportVectorMachineModel, svmRepresentation);
            }
        }
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
        context.pushFrame(arguments);
        MiningFunctionType miningFunction = supportVectorMachineModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException(supportVectorMachineModel, miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ? extends Number> evaluateRegression(ModelManagerEvaluationContext context) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() != 1) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
        SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0);
        double[] input = this.createInput(context);
        Double value = this.evaluateSupportVectorMachine(supportVectorMachine, input);
        return TargetUtil.evaluateRegression(value, context);
    }

    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelManagerEvaluationContext context) {
        ClassificationMap<String> result;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() < 1) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
        SvmClassificationMethodType svmClassificationMethod = this.getClassificationMethod();
        switch (svmClassificationMethod) {
            case ONE_AGAINST_ALL: {
                result = new ClassificationMap<String>(ClassificationMap.Type.DISTANCE);
                break;
            }
            case ONE_AGAINST_ONE: {
                result = new ClassificationMap(ClassificationMap.Type.VOTE);
                break;
            }
            default: {
                throw new UnsupportedFeatureException(supportVectorMachineModel, svmClassificationMethod);
            }
        }
        double[] input = this.createInput(context);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            String category = supportVectorMachine.getTargetCategory();
            String alternateCategory = supportVectorMachine.getAlternateTargetCategory();
            Double value = this.evaluateSupportVectorMachine(supportVectorMachine, input);
            switch (svmClassificationMethod) {
                case ONE_AGAINST_ALL: {
                    if (category == null || alternateCategory != null) {
                        throw new InvalidFeatureException(supportVectorMachine);
                    }
                    result.put(category, value);
                    break;
                }
                case ONE_AGAINST_ONE: {
                    String label;
                    Double vote;
                    if (category == null || alternateCategory == null) {
                        throw new InvalidFeatureException(supportVectorMachine);
                    }
                    Double threshold = supportVectorMachine.getThreshold();
                    if (threshold == null) {
                        threshold = supportVectorMachineModel.getThreshold();
                    }
                    if ((vote = (Double)result.get(label = value.compareTo(threshold) < 0 ? category : alternateCategory)) == null) {
                        vote = 0.0;
                    }
                    result.put(label, vote + 1.0);
                    break;
                }
            }
        }
        return TargetUtil.evaluateClassification(result, context);
    }

    private Double evaluateSupportVectorMachine(SupportVectorMachine supportVectorMachine, double[] input) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        double result = 0.0;
        KernelType kernelType = supportVectorMachineModel.getKernelType();
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator<Coefficient> coefficientIterator = coefficients.iterator();
        SupportVectors supportVectors = supportVectorMachine.getSupportVectors();
        Iterator<SupportVector> supportVectorIterator = supportVectors.iterator();
        Map<String, double[]> vectorMap = this.getVectorMap();
        while (coefficientIterator.hasNext() && supportVectorIterator.hasNext()) {
            Coefficient coefficient = coefficientIterator.next();
            SupportVector supportVector = supportVectorIterator.next();
            double[] vector = vectorMap.get(supportVector.getVectorId());
            if (vector == null) {
                throw new InvalidFeatureException(supportVector);
            }
            Double value = KernelTypeUtil.evaluate(kernelType, input, vector);
            result += coefficient.getValue() * value;
        }
        if (coefficientIterator.hasNext() || supportVectorIterator.hasNext()) {
            throw new InvalidFeatureException(supportVectorMachine);
        }
        return result += coefficients.getAbsoluteValue();
    }

    private SvmClassificationMethodType getClassificationMethod() {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        SvmClassificationMethodType svmClassificationMethod = (SvmClassificationMethodType)((Object)PMMLObjectUtil.getField(supportVectorMachineModel, "classificationMethod"));
        if (svmClassificationMethod != null) {
            return svmClassificationMethod;
        }
        List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        Iterator<SupportVectorMachine> i$ = supportVectorMachines.iterator();
        if (i$.hasNext()) {
            SupportVectorMachine supportVectorMachine = i$.next();
            String category = supportVectorMachine.getTargetCategory();
            String alternateCategory = supportVectorMachine.getAlternateTargetCategory();
            if (category != null) {
                if (alternateCategory != null) {
                    return SvmClassificationMethodType.ONE_AGAINST_ONE;
                }
                return SvmClassificationMethodType.ONE_AGAINST_ALL;
            }
            throw new InvalidFeatureException(supportVectorMachine);
        }
        throw new InvalidFeatureException(supportVectorMachineModel);
    }

    private double[] createInput(EvaluationContext context) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        List<FieldRef> fieldRefs = vectorFields.getFieldRefs();
        double[] result = new double[fieldRefs.size()];
        for (int i = 0; i < fieldRefs.size(); ++i) {
            FieldRef fieldRef = fieldRefs.get(i);
            FieldValue value = ExpressionUtil.evaluate(fieldRef, context);
            if (value == null) {
                throw new MissingFieldException(fieldRef.getField(), (PMMLObject)vectorFields);
            }
            result[i] = value.asNumber().doubleValue();
        }
        Integer numberOfFields = vectorFields.getNumberOfFields();
        if (numberOfFields != null && numberOfFields != result.length) {
            throw new InvalidFeatureException(vectorFields);
        }
        return result;
    }

    private Map<String, double[]> getVectorMap() {
        return this.getValue(vectorCache);
    }

    private static Map<String, double[]> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        LinkedHashMap<String, double[]> result = Maps.newLinkedHashMap();
        List<VectorInstance> vectorInstances = vectorDictionary.getVectorInstances();
        for (VectorInstance vectorInstance : vectorInstances) {
            double[] vector;
            Array array = vectorInstance.getArray();
            RealSparseArray sparseArray = vectorInstance.getREALSparseArray();
            if (array != null && sparseArray == null) {
                vector = ArrayUtil.toArray(array);
            } else if (array == null && sparseArray != null) {
                vector = SparseArrayUtil.toArray(sparseArray);
            } else {
                throw new InvalidFeatureException(vectorInstance);
            }
            Integer numberOfFields = vectorFields.getNumberOfFields();
            if (numberOfFields != null && numberOfFields != vector.length) {
                throw new InvalidFeatureException(vectorInstance);
            }
            result.put(vectorInstance.getId(), vector);
        }
        Integer numberOfVectors = vectorDictionary.getNumberOfVectors();
        if (numberOfVectors != null && numberOfVectors.intValue() != result.size()) {
            throw new InvalidFeatureException(vectorDictionary);
        }
        return result;
    }
}

