/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.fpm.FPGrowthModel;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.association.AssociationModel;
import org.dmg.pmml.association.AssociationRule;
import org.dmg.pmml.association.Item;
import org.dmg.pmml.association.ItemRef;
import org.dmg.pmml.association.Itemset;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.sparkml.AssociationRulesModelConverter;
import org.jpmml.sparkml.ItemSetFeature;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class FPGrowthModelConverter
extends AssociationRulesModelConverter<FPGrowthModel> {
    public FPGrowthModelConverter(FPGrowthModel model) {
        super(model);
    }

    @Override
    public List<Feature> getFeatures(SparkMLEncoder encoder) {
        FPGrowthModel model = (FPGrowthModel)this.getTransformer();
        String itemsCol = model.getItemsCol();
        if (itemsCol.endsWith("s")) {
            itemsCol = itemsCol.substring(0, itemsCol.length() - 1);
        }
        DataField dataField = encoder.createDataField(FieldName.create((String)itemsCol), OpType.CATEGORICAL, DataType.STRING);
        ItemSetFeature feature = new ItemSetFeature(encoder, (Field<?>)dataField);
        return Collections.singletonList(feature);
    }

    public AssociationModel encodeModel(Schema schema) {
        FPGrowthModel model = (FPGrowthModel)this.getTransformer();
        List features = schema.getFeatures();
        SchemaUtil.checkSize((int)1, (List)features);
        Feature feature = (Feature)features.get(0);
        LinkedHashMap<String, Item> items = new LinkedHashMap<String, Item>();
        LinkedHashMap<List<String>, Itemset> itemsets = new LinkedHashMap<List<String>, Itemset>();
        ArrayList<AssociationRule> associationRules = new ArrayList<AssociationRule>();
        List associationRuleRows = model.associationRules().collectAsList();
        for (Row associationRuleRow : associationRuleRows) {
            List antecedent = JavaConversions.seqAsJavaList((Seq)((Seq)associationRuleRow.apply(0)));
            List consequent = JavaConversions.seqAsJavaList((Seq)((Seq)associationRuleRow.apply(1)));
            Double confidence = (Double)associationRuleRow.apply(2);
            Double lift = 0.0;
            Double support = 0.0;
            Itemset antecedentItemset = FPGrowthModelConverter.ensureItemset(feature, antecedent, itemsets, items);
            Itemset consequentItemset = FPGrowthModelConverter.ensureItemset(feature, consequent, itemsets, items);
            AssociationRule associationRule = new AssociationRule().setAntecedent(antecedentItemset.getId()).setConsequent(consequentItemset.getId());
            associationRule = associationRule.setConfidence((Number)confidence).setLift((Number)lift).setSupport((Number)support);
            associationRules.add(associationRule);
        }
        int numberOfTransactions = 0;
        MiningSchema miningSchema = new MiningSchema();
        AssociationModel associationModel = new AssociationModel(MiningFunction.ASSOCIATION_RULES, Integer.valueOf(numberOfTransactions), (Number)model.getMinSupport(), (Number)model.getMinConfidence(), Integer.valueOf(items.size()), Integer.valueOf(itemsets.size()), Integer.valueOf(associationRules.size()), miningSchema).setScorable(Boolean.FALSE);
        associationModel.getItems().addAll(items.values());
        associationModel.getItemsets().addAll(itemsets.values());
        associationModel.getAssociationRules().addAll(associationRules);
        return associationModel;
    }

    private static Itemset ensureItemset(Feature feature, List<String> values, Map<List<String>, Itemset> itemsets, Map<String, Item> items) {
        Itemset itemset = itemsets.get(values);
        if (itemset == null) {
            itemset = new Itemset(String.valueOf(itemsets.size() + 1));
            for (String value : values) {
                Item item = items.get(value);
                if (item == null) {
                    item = new Item(String.valueOf(items.size() + 1), value).setField(feature.getName());
                    items.put(value, item);
                }
                itemset.addItemRefs(new ItemRef[]{new ItemRef(item.getId())});
            }
            List itemRefs = itemset.getItemRefs();
            if (itemRefs.size() > 1) {
                Comparator<ItemRef> comparator = new Comparator<ItemRef>(){

                    @Override
                    public int compare(ItemRef left, ItemRef right) {
                        int leftId = Integer.parseInt(left.getItemRef());
                        int rightId = Integer.parseInt(right.getItemRef());
                        return Integer.compare(leftId, rightId);
                    }
                };
                Collections.sort(itemRefs, comparator);
            }
            itemsets.put(values, itemset);
        }
        return itemset;
    }
}

