/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.spark;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.catalyst.expressions.CreateStruct;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.spark.ColumnProducer;
import org.jpmml.evaluator.spark.DataFrameUtil;
import org.jpmml.evaluator.spark.ScalaUtil;
import scala.runtime.AbstractFunction1;

public class PMMLTransformer
extends Transformer {
    private String outputCol = "pmml";
    private Evaluator evaluator = null;
    private List<ColumnProducer<?>> columnProducers = null;
    private StructType outputSchema = null;

    public PMMLTransformer(Evaluator evaluator, List<ColumnProducer<?>> columnProducers) {
        StructType outputSchema = new StructType();
        for (ColumnProducer<?> columnProducer : columnProducers) {
            StructField structField = columnProducer.init(evaluator);
            outputSchema = outputSchema.add(structField);
        }
        this.setEvaluator(evaluator);
        this.setColumnProducers(columnProducers);
        this.setOutputSchema(outputSchema);
    }

    public String uid() {
        return null;
    }

    public PMMLTransformer copy(ParamMap extra) {
        throw new UnsupportedOperationException();
    }

    public StructType transformSchema(StructType schema) {
        StructField outputField = DataTypes.createStructField((String)this.getOutputCol(), (DataType)this.getOutputSchema(), (boolean)false);
        return schema.add(outputField);
    }

    public DataFrame transform(final DataFrame dataFrame) {
        final Evaluator evaluator = this.getEvaluator();
        final List<ColumnProducer<?>> columnProducers = this.getColumnProducers();
        final List inputFields = evaluator.getInputFields();
        Function<InputField, Expression> function = new Function<InputField, Expression>(){

            public Expression apply(InputField inputField) {
                Column column = dataFrame.apply(DataFrameUtil.escapeColumnName(inputField.getName().getValue()));
                return column.expr();
            }
        };
        ArrayList activeExpressions = Lists.newArrayList((Iterable)Lists.transform((List)inputFields, (Function)function));
        SerializableAbstractFunction1<Row, Row> evaluatorFunction = new SerializableAbstractFunction1<Row, Row>(){

            public Row apply(Row row) {
                LinkedHashMap<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
                for (int i = 0; i < inputFields.size(); ++i) {
                    InputField inputField = (InputField)inputFields.get(i);
                    FieldName name = inputField.getName();
                    Object value = row.get(i);
                    FieldValue preparedValue = inputField.prepare(value);
                    arguments.put(name, preparedValue);
                }
                Map result = evaluator.evaluate(arguments);
                ArrayList<Object> formattedValues = new ArrayList<Object>(columnProducers.size());
                for (int i = 0; i < columnProducers.size(); ++i) {
                    ColumnProducer columnProducer = (ColumnProducer)columnProducers.get(i);
                    Object resultField = columnProducer.getField();
                    FieldName name = resultField.getName();
                    Object value = result.get(name);
                    Object formattedValue = columnProducer.format(value);
                    formattedValues.add(formattedValue);
                }
                return RowFactory.create((Object[])formattedValues.toArray());
            }
        };
        ScalaUDF evaluateExpression = new ScalaUDF((Object)evaluatorFunction, (DataType)this.getOutputSchema(), ScalaUtil.singletonSeq(new CreateStruct(ScalaUtil.toSeq(activeExpressions))), ScalaUtil.emptySeq());
        Column outputColumn = new Column((Expression)evaluateExpression);
        return dataFrame.withColumn(DataFrameUtil.escapeColumnName(this.getOutputCol()), outputColumn);
    }

    public String[] getInputCols() {
        Evaluator evaluator = this.getEvaluator();
        List inputFields = evaluator.getActiveFields();
        Function<InputField, String> function = new Function<InputField, String>(){

            public String apply(InputField inputField) {
                return inputField.getName().getValue();
            }
        };
        ArrayList values = Lists.newArrayList((Iterable)Lists.transform((List)inputFields, (Function)function));
        return values.toArray(new String[values.size()]);
    }

    public String getOutputCol() {
        return this.outputCol;
    }

    public void setOutputCol(String outputCol) {
        if (outputCol == null) {
            throw new IllegalArgumentException();
        }
        this.outputCol = outputCol;
    }

    public Evaluator getEvaluator() {
        return this.evaluator;
    }

    private void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }

    public List<ColumnProducer<?>> getColumnProducers() {
        return this.columnProducers;
    }

    private void setColumnProducers(List<ColumnProducer<?>> columnProducers) {
        this.columnProducers = columnProducers;
    }

    public StructType getOutputSchema() {
        return this.outputSchema;
    }

    private void setOutputSchema(StructType outputSchema) {
        this.outputSchema = outputSchema;
    }

    public static abstract class SerializableAbstractFunction1<T1, R>
    extends AbstractFunction1<T1, R>
    implements Serializable {
    }
}

