/*
 * Decompiled with CFR 0.152.
 */
package sklearn.decomposition;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.FeatureMapper;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Transformer;

public class PCA
extends Transformer {
    private static final AtomicInteger SEQUENCE = new AtomicInteger(1);

    public PCA(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<String> ids, List<Feature> inputFeatures, FeatureMapper featureMapper) {
        int[] shape = this.getComponentsShape();
        int numberOfComponents = shape[0];
        int numberOfFeatures = shape[1];
        if (ids.size() != numberOfFeatures || inputFeatures.size() != numberOfFeatures) {
            throw new IllegalArgumentException();
        }
        String id = String.valueOf(SEQUENCE.getAndIncrement());
        List<? extends Number> components = this.getComponents();
        List<? extends Number> mean = this.getMean();
        Boolean whiten = this.getWhiten();
        List<? extends Number> explainedVariance = whiten != false ? this.getExplainedVariance() : null;
        ids.clear();
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < numberOfComponents; ++i) {
            Number explainedVarianceValue;
            List<? extends Number> component = MatrixUtil.getRow(components, numberOfComponents, numberOfFeatures, i);
            Apply apply = new Apply("sum");
            for (int j = 0; j < numberOfFeatures; ++j) {
                Number componentValue;
                Feature inputFeature = inputFeatures.get(j);
                FieldRef expression = inputFeature.ref();
                Number meanValue = mean.get(j);
                if (!ValueUtil.isZero((Number)meanValue)) {
                    expression = PMMLUtil.createApply((String)"-", (Expression[])new Expression[]{expression, PMMLUtil.createConstant((Object)meanValue)});
                }
                if (!ValueUtil.isOne((Number)(componentValue = component.get(j)))) {
                    expression = PMMLUtil.createApply((String)"*", (Expression[])new Expression[]{expression, PMMLUtil.createConstant((Object)componentValue)});
                }
                apply.addExpressions(new Expression[]{expression});
            }
            if (whiten.booleanValue() && !ValueUtil.isOne((Number)(explainedVarianceValue = explainedVariance.get(i)))) {
                apply = PMMLUtil.createApply((String)"/", (Expression[])new Expression[]{apply, PMMLUtil.createConstant((Object)Math.sqrt(ValueUtil.asDouble((Number)explainedVarianceValue)))});
            }
            DerivedField derivedField = featureMapper.createDerivedField(this.createName(id, i), (Expression)apply);
            ids.add(derivedField.getName().getValue());
            features.add((Feature)new ContinuousFeature((TypeDefinitionField)derivedField));
        }
        return features;
    }

    @Override
    protected String name() {
        return "pca";
    }

    public Boolean getWhiten() {
        return (Boolean)this.get("whiten");
    }

    public List<? extends Number> getComponents() {
        return ClassDictUtil.getArray(this, "components_");
    }

    public List<? extends Number> getExplainedVariance() {
        return ClassDictUtil.getArray(this, "explained_variance_");
    }

    public List<? extends Number> getMean() {
        return ClassDictUtil.getArray(this, "mean_");
    }

    private int[] getComponentsShape() {
        return ClassDictUtil.getShape(this, "components_", 2);
    }
}

