/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.feature.SQLTransformer;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.StringFeature;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sparkml.AliasExpression;
import org.jpmml.sparkml.DatasetUtil;
import org.jpmml.sparkml.ExpressionTranslator;
import org.jpmml.sparkml.ExpressionUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class SQLTransformerConverter
extends FeatureConverter<SQLTransformer> {
    public SQLTransformerConverter(SQLTransformer sqlTransformer) {
        super(sqlTransformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
        SQLTransformer transformer = (SQLTransformer)this.getTransformer();
        String statement = transformer.getStatement();
        SparkSession sparkSession = SparkSession.builder().getOrCreate();
        StructType schema = encoder.getSchema();
        LogicalPlan logicalPlan = DatasetUtil.createAnalyzedLogicalPlan(sparkSession, schema, statement);
        ArrayList<Feature> result = new ArrayList<Feature>();
        List<Field<?>> fields = SQLTransformerConverter.encodeLogicalPlan(encoder, logicalPlan);
        for (Field<?> field : fields) {
            StringFeature feature;
            FieldName name = field.getName();
            OpType opType = field.getOpType();
            DataType dataType = field.getDataType();
            switch (dataType) {
                case STRING: {
                    feature = new StringFeature((PMMLEncoder)encoder, field);
                    break;
                }
                case INTEGER: 
                case DOUBLE: {
                    feature = new ContinuousFeature((PMMLEncoder)encoder, field);
                    break;
                }
                case BOOLEAN: {
                    feature = new BooleanFeature((PMMLEncoder)encoder, field);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Data type " + dataType + " is not supported");
                }
            }
            encoder.putOnlyFeature(name.getValue(), (Feature)feature);
            result.add((Feature)feature);
        }
        return result;
    }

    @Override
    public void registerFeatures(SparkMLEncoder encoder) {
        this.encodeFeatures(encoder);
    }

    public static List<Field<?>> encodeLogicalPlan(final SparkMLEncoder encoder, LogicalPlan logicalPlan) {
        ArrayList result = new ArrayList();
        List children = JavaConversions.seqAsJavaList((Seq)logicalPlan.children());
        for (LogicalPlan child : children) {
            SQLTransformerConverter.encodeLogicalPlan(encoder, child);
        }
        List expressions = JavaConversions.seqAsJavaList((Seq)logicalPlan.expressions());
        for (org.apache.spark.sql.catalyst.expressions.Expression expression : expressions) {
            FieldRef fieldRef;
            Field<?> field;
            Expression pmmlExpression = ExpressionTranslator.translate(encoder, expression);
            if (pmmlExpression instanceof FieldRef && (field = SQLTransformerConverter.ensureField(encoder, (fieldRef = (FieldRef)pmmlExpression).getField())) != null) {
                result.add(field);
                continue;
            }
            FieldName name = null;
            if (pmmlExpression instanceof AliasExpression) {
                AliasExpression aliasExpression = (AliasExpression)pmmlExpression;
                name = FieldName.create((String)aliasExpression.getName());
            } else {
                name = FieldNameUtil.create((String)"sql", (Object[])new Object[]{ExpressionUtil.format(expression)});
            }
            DataType dataType = DatasetUtil.translateDataType(expression.dataType());
            OpType opType = ExpressionUtil.getOpType(dataType);
            pmmlExpression = AliasExpression.unwrap(pmmlExpression);
            AbstractVisitor visitor = new AbstractVisitor(){

                public VisitorAction visit(FieldRef fieldRef) {
                    SQLTransformerConverter.ensureField(encoder, fieldRef.getField());
                    return super.visit(fieldRef);
                }
            };
            visitor.applyTo((Visitable)pmmlExpression);
            DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
            result.add((Field<?>)derivedField);
        }
        return result;
    }

    private static Field<?> ensureField(SparkMLEncoder encoder, FieldName name) {
        try {
            return encoder.getField(name);
        }
        catch (IllegalArgumentException pmmlIae) {
            try {
                return encoder.createDataField(name);
            }
            catch (IllegalArgumentException sparkIae) {
                return null;
            }
        }
    }
}

