/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.preprocessing;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.Lag;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;

public class RollingAggregateTransformer
extends Transformer {
    private static final String FUNCTION_AVG = "avg";
    private static final String FUNCTION_MAX = "max";
    private static final String FUNCTION_MEAN = "mean";
    private static final String FUNCTION_MIN = "min";
    private static final String FUNCTION_PROD = "prod";
    private static final String FUNCTION_PRODUCT = "product";
    private static final String FUNCTION_SUM = "sum";

    public RollingAggregateTransformer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        String function = this.getFunction();
        Integer n = this.getN();
        Lag.Aggregate aggregate = RollingAggregateTransformer.parseFunction(function);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = features.get(i);
            Field field = feature.getField();
            Lag lag = new Lag(field.requireName()).setAggregate(aggregate).setN(n);
            DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create((String)aggregate.value(), (Object[])new Object[]{feature, n}), OpType.CONTINUOUS, DataType.DOUBLE, (Expression)lag);
            result.add((Feature)new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField));
        }
        return result;
    }

    public String getFunction() {
        return (String)this.getEnum("function", arg_0 -> ((RollingAggregateTransformer)this).getString(arg_0), Arrays.asList(FUNCTION_AVG, FUNCTION_MAX, FUNCTION_MEAN, FUNCTION_MIN, FUNCTION_PROD, FUNCTION_PRODUCT, FUNCTION_SUM));
    }

    public Integer getN() {
        return this.getInteger("n");
    }

    private static Lag.Aggregate parseFunction(String function) {
        switch (function) {
            case "avg": 
            case "mean": {
                return Lag.Aggregate.AVG;
            }
            case "max": {
                return Lag.Aggregate.MAX;
            }
            case "min": {
                return Lag.Aggregate.MIN;
            }
            case "prod": 
            case "product": {
                return Lag.Aggregate.PRODUCT;
            }
            case "sum": {
                return Lag.Aggregate.SUM;
            }
        }
        throw new IllegalArgumentException(function);
    }
}

