/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.base.Equivalence;
import com.google.common.collect.MapDifference;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Batch;
import org.jpmml.evaluator.Conflict;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.HasGroupFields;

public class BatchUtil {
    private BatchUtil() {
    }

    public static List<Conflict> evaluate(Batch batch, Equivalence<Object> equivalence) throws Exception {
        Evaluator evaluator = batch.getEvaluator();
        List input = batch.getInput();
        List<Map<FieldName, ?>> output = batch.getOutput();
        if (evaluator instanceof HasGroupFields) {
            HasGroupFields hasGroupFields = (HasGroupFields)evaluator;
            input = EvaluatorUtil.groupRows((HasGroupFields)hasGroupFields, input);
        }
        if (input.size() != output.size()) {
            throw new IllegalArgumentException("Expected the same number of data rows, got " + input.size() + " input data rows and " + output.size() + " expected output data rows");
        }
        Predicate<FieldName> predicate = batch.getPredicate().and(name -> !Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name));
        ArrayList<Conflict> conflicts = new ArrayList<Conflict>();
        for (int i = 0; i < input.size(); ++i) {
            Map arguments = (Map)input.get(i);
            Map expectedResults = output.get(i);
            expectedResults = Maps.filterKeys(expectedResults, predicate::test);
            try {
                Map actualResults = evaluator.evaluate(arguments);
                actualResults = Maps.filterKeys((Map)actualResults, predicate::test);
                MapDifference difference = Maps.difference((Map)expectedResults, (Map)actualResults, equivalence);
                if (difference.areEqual()) continue;
                Conflict conflict = new Conflict((Integer)i, arguments, difference);
                conflicts.add(conflict);
                continue;
            }
            catch (Exception e) {
                Conflict conflict = new Conflict((Integer)i, arguments, e);
                conflicts.add(conflict);
            }
        }
        return conflicts;
    }

    public static List<Map<FieldName, String>> parseRecords(List<List<String>> table, Function<String, String> function) {
        ArrayList<Map<FieldName, String>> records = new ArrayList<Map<FieldName, String>>(table.size() - 1);
        List<String> headerRow = table.get(0);
        LinkedHashSet<String> uniqueHeaderRow = new LinkedHashSet<String>(headerRow);
        if (uniqueHeaderRow.size() < headerRow.size()) {
            LinkedHashSet<String> duplicateHeaderCells = new LinkedHashSet<String>();
            for (int column = 0; column < headerRow.size(); ++column) {
                String headerCell = headerRow.get(column);
                if (Collections.frequency(headerRow, headerCell) == 1) continue;
                duplicateHeaderCells.add(headerCell);
            }
            if (duplicateHeaderCells.size() > 0) {
                throw new IllegalArgumentException("Expected unique cell names, got non-unique cell name(s) " + duplicateHeaderCells);
            }
        }
        for (int row = 1; row < table.size(); ++row) {
            List<String> bodyRow = table.get(row);
            if (headerRow.size() != bodyRow.size()) {
                throw new IllegalArgumentException("Expected " + headerRow.size() + " cells, got " + bodyRow.size() + " cells (data record " + (row - 1) + ")");
            }
            LinkedHashMap<FieldName, String> record = new LinkedHashMap<FieldName, String>();
            for (int column = 0; column < headerRow.size(); ++column) {
                FieldName name = FieldName.create((String)headerRow.get(column));
                String value = function.apply(bodyRow.get(column));
                record.put(name, value);
            }
            records.add(record);
        }
        return records;
    }

    public static List<List<String>> formatRecords(List<Map<FieldName, ?>> records, List<FieldName> names, Function<Object, String> function) {
        ArrayList<List<String>> table = new ArrayList<List<String>>(1 + records.size());
        ArrayList<String> headerRow = new ArrayList<String>(names.size());
        for (FieldName fieldName : names) {
            headerRow.add(fieldName != null ? fieldName.getValue() : "(null)");
        }
        table.add(headerRow);
        for (Map map : records) {
            ArrayList<String> bodyRow = new ArrayList<String>(names.size());
            for (FieldName name : names) {
                bodyRow.add(function.apply(map.get(name)));
            }
            table.add(bodyRow);
        }
        return table;
    }
}

