/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.feature.SQLTransformer;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.TypeUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sparkml.AliasExpression;
import org.jpmml.sparkml.DatasetUtil;
import org.jpmml.sparkml.ExpressionTranslator;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.ScalaUtil;
import org.jpmml.sparkml.SparkMLEncoder;

public class SQLTransformerConverter
extends FeatureConverter<SQLTransformer> {
    public SQLTransformerConverter(SQLTransformer sqlTransformer) {
        super(sqlTransformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
        SQLTransformer transformer = (SQLTransformer)this.getTransformer();
        String statement = transformer.getStatement();
        SparkSession sparkSession = SparkSession.builder().getOrCreate();
        StructType schema = encoder.getSchema();
        LogicalPlan logicalPlan = DatasetUtil.createAnalyzedLogicalPlan(sparkSession, schema, statement);
        ArrayList<Feature> result = new ArrayList<Feature>();
        List<?> objects = SQLTransformerConverter.encodeLogicalPlan(encoder, logicalPlan);
        for (Object object : objects) {
            if (object instanceof List) {
                List features = (List)object;
                features.stream().map(Feature.class::cast).forEach(result::add);
                continue;
            }
            if (object instanceof Field) {
                Field field = (Field)object;
                String name = field.requireName();
                Feature feature = encoder.createFeature(field);
                encoder.putOnlyFeature(name, feature);
                result.add(feature);
                continue;
            }
            throw new IllegalArgumentException();
        }
        return result;
    }

    @Override
    public void registerFeatures(SparkMLEncoder encoder) {
        this.encodeFeatures(encoder);
    }

    public static List<?> encodeLogicalPlan(final SparkMLEncoder encoder, LogicalPlan logicalPlan) {
        ArrayList<List<Object>> result = new ArrayList<List<Object>>();
        List<LogicalPlan> children = ScalaUtil.seqAsJavaList(logicalPlan.children());
        for (LogicalPlan child : children) {
            SQLTransformerConverter.encodeLogicalPlan(encoder, child);
        }
        List<org.apache.spark.sql.catalyst.expressions.Expression> expressions = ScalaUtil.seqAsJavaList(logicalPlan.expressions());
        for (org.apache.spark.sql.catalyst.expressions.Expression expression : expressions) {
            String name;
            Expression pmmlExpression = ExpressionTranslator.translate(encoder, expression);
            if (pmmlExpression instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef)pmmlExpression;
                if (encoder.hasFeatures(fieldRef.requireField())) {
                    List<Feature> features = encoder.getFeatures(fieldRef.requireField());
                    result.add(features);
                    continue;
                }
                Field<?> field = SQLTransformerConverter.ensureField(encoder, fieldRef.requireField());
                if (field != null) {
                    result.add((List<Object>)field);
                    continue;
                }
            }
            if (pmmlExpression instanceof AliasExpression) {
                AliasExpression aliasExpression = (AliasExpression)pmmlExpression;
                name = aliasExpression.getName();
            } else {
                name = FieldNameUtil.create((String)"sql", (Object[])new Object[]{String.valueOf(expression)});
            }
            DataType dataType = DatasetUtil.translateDataType(expression.dataType());
            OpType opType = TypeUtil.getOpType((DataType)dataType);
            pmmlExpression = AliasExpression.unwrap(pmmlExpression);
            AbstractVisitor visitor = new AbstractVisitor(){

                public VisitorAction visit(FieldRef fieldRef) {
                    SQLTransformerConverter.ensureField(encoder, fieldRef.requireField());
                    return super.visit(fieldRef);
                }
            };
            visitor.applyTo((Visitable)pmmlExpression);
            DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
            result.add((List<Object>)derivedField);
        }
        return result;
    }

    private static Field<?> ensureField(SparkMLEncoder encoder, String name) {
        try {
            return encoder.getField(name);
        }
        catch (IllegalArgumentException pmmlIae) {
            try {
                return encoder.createDataField(name);
            }
            catch (IllegalArgumentException sparkIae) {
                return null;
            }
        }
    }
}

