/*
 * Decompiled with CFR 0.152.
 */
package edu.columbia.tjw.item.spark;

import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.base.SimpleRegressor;
import edu.columbia.tjw.item.base.SimpleStatus;
import edu.columbia.tjw.item.base.StandardCurveType;
import edu.columbia.tjw.item.base.raw.RawFittingGrid;
import edu.columbia.tjw.item.data.ItemStatusGrid;
import edu.columbia.tjw.item.fit.FitResult;
import edu.columbia.tjw.item.fit.GradientResult;
import edu.columbia.tjw.item.fit.ItemFitter;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.spark.ItemClassificationModel;
import edu.columbia.tjw.item.spark.ItemClassifierSettings;
import edu.columbia.tjw.item.spark.SparkGridAdapter;
import edu.columbia.tjw.item.util.EnumFamily;
import edu.columbia.tjw.item.util.random.RandomTool;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;

public class ItemClassifier
extends ProbabilisticClassifier<Vector, ItemClassifier, ItemClassificationModel>
implements Cloneable {
    private static final long serialVersionUID = 8990051165227355116L;
    private static final String INTERCEPT_NAME = "ITEM_INTERCEPT";
    private final ItemClassifierSettings _settings;
    private final ItemParameters<SimpleStatus, SimpleRegressor, StandardCurveType> _startingParams;
    private String _uid;

    public ItemClassifier(ItemClassifierSettings settings_) {
        this(settings_, null);
    }

    public ItemClassifier(ItemClassifierSettings settings_, ItemParameters<SimpleStatus, SimpleRegressor, StandardCurveType> startingParams_) {
        if (null == settings_) {
            throw new NullPointerException("Settings cannot be null.");
        }
        this._settings = settings_;
        this._startingParams = startingParams_;
    }

    public ItemClassifier copy(ParamMap paramMap_) {
        return (ItemClassifier)this.defaultCopy(paramMap_);
    }

    public ItemClassifierSettings getSettings() {
        return this._settings;
    }

    public RawFittingGrid<SimpleStatus, SimpleRegressor> generateMaterializedGrid(Dataset<?> data_) {
        return new RawFittingGrid(this.generateFitter(data_).getGrid());
    }

    private ItemStatusGrid<SimpleStatus, SimpleRegressor> generateGrid(Dataset<?> data_) {
        String featureCol = this.getFeaturesCol();
        String labelCol = this.getLabelCol();
        SparkGridAdapter data = new SparkGridAdapter(data_, labelCol, featureCol, this._settings.getRegressors(), this._settings.getFromStatus(), this._settings.getIntercept());
        return data;
    }

    private ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> generateFitter(Dataset<?> data_) {
        ItemStatusGrid<SimpleStatus, SimpleRegressor> data = this.generateGrid(data_);
        ItemFitter fitter = new ItemFitter(this._settings.getFactory(), (ItemRegressor)this._settings.getIntercept(), (ItemStatus)this._settings.getFromStatus(), data, this._settings.getSettings());
        return fitter;
    }

    public static Dataset<Row> prepareData(Dataset<?> data_, ItemClassifierSettings settings_, String featuresColumn_) {
        List<SimpleRegressor> regs = settings_.getRegressors();
        String[] regNames = new String[regs.size()];
        int pointer = 0;
        for (SimpleRegressor reg : regs) {
            regNames[pointer++] = reg.name();
        }
        Dataset withIntercept = data_.withColumn(INTERCEPT_NAME, functions.lit((Object)1.0));
        VectorAssembler assembler = new VectorAssembler();
        assembler.setInputCols(regNames);
        assembler.setOutputCol(featuresColumn_);
        Dataset withFeatures = assembler.transform(withIntercept);
        return withFeatures;
    }

    public static ItemClassifierSettings prepareSettings(Dataset<?> data_, String toStatusColumn_, List<String> featureList, Set<String> curveRegressors_, int maxParamCount_) {
        return ItemClassifier.prepareSettings(data_, toStatusColumn_, featureList, curveRegressors_, maxParamCount_, new ItemSettings());
    }

    public static ItemClassifierSettings prepareSettings(Dataset<?> data_, String toStatusColumn_, List<String> featureList, Set<String> curveRegressors_, int maxParamCount_, ItemSettings settings_) {
        Iterator iter = data_.select(toStatusColumn_, new String[0]).distinct().toLocalIterator();
        TreeSet<Integer> statSet = new TreeSet<Integer>();
        while (iter.hasNext()) {
            Row nextRow = (Row)iter.next();
            Object nextObj = nextRow.get(0);
            if (null == nextObj) continue;
            statSet.add(((Number)nextObj).intValue());
        }
        ArrayList<String> statList = new ArrayList<String>();
        for (Integer next : statSet) {
            statList.add(next.toString());
        }
        EnumFamily statFamily = SimpleStatus.generateFamily(statList);
        ArrayList<String> regList = new ArrayList<String>();
        regList.add(INTERCEPT_NAME);
        regList.addAll(featureList);
        HashSet distinctSet = new HashSet(regList);
        if (distinctSet.size() != regList.size()) {
            throw new RuntimeException("Non distinct features: " + regList.size());
        }
        if (!distinctSet.containsAll(curveRegressors_)) {
            throw new RuntimeException("All curve regressors must also be in the feature list.");
        }
        ItemClassifierSettings settings = new ItemClassifierSettings(settings_, INTERCEPT_NAME, (SimpleStatus)statFamily.getFromOrdinal(0), maxParamCount_, regList, curveRegressors_);
        return settings;
    }

    public GradientResult computeGradients(Dataset<?> data_, ItemClassificationModel model_) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> fitter = this.generateFitter(data_);
        return fitter.getCalculator().computeGradients(model_.getParams());
    }

    public FitResult<SimpleStatus, SimpleRegressor, StandardCurveType> computeFitResult(Dataset<?> data_, ItemClassificationModel model_) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> fitter = this.generateFitter(data_);
        return fitter.getCalculator().computeFitResult(model_.getParams(), null);
    }

    public ItemClassificationModel runAnnealing(Dataset<?> data_, ItemClassificationModel prevModel_) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> fitter = this.generateFitter(data_);
        try {
            fitter.pushParameters("PrevModel", prevModel_.getParams());
            fitter.runAnnealingByEntry(this._settings.getCurveRegressors(), true);
            FitResult fitResult = fitter.getChain().getLatestResults();
            ItemClassificationModel classificationModel = new ItemClassificationModel((FitResult<SimpleStatus, SimpleRegressor, StandardCurveType>)fitResult, this._settings);
            return classificationModel;
        }
        catch (ConvergenceException e) {
            throw new RuntimeException(e);
        }
    }

    public ItemClassificationModel retrainModel(Dataset<?> data_, ItemClassificationModel prevModel_) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> fitter = this.generateFitter(data_);
        fitter.pushParameters("PrevModel", prevModel_.getParams());
        int maxParams = this._settings.getMaxParamCount();
        try {
            int usedParams = fitter.getBestParameters().getEffectiveParamCount();
            int remainingParams = maxParams - usedParams;
            fitter.fitModel(this._settings.getNonCurveRegressors(), this._settings.getCurveRegressors(), remainingParams, false);
        }
        catch (ConvergenceException e) {
            throw new RuntimeException(e);
        }
        FitResult fitResult = fitter.getChain().getLatestResults();
        ItemClassificationModel classificationModel = new ItemClassificationModel((FitResult<SimpleStatus, SimpleRegressor, StandardCurveType>)fitResult, this._settings);
        return classificationModel;
    }

    public ItemClassificationModel train(Dataset<?> data_) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> fitter = this.generateFitter(data_);
        if (null != this._startingParams) {
            fitter.pushParameters("InitialParams", this._startingParams);
        }
        int maxParams = this._settings.getMaxParamCount();
        try {
            int usedParams = fitter.getBestParameters().getEffectiveParamCount();
            int remainingParams = maxParams - usedParams;
            fitter.fitModel(this._settings.getNonCurveRegressors(), this._settings.getCurveRegressors(), remainingParams, false);
        }
        catch (ConvergenceException e) {
            throw new RuntimeException(e);
        }
        FitResult fitResult = fitter.getChain().getLatestResults();
        ItemClassificationModel classificationModel = new ItemClassificationModel((FitResult<SimpleStatus, SimpleRegressor, StandardCurveType>)fitResult, this._settings);
        return classificationModel;
    }

    public synchronized String uid() {
        if (null == this._uid) {
            this._uid = RandomTool.randomString((int)64);
        }
        return this._uid;
    }
}

