/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.feature.SQLTransformer;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar;
import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.parser.ParserInterface;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.internal.SessionState;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.StringFeature;
import org.jpmml.sparkml.ExpressionMapping;
import org.jpmml.sparkml.ExpressionTranslator;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class SQLTransformerConverter
extends FeatureConverter<SQLTransformer> {
    public SQLTransformerConverter(SQLTransformer sqlTransformer) {
        super(sqlTransformer);
    }

    @Override
    public List<Feature> encodeFeatures(final SparkMLEncoder encoder) {
        SQLTransformer transformer = (SQLTransformer)this.getTransformer();
        String statement = transformer.getStatement();
        SparkSession sparkSession = SparkSession.builder().getOrCreate();
        SessionState sessionState = sparkSession.sessionState();
        ParserInterface parserInterface = sessionState.sqlParser();
        LogicalPlan logicalPlan = parserInterface.parsePlan(statement);
        ExpressionTranslator.DataTypeResolver dataTypeResolver = new ExpressionTranslator.DataTypeResolver(){

            @Override
            public DataType getDataType(String name) {
                Feature feature = encoder.getOnlyFeature(name);
                return feature.getDataType();
            }
        };
        ArrayList<Feature> result = new ArrayList<Feature>();
        List expressions = JavaConversions.seqAsJavaList((Seq)logicalPlan.expressions());
        for (org.apache.spark.sql.catalyst.expressions.Expression expression : expressions) {
            StringFeature feature;
            OpType opType;
            String name;
            if (expression instanceof Alias) {
                Alias alias = (Alias)expression;
                expression = alias.child();
                name = alias.name();
            } else if (expression instanceof UnresolvedAlias) {
                UnresolvedAlias unresolvedAlias = (UnresolvedAlias)expression;
                expression = unresolvedAlias.child();
                name = "(" + expression.toString().replace("'", "") + ")";
            } else {
                if (expression instanceof UnresolvedStar) {
                    UnresolvedStar unresolvedStar = (UnresolvedStar)expression;
                    List<Feature> features = encoder.getSchemaFeatures();
                    result.addAll(features);
                    continue;
                }
                throw new IllegalArgumentException(String.valueOf(expression));
            }
            ExpressionMapping expressionMapping = ExpressionTranslator.translate(expression, dataTypeResolver);
            DataType dataType = expressionMapping.getDataType();
            switch (dataType) {
                case STRING: {
                    opType = OpType.CATEGORICAL;
                    break;
                }
                case INTEGER: 
                case DOUBLE: {
                    opType = OpType.CONTINUOUS;
                    break;
                }
                case BOOLEAN: {
                    opType = OpType.CATEGORICAL;
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
            Expression pmmlExpression = expressionMapping.getTo();
            DerivedField derivedField = encoder.createDerivedField(FieldName.create((String)name), opType, dataType, pmmlExpression);
            switch (dataType) {
                case STRING: {
                    feature = new StringFeature((PMMLEncoder)encoder, (Field)derivedField);
                    break;
                }
                case INTEGER: 
                case DOUBLE: {
                    feature = new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField);
                    break;
                }
                case BOOLEAN: {
                    feature = new BooleanFeature((PMMLEncoder)encoder, (Field)derivedField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
            encoder.putOnlyFeature(name, (Feature)feature);
            result.add((Feature)feature);
        }
        return result;
    }

    @Override
    public void registerFeatures(SparkMLEncoder encoder) {
        this.encodeFeatures(encoder);
    }
}

