/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.iforest;

import java.util.ArrayList;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import numpy.core.Scalar;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.AbstractTransformation;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.OutlierTransformation;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sklearn.SkLearnUtil;
import org.jpmml.sklearn.TreeModelProducer;
import sklearn.Regressor;
import sklearn.ensemble.EnsembleRegressor;
import sklearn.tree.ExtraTreeRegressor;
import sklearn.tree.Tree;
import sklearn.tree.TreeModelUtil;

public class IsolationForest
extends EnsembleRegressor
implements TreeModelProducer {
    public IsolationForest(String module, String name) {
        super(module, name);
    }

    @Override
    public boolean isSupervised() {
        return false;
    }

    public MiningModel encodeModel(Schema schema) {
        String sklearnVersion = this.getSkLearnVersion();
        List<? extends Regressor> estimators = this.getEstimators();
        final boolean corrected = sklearnVersion != null && SkLearnUtil.compareVersion(sklearnVersion, "0.19") >= 0;
        PredicateManager predicateManager = new PredicateManager();
        Schema segmentSchema = schema.toAnonymousSchema();
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (Regressor regressor : estimators) {
            ExtraTreeRegressor treeRegressor = (ExtraTreeRegressor)regressor;
            final Tree tree = treeRegressor.getTree();
            TreeModel treeModel = TreeModelUtil.encodeTreeModel(treeRegressor, predicateManager, MiningFunction.REGRESSION, segmentSchema);
            AbstractVisitor visitor = new AbstractVisitor(){
                private int[] nodeSamples;
                {
                    this.nodeSamples = tree.getNodeSamples();
                }

                public VisitorAction visit(Node node) {
                    if (node.getScore() != null) {
                        PMMLObject parent;
                        double nodeDepth = 0.0;
                        Deque parents = this.getParents();
                        Iterator i$ = parents.iterator();
                        while (i$.hasNext() && (parent = (PMMLObject)i$.next()) instanceof Node) {
                            nodeDepth += 1.0;
                        }
                        double nodeSample = this.nodeSamples[Integer.parseInt(node.getId()) - 1];
                        double averagePathLength = corrected ? IsolationForest.correctedAveragePathLength(nodeSample) : IsolationForest.averagePathLength(nodeSample);
                        node.setScore(ValueUtil.formatValue((Number)(nodeDepth + averagePathLength)));
                    }
                    return super.visit(node);
                }
            };
            visitor.applyTo((Visitable)treeModel);
            treeModels.add(treeModel);
        }
        AbstractTransformation normalizedAnomalyScore = new AbstractTransformation(){

            public FieldName getName(FieldName name) {
                return FieldName.create((String)"normalizedAnomalyScore");
            }

            public Expression createExpression(FieldRef fieldRef) {
                double maxSamples = IsolationForest.this.getMaxSamples();
                double averagePathLength = corrected ? IsolationForest.correctedAveragePathLength(maxSamples) : IsolationForest.averagePathLength(maxSamples);
                return PMMLUtil.createApply((String)"/", (Expression[])new Expression[]{fieldRef, PMMLUtil.createConstant((Object)averagePathLength)});
            }
        };
        AbstractTransformation abstractTransformation = new AbstractTransformation(){

            public FieldName getName(FieldName name) {
                return FieldName.create((String)"decisionFunction");
            }

            public Expression createExpression(FieldRef fieldRef) {
                return PMMLUtil.createApply((String)"-", (Expression[])new Expression[]{PMMLUtil.createConstant((Object)0.5), PMMLUtil.createApply((String)"pow", (Expression[])new Expression[]{PMMLUtil.createConstant((Object)2.0), PMMLUtil.createApply((String)"*", (Expression[])new Expression[]{PMMLUtil.createConstant((Object)-1.0), fieldRef})})});
            }
        };
        OutlierTransformation outlier = new OutlierTransformation(){

            public Expression createExpression(FieldRef fieldRef) {
                return PMMLUtil.createApply((String)"lessOrEqual", (Expression[])new Expression[]{fieldRef, PMMLUtil.createConstant((Object)IsolationForest.this.getThreshold())});
            }
        };
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldName.create((String)"rawAnomalyScore"), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[]{normalizedAnomalyScore, abstractTransformation, outlier}));
        return miningModel;
    }

    public String getSkLearnVersion() {
        return (String)this.get("_sklearn_version");
    }

    public int getMaxSamples() {
        return ValueUtil.asInt((Number)((Number)this.get("max_samples_")));
    }

    public double getThreshold() {
        Scalar threshold = (Scalar)this.get("threshold_");
        return ValueUtil.asDouble((Number)((Number)threshold.getOnlyElement()));
    }

    private static double averagePathLength(double n) {
        if (n <= 1.0) {
            return 1.0;
        }
        return 2.0 * (Math.log(n) + 0.5772156649) - 2.0 * ((n - 1.0) / n);
    }

    private static double correctedAveragePathLength(double n) {
        if (n <= 1.0) {
            return 1.0;
        }
        return 2.0 * (Math.log(n - 1.0) + 0.5772156649015329) - 2.0 * ((n - 1.0) / n);
    }
}

