/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableMap;
import java.lang.reflect.Field;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.Array;
import org.dmg.pmml.Coefficient;
import org.dmg.pmml.Coefficients;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Kernel;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.SupportVector;
import org.dmg.pmml.SupportVectorMachine;
import org.dmg.pmml.SupportVectorMachineModel;
import org.dmg.pmml.SupportVectors;
import org.dmg.pmml.SvmClassificationMethodType;
import org.dmg.pmml.SvmRepresentationType;
import org.dmg.pmml.VectorDictionary;
import org.dmg.pmml.VectorFields;
import org.dmg.pmml.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.KernelUtil;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;
import org.jpmml.evaluator.VoteDistribution;
import org.jpmml.model.ReflectionUtil;

public class SupportVectorMachineModelEvaluator
extends ModelEvaluator<SupportVectorMachineModel> {
    private static final LoadingCache<SupportVectorMachineModel, Map<String, double[]>> vectorCache = CacheBuilder.newBuilder().weakKeys().build((CacheLoader)new CacheLoader<SupportVectorMachineModel, Map<String, double[]>>(){

        public Map<String, double[]> load(SupportVectorMachineModel supportVectorMachineModel) {
            return ImmutableMap.copyOf((Map)SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel));
        }
    });

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        super(pmml, SupportVectorMachineModel.class);
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
    }

    @Override
    public String getSummary() {
        return "Support vector machine";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, Object> predictions;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        if (!supportVectorMachineModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)supportVectorMachineModel);
        }
        SvmRepresentationType svmRepresentation = supportVectorMachineModel.getSvmRepresentation();
        switch (svmRepresentation) {
            case SUPPORT_VECTORS: {
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)supportVectorMachineModel, (Enum<?>)svmRepresentation);
            }
        }
        MiningFunctionType miningFunction = supportVectorMachineModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)supportVectorMachineModel, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        List supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() != 1) {
            throw new InvalidFeatureException((PMMLObject)supportVectorMachineModel);
        }
        SupportVectorMachine supportVectorMachine = (SupportVectorMachine)supportVectorMachines.get(0);
        double[] input = this.createInput(context);
        Double result = this.evaluateSupportVectorMachine(supportVectorMachine, input);
        return TargetUtil.evaluateRegression(result, context);
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext context) {
        Classification result;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        List supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        if (supportVectorMachines.size() < 1) {
            throw new InvalidFeatureException((PMMLObject)supportVectorMachineModel);
        }
        String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        SvmClassificationMethodType svmClassificationMethod = this.getClassificationMethod();
        switch (svmClassificationMethod) {
            case ONE_AGAINST_ALL: {
                result = new Classification(Classification.Type.DISTANCE);
                break;
            }
            case ONE_AGAINST_ONE: {
                result = new VoteDistribution();
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)supportVectorMachineModel, (Enum<?>)svmClassificationMethod);
            }
        }
        double[] input = this.createInput(context);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            String targetCategory = supportVectorMachine.getTargetCategory();
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            Double value = this.evaluateSupportVectorMachine(supportVectorMachine, input);
            switch (svmClassificationMethod) {
                case ONE_AGAINST_ALL: {
                    if (targetCategory == null || alternateTargetCategory != null) {
                        throw new InvalidFeatureException((PMMLObject)supportVectorMachine);
                    }
                    result.put(targetCategory, value);
                    break;
                }
                case ONE_AGAINST_ONE: {
                    String label;
                    Double vote;
                    if (alternateBinaryTargetCategory != null) {
                        String label2;
                        if (targetCategory == null || alternateTargetCategory != null) {
                            throw new InvalidFeatureException((PMMLObject)supportVectorMachine);
                        }
                        long roundedValue = Math.round(value);
                        if (roundedValue == 1L) {
                            label2 = targetCategory;
                        } else if (roundedValue == 0L) {
                            label2 = alternateBinaryTargetCategory;
                        } else {
                            throw new EvaluationException("Invalid numeric prediction " + value);
                        }
                        Double vote2 = result.get(label2);
                        if (vote2 == null) {
                            vote2 = 0.0;
                        }
                        result.put(label2, vote2 + 1.0);
                        break;
                    }
                    if (targetCategory == null || alternateTargetCategory == null) {
                        throw new InvalidFeatureException((PMMLObject)supportVectorMachine);
                    }
                    Double threshold = supportVectorMachine.getThreshold();
                    if (threshold == null) {
                        threshold = supportVectorMachineModel.getThreshold();
                    }
                    if ((vote = result.get(label = value.compareTo(threshold) < 0 ? targetCategory : alternateTargetCategory)) == null) {
                        vote = 0.0;
                    }
                    result.put(label, vote + 1.0);
                    break;
                }
            }
        }
        return TargetUtil.evaluateClassification(result, context);
    }

    private Double evaluateSupportVectorMachine(SupportVectorMachine supportVectorMachine, double[] input) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        double result = 0.0;
        Kernel kernel = supportVectorMachineModel.getKernel();
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator coefficientIterator = coefficients.iterator();
        SupportVectors supportVectors = supportVectorMachine.getSupportVectors();
        Iterator supportVectorIterator = supportVectors.iterator();
        Map<String, double[]> vectorMap = this.getVectorMap();
        while (coefficientIterator.hasNext() && supportVectorIterator.hasNext()) {
            Coefficient coefficient = (Coefficient)coefficientIterator.next();
            SupportVector supportVector = (SupportVector)supportVectorIterator.next();
            double[] vector = vectorMap.get(supportVector.getVectorId());
            if (vector == null) {
                throw new InvalidFeatureException((PMMLObject)supportVector);
            }
            Double value = KernelUtil.evaluate(kernel, input, vector);
            result += coefficient.getValue() * value;
        }
        if (coefficientIterator.hasNext() || supportVectorIterator.hasNext()) {
            throw new InvalidFeatureException((PMMLObject)supportVectorMachine);
        }
        return result += coefficients.getAbsoluteValue();
    }

    private SvmClassificationMethodType getClassificationMethod() {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        Field field = ReflectionUtil.getField((Object)supportVectorMachineModel, (String)"classificationMethod");
        SvmClassificationMethodType svmClassificationMethod = (SvmClassificationMethodType)ReflectionUtil.getFieldValue((Field)field, (Object)supportVectorMachineModel);
        if (svmClassificationMethod != null) {
            return svmClassificationMethod;
        }
        List supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines();
        String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        if (alternateBinaryTargetCategory != null) {
            if (supportVectorMachines.size() == 1) {
                SupportVectorMachine supportVectorMachine = (SupportVectorMachine)supportVectorMachines.get(0);
                String targetCategory = supportVectorMachine.getTargetCategory();
                if (targetCategory != null) {
                    return SvmClassificationMethodType.ONE_AGAINST_ONE;
                }
                throw new InvalidFeatureException((PMMLObject)supportVectorMachine);
            }
            throw new InvalidFeatureException((PMMLObject)supportVectorMachineModel);
        }
        Iterator i$ = supportVectorMachines.iterator();
        if (i$.hasNext()) {
            SupportVectorMachine supportVectorMachine = (SupportVectorMachine)i$.next();
            String targetCategory = supportVectorMachine.getTargetCategory();
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            if (targetCategory != null) {
                if (alternateTargetCategory != null) {
                    return SvmClassificationMethodType.ONE_AGAINST_ONE;
                }
                return SvmClassificationMethodType.ONE_AGAINST_ALL;
            }
            throw new InvalidFeatureException((PMMLObject)supportVectorMachine);
        }
        throw new InvalidFeatureException((PMMLObject)supportVectorMachineModel);
    }

    private double[] createInput(EvaluationContext context) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel)this.getModel();
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        List fieldRefs = vectorFields.getFieldRefs();
        double[] result = new double[fieldRefs.size()];
        for (int i = 0; i < fieldRefs.size(); ++i) {
            FieldRef fieldRef = (FieldRef)fieldRefs.get(i);
            FieldValue value = ExpressionUtil.evaluate((Expression)fieldRef, context);
            if (value == null) {
                throw new MissingFieldException(fieldRef.getField(), (PMMLObject)vectorFields);
            }
            result[i] = value.asNumber().doubleValue();
        }
        return result;
    }

    private Map<String, double[]> getVectorMap() {
        return this.getValue(vectorCache);
    }

    private static Map<String, double[]> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        VectorFields vectorFields = vectorDictionary.getVectorFields();
        List fieldRefs = vectorFields.getFieldRefs();
        LinkedHashMap<String, double[]> result = new LinkedHashMap<String, double[]>();
        List vectorInstances = vectorDictionary.getVectorInstances();
        for (VectorInstance vectorInstance : vectorInstances) {
            double[] vector;
            String id = vectorInstance.getId();
            if (id == null) {
                throw new InvalidFeatureException((PMMLObject)vectorInstance);
            }
            Array array = vectorInstance.getArray();
            RealSparseArray sparseArray = vectorInstance.getREALSparseArray();
            if (array != null && sparseArray == null) {
                vector = ArrayUtil.toArray(array);
            } else if (array == null && sparseArray != null) {
                vector = SparseArrayUtil.toArray(sparseArray);
            } else {
                throw new InvalidFeatureException((PMMLObject)vectorInstance);
            }
            if (fieldRefs.size() != vector.length) {
                throw new InvalidFeatureException((PMMLObject)vectorInstance);
            }
            result.put(id, vector);
        }
        return result;
    }
}

