/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.TreeModel;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MiningModelEvaluationContext;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PredicateUtil;
import org.jpmml.evaluator.ProbabilityAggregator;
import org.jpmml.evaluator.ProbabilityClassificationMap;
import org.jpmml.evaluator.SegmentResultMap;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.TypeCheckException;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.evaluator.VoteCounter;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.ModelManager;
import org.jpmml.manager.UnsupportedFeatureException;

public class MiningModelEvaluator
extends ModelEvaluator<MiningModel>
implements HasEntityRegistry<Segment> {
    private static final LoadingCache<MiningModel, BiMap<String, Segment>> entityCache = CacheBuilder.newBuilder().weakKeys().build((CacheLoader)new CacheLoader<MiningModel, BiMap<String, Segment>>(){

        public BiMap<String, Segment> load(MiningModel miningModel) {
            HashBiMap result = HashBiMap.create();
            Segmentation segmentation = miningModel.getSegmentation();
            EntityUtil.putAll(segmentation.getSegments(), result);
            return result;
        }
    });
    private static final ModelEvaluatorFactory evaluatorFactory = ModelEvaluatorFactory.getInstance();

    public MiningModelEvaluator(PMML pmml) {
        this(pmml, (MiningModel)MiningModelEvaluator.find((List)pmml.getModels(), MiningModel.class));
    }

    public MiningModelEvaluator(PMML pmml, MiningModel miningModel) {
        super(pmml, miningModel);
    }

    public String getSummary() {
        MiningModel miningModel = (MiningModel)this.getModel();
        if (MiningModelEvaluator.isRandomForest(miningModel)) {
            return "Random forest";
        }
        return "Ensemble model";
    }

    @Override
    public BiMap<String, Segment> getEntityRegistry() {
        return this.getValue(entityCache);
    }

    @Override
    public MiningModelEvaluationContext createContext(ModelEvaluationContext parent) {
        return new MiningModelEvaluationContext(parent, this);
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        return this.evaluate((MiningModelEvaluationContext)context);
    }

    public Map<FieldName, ?> evaluate(MiningModelEvaluationContext context) {
        Map<FieldName, ?> predictions;
        MiningModel miningModel = (MiningModel)this.getModel();
        if (!miningModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)miningModel);
        }
        EmbeddedModel embeddedModel = (EmbeddedModel)Iterables.getFirst((Iterable)miningModel.getEmbeddedModels(), null);
        if (embeddedModel != null) {
            throw new UnsupportedFeatureException((PMMLObject)embeddedModel);
        }
        MiningFunctionType miningFunction = miningModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            case CLUSTERING: {
                predictions = this.evaluateClustering(context);
                break;
            }
            default: {
                predictions = this.evaluateAny(context);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(MiningModelEvaluationContext context) {
        Double result;
        MiningModel miningModel = (MiningModel)this.getModel();
        List<SegmentResultMap> segmentResults = this.evaluateSegmentation(context);
        Map<FieldName, ?> predictions = this.getRegressionResult(segmentResults);
        if (predictions != null) {
            return predictions;
        }
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        double sum = 0.0;
        block8: for (SegmentResultMap segmentResult : segmentResults) {
            Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());
            Number number = (Number)TypeUtil.parseOrCast(DataType.DOUBLE, targetValue);
            switch (multipleModelMethod) {
                case SUM: 
                case AVERAGE: {
                    sum += number.doubleValue();
                    continue block8;
                }
                case WEIGHTED_AVERAGE: {
                    sum += segmentResult.getWeight() * number.doubleValue();
                    continue block8;
                }
            }
            throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
        }
        switch (multipleModelMethod) {
            case SUM: {
                result = sum;
                break;
            }
            case AVERAGE: 
            case WEIGHTED_AVERAGE: {
                result = sum / (double)segmentResults.size();
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
        }
        return TargetUtil.evaluateRegression(result, (ModelEvaluationContext)context);
    }

    private Map<FieldName, ?> getRegressionResult(List<SegmentResultMap> segmentResults) {
        MiningModel miningModel = (MiningModel)this.getModel();
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL: {
                return this.selectAll(segmentResults);
            }
            case SELECT_FIRST: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getFirst(segmentResults);
                }
            }
            case MODEL_CHAIN: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getLast(segmentResults);
                }
            }
            case SUM: 
            case AVERAGE: 
            case WEIGHTED_AVERAGE: {
                if (segmentResults.size() != 0) break;
                return Collections.singletonMap(this.getTargetField(), null);
            }
        }
        return null;
    }

    private Map<FieldName, ?> evaluateClassification(MiningModelEvaluationContext context) {
        ClassificationMap result;
        MiningModel miningModel = (MiningModel)this.getModel();
        List<SegmentResultMap> segmentResults = this.evaluateSegmentation(context);
        Map<FieldName, ?> predictions = this.getClassificationResult(segmentResults);
        if (predictions != null) {
            return predictions;
        }
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case MAJORITY_VOTE: 
            case WEIGHTED_MAJORITY_VOTE: {
                result = new ProbabilityClassificationMap<Object>();
                result.putAll(MiningModelEvaluator.countVotes(segmentation, segmentResults));
                result.normalizeValues();
                break;
            }
            case AVERAGE: 
            case WEIGHTED_AVERAGE: 
            case MAX: {
                result = new ClassificationMap(ClassificationMap.Type.VOTE);
                result.putAll(MiningModelEvaluator.aggregateProbabilities(segmentation, segmentResults));
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
        }
        return TargetUtil.evaluateClassification(result, (ModelEvaluationContext)context);
    }

    private Map<FieldName, ?> getClassificationResult(List<SegmentResultMap> segmentResults) {
        MiningModel miningModel = (MiningModel)this.getModel();
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL: {
                return this.selectAll(segmentResults);
            }
            case SELECT_FIRST: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getFirst(segmentResults);
                }
            }
            case MODEL_CHAIN: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getLast(segmentResults);
                }
            }
            case MAJORITY_VOTE: 
            case WEIGHTED_MAJORITY_VOTE: {
                if (segmentResults.size() != 0) break;
                return Collections.singletonMap(this.getTargetField(), null);
            }
        }
        return null;
    }

    private Map<FieldName, ?> evaluateClustering(MiningModelEvaluationContext context) {
        MiningModel miningModel = (MiningModel)this.getModel();
        List<SegmentResultMap> segmentResults = this.evaluateSegmentation(context);
        Map<FieldName, ?> predictions = this.getClusteringResult(segmentResults);
        if (predictions != null) {
            return predictions;
        }
        Segmentation segmentation = miningModel.getSegmentation();
        ClassificationMap<Object> result = new ClassificationMap<Object>(ClassificationMap.Type.VOTE);
        result.putAll(MiningModelEvaluator.countVotes(segmentation, segmentResults));
        return Collections.singletonMap(this.getTargetField(), result);
    }

    private Map<FieldName, ?> getClusteringResult(List<SegmentResultMap> segmentResults) {
        MiningModel miningModel = (MiningModel)this.getModel();
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL: {
                return this.selectAll(segmentResults);
            }
            case SELECT_FIRST: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getFirst(segmentResults);
                }
            }
            case MODEL_CHAIN: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getLast(segmentResults);
                }
            }
            case MAJORITY_VOTE: 
            case WEIGHTED_MAJORITY_VOTE: {
                if (segmentResults.size() != 0) break;
                return Collections.singletonMap(this.getTargetField(), null);
            }
        }
        return null;
    }

    private Map<FieldName, ?> evaluateAny(MiningModelEvaluationContext context) {
        MiningModel miningModel = (MiningModel)this.getModel();
        List<SegmentResultMap> segmentResults = this.evaluateSegmentation(context);
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL: {
                return this.selectAll(segmentResults);
            }
            case SELECT_FIRST: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getFirst(segmentResults);
                }
            }
            case MODEL_CHAIN: {
                if (segmentResults.size() > 0) {
                    return MiningModelEvaluator.getLast(segmentResults);
                }
                return Collections.singletonMap(this.getTargetField(), null);
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
    }

    private List<SegmentResultMap> evaluateSegmentation(MiningModelEvaluationContext context) {
        MiningModel miningModel = (MiningModel)this.getModel();
        ArrayList results = Lists.newArrayList();
        Segmentation segmentation = miningModel.getSegmentation();
        LocalTransformations localTransformations = segmentation.getLocalTransformations();
        if (localTransformations != null) {
            throw new UnsupportedFeatureException((PMMLObject)localTransformations);
        }
        BiMap inverseEntities = this.getEntityRegistry().inverse();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        Model lastModel = null;
        MiningFunctionType miningFunction = miningModel.getFunctionName();
        List segments = segmentation.getSegments();
        for (Segment segment : segments) {
            Predicate predicate = segment.getPredicate();
            if (predicate == null) {
                throw new InvalidFeatureException((PMMLObject)segment);
            }
            Boolean status = PredicateUtil.evaluate(predicate, context);
            if (status == null || !status.booleanValue()) continue;
            String id = (String)inverseEntities.get((Object)segment);
            Model model = segment.getModel();
            if (model == null) {
                throw new InvalidFeatureException((PMMLObject)segment);
            }
            switch (multipleModelMethod) {
                case MODEL_CHAIN: {
                    lastModel = model;
                    break;
                }
                default: {
                    if (miningFunction.equals((Object)model.getFunctionName())) break;
                    throw new InvalidFeatureException((PMMLObject)model);
                }
            }
            ModelManager evaluator = evaluatorFactory.getModelManager(this.getPMML(), model);
            ModelEvaluationContext segmentContext = evaluator.createContext(context);
            Map<FieldName, ?> result = evaluator.evaluate(segmentContext);
            FieldName targetField = evaluator.getTargetField();
            List outputFields = evaluator.getOutputFields();
            for (FieldName outputField : outputFields) {
                FieldValue outputValue = segmentContext.getField(outputField);
                if (outputValue == null) {
                    throw new MissingFieldException(outputField, (PMMLObject)segment);
                }
                context.declare(outputField, outputValue);
            }
            List<String> warnings = segmentContext.getWarnings();
            for (String warning : warnings) {
                context.addWarning(warning);
            }
            SegmentResultMap segmentResult = new SegmentResultMap(segment, targetField);
            segmentResult.putAll(result);
            context.putResult(id, segmentResult);
            switch (multipleModelMethod) {
                case SELECT_FIRST: {
                    return Collections.singletonList(segmentResult);
                }
            }
            results.add(segmentResult);
        }
        switch (multipleModelMethod) {
            case MODEL_CHAIN: {
                if (lastModel == null || miningFunction.equals((Object)lastModel.getFunctionName())) break;
                throw new InvalidFeatureException(lastModel);
            }
        }
        return results;
    }

    private Map<FieldName, ?> selectAll(List<SegmentResultMap> segmentResults) {
        ArrayListMultimap result = ArrayListMultimap.create();
        LinkedHashSet keys = null;
        for (SegmentResultMap segmentResult : segmentResults) {
            if (keys == null) {
                keys = Sets.newLinkedHashSet(segmentResult.keySet());
            }
            if (!keys.equals(segmentResult.keySet())) {
                throw new EvaluationException();
            }
            for (FieldName key : keys) {
                result.put((Object)key, segmentResult.get(key));
            }
        }
        return result.asMap();
    }

    private static <E> E getFirst(List<E> list) {
        return list.get(0);
    }

    private static <E> E getLast(List<E> list) {
        return list.get(list.size() - 1);
    }

    private static Map<Object, Double> countVotes(Segmentation segmentation, List<SegmentResultMap> segmentResults) {
        VoteCounter<Object> counter = new VoteCounter<Object>();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        block4: for (SegmentResultMap segmentResult : segmentResults) {
            Object targetCategory = EvaluatorUtil.decode(segmentResult.getTargetValue());
            switch (multipleModelMethod) {
                case MAJORITY_VOTE: {
                    counter.increment(targetCategory);
                    continue block4;
                }
                case WEIGHTED_MAJORITY_VOTE: {
                    counter.increment(targetCategory, segmentResult.getWeight());
                    continue block4;
                }
            }
            throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
        }
        return counter;
    }

    private static Map<Object, Double> aggregateProbabilities(Segmentation segmentation, List<SegmentResultMap> segmentResults) {
        ProbabilityAggregator<Object> aggregator = new ProbabilityAggregator<Object>();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        for (SegmentResultMap segmentResult : segmentResults) {
            Object targetValue = segmentResult.getTargetValue();
            if (!(targetValue instanceof ClassificationMap)) {
                throw new TypeCheckException(ClassificationMap.class, targetValue);
            }
            ClassificationMap values = (ClassificationMap)targetValue;
            if (!ClassificationMap.Type.PROBABILITY.equals(values.getType())) {
                throw new EvaluationException();
            }
            Set entries = values.entrySet();
            block10: for (Map.Entry entry : entries) {
                Object targetCategory = entry.getKey();
                Double probability = (Double)entry.getValue();
                switch (multipleModelMethod) {
                    case MAX: {
                        aggregator.max(targetCategory, probability);
                        continue block10;
                    }
                    case AVERAGE: {
                        aggregator.add(targetCategory, probability);
                        continue block10;
                    }
                    case WEIGHTED_AVERAGE: {
                        aggregator.add(targetCategory, segmentResult.getWeight() * probability);
                        continue block10;
                    }
                }
                throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
        }
        switch (multipleModelMethod) {
            case MAX: {
                break;
            }
            case AVERAGE: 
            case WEIGHTED_AVERAGE: {
                aggregator.divide(Double.valueOf(segmentResults.size()));
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
        }
        return aggregator;
    }

    private static boolean isRandomForest(MiningModel miningModel) {
        Segmentation segmentation = miningModel.getSegmentation();
        if (segmentation == null) {
            return false;
        }
        List segments = segmentation.getSegments();
        boolean result = segments.size() > 3;
        for (Segment segment : segments) {
            Model model = segment.getModel();
            result &= model instanceof TreeModel;
        }
        return result;
    }
}

