/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sklearn.visitors;

import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.Model;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.visitors.DeepFieldResolver;
import org.jpmml.converter.visitors.FieldDependencyResolver;
import org.jpmml.converter.visitors.FieldUtil;

public class FeatureExpander
extends DeepFieldResolver {
    private Map<Model, Set<FieldName>> features = null;
    private Map<Model, Map<FieldName, Set<Field<?>>>> expandedFeatures = new IdentityHashMap();

    public FeatureExpander(Map<Model, Set<FieldName>> features) {
        this.features = Objects.requireNonNull(features);
    }

    public void reset() {
        super.reset();
        this.features.clear();
        this.expandedFeatures.clear();
    }

    public PMMLObject popParent() {
        PMMLObject parent = super.popParent();
        if (parent instanceof Model) {
            Model model = (Model)parent;
            this.processModel(model);
        }
        return parent;
    }

    private void processModel(Model model) {
        Map globalDerivedFields;
        FieldDependencyResolver fieldDependencyResolver = this.getFieldDependencyResolver();
        MiningModel parentMiningModel = null;
        Set<FieldName> features = this.features.get(model);
        if (features == null) {
            parentMiningModel = this.getParent(this.features.keySet());
            if (parentMiningModel != null) {
                features = this.features.get(parentMiningModel);
            }
            if (features == null) {
                return;
            }
        }
        Collection modelFields = this.getFields(new PMMLObject[]{model});
        Collection featureFields = FieldUtil.selectAll((Collection)modelFields, features, (boolean)true);
        Map localDerivedFields = Collections.emptyMap();
        LocalTransformations localTransformations = model.getLocalTransformations();
        if (localTransformations != null && localTransformations.hasDerivedFields()) {
            localDerivedFields = FieldUtil.nameMap((Collection)localTransformations.getDerivedFields());
        }
        if (parentMiningModel != null) {
            if (localDerivedFields.isEmpty()) {
                return;
            }
            featureFields.retainAll(localDerivedFields.values());
        }
        try {
            Method method = FieldDependencyResolver.class.getDeclaredMethod("getGlobalDerivedFields", new Class[0]);
            if (!method.isAccessible()) {
                method.setAccessible(true);
            }
            globalDerivedFields = FieldUtil.nameMap((Collection)((Collection)method.invoke((Object)fieldDependencyResolver, new Object[0])));
        }
        catch (ReflectiveOperationException roe) {
            throw new IllegalArgumentException(roe);
        }
        Map<FieldName, Set<Field<?>>> expandedFields = parentMiningModel != null ? this.ensureExpandedFeatures((Model)parentMiningModel) : this.ensureExpandedFeatures(model);
        for (Field featureField : featureFields) {
            FieldName name = featureField.getName();
            if (featureField instanceof DataField) {
                expandedFields.put(name, Collections.singleton(featureField));
                continue;
            }
            if (featureField instanceof DerivedField) {
                DerivedField derivedField = (DerivedField)featureField;
                HashSet<DerivedField> expandedFeatureFields = new HashSet<DerivedField>();
                expandedFeatureFields.add(derivedField);
                fieldDependencyResolver.expand(expandedFeatureFields, new HashSet(localDerivedFields.values()));
                fieldDependencyResolver.expand(expandedFeatureFields, new HashSet(globalDerivedFields.values()));
                expandedFields.put(name, expandedFeatureFields);
                continue;
            }
            if (featureField instanceof OutputField) {
                expandedFields.put(name, Collections.singleton(featureField));
                continue;
            }
            throw new IllegalArgumentException();
        }
    }

    private MiningModel getParent(Set<Model> models) {
        Deque parents = this.getParents();
        for (PMMLObject parent : parents) {
            MiningModel miningModel;
            if (!(parent instanceof MiningModel) || !models.contains(miningModel = (MiningModel)parent)) continue;
            return miningModel;
        }
        return null;
    }

    private Map<FieldName, Set<Field<?>>> ensureExpandedFeatures(Model model) {
        Map<Model, Map<FieldName, Set<Field<?>>>> expandedFeatures = this.getExpandedFeatures();
        Map<FieldName, Set<Field<?>>> result = expandedFeatures.get(model);
        if (result == null) {
            result = new HashMap();
            expandedFeatures.put(model, result);
        }
        return result;
    }

    public Map<FieldName, Set<Field<?>>> getExpandedFeatures(Model model) {
        Map<Model, Map<FieldName, Set<Field<?>>>> expandedFeatures = this.getExpandedFeatures();
        return expandedFeatures.get(model);
    }

    public Map<Model, Set<FieldName>> getFeatures() {
        return this.features;
    }

    public Map<Model, Map<FieldName, Set<Field<?>>>> getExpandedFeatures() {
        return this.expandedFeatures;
    }
}

