/*
 * Decompiled with CFR 0.152.
 */
package sktree.tree;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.tree.Tree;
import sktree.tree.ProjectionManager;

public class ObliqueTree
extends Tree {
    public ObliqueTree(String module, String name) {
        super(module, name);
    }

    public ObliqueTree(ObliqueTree that) {
        this(that.getPythonModule(), that.getPythonName());
        this.update((Map)((Object)that));
    }

    public ObliqueTree transform(Schema schema) {
        List features = schema.getFeatures();
        final int[] featureIndices = new int[features.size()];
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            featureIndices[i] = feature != null ? i : -2;
        }
        ObliqueTree result = new ObliqueTree(this){
            {
                super(that);
                this.delProjVecs();
            }

            public int[] getFeature() {
                return featureIndices;
            }
        };
        return result;
    }

    public Schema transformSchema(Object segmentId, ProjectionManager projectionManager, Schema schema) {
        SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
        Label label = schema.getLabel();
        List<Feature> features = schema.getFeatures();
        features = this.encodeFeatures(segmentId, features, projectionManager, encoder);
        return new Schema((PMMLEncoder)encoder, label, features);
    }

    public List<Feature> encodeFeatures(Object segmentId, List<Feature> features, ProjectionManager projectionManager, SkLearnEncoder encoder) {
        Integer nodeCount = this.getNodeCount();
        List<Number> projVecs = this.getProjVecs();
        int rows = nodeCount;
        int columns = features.size();
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int row = 0; row < rows; ++row) {
            String name = segmentId != null ? FieldNameUtil.create((String)"lc", (Object[])new Object[]{segmentId, row}) : FieldNameUtil.create((String)"lc", (Object[])new Object[]{row});
            List weights = CMatrixUtil.getRow(projVecs, (int)rows, (int)columns, (int)row);
            Feature feature = projectionManager.getOrCreateFeature(name, features, weights, encoder);
            result.add(feature);
        }
        return result;
    }

    public Integer getNodeCount() {
        return this.getInteger("node_count");
    }

    public boolean hasProjVecs() {
        return this.hasattr("proj_vecs");
    }

    public List<Number> getProjVecs() {
        return this.getNumberArray("proj_vecs");
    }

    public void delProjVecs() {
        this.delattr("proj_vecs");
    }
}

