/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.odkl;

import com.github.fommil.netlib.BLAS;
import org.apache.spark.ml.odkl.HasNetlibBlas;
import org.apache.spark.ml.odkl.HasNetlibBlas$class;
import org.apache.spark.ml.odkl.ModelWithSummary;
import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseMatrix;
import org.apache.spark.mllib.linalg.DenseMatrix$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.odkl.MatrixUtils$;
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import scala.Function1;
import scala.Function2;
import scala.Function3;
import scala.Function4;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;

public final class DSVRGD$
implements Serializable,
HasNetlibBlas {
    public static final DSVRGD$ MODULE$;
    private final ModelWithSummary.Block LossHistory;
    private final ModelWithSummary.Block WeightDiffHistory;
    private final ModelWithSummary.Block WeightNormHistory;

    static {
        new DSVRGD$();
    }

    @Override
    public BLAS f2jBLAS() {
        return HasNetlibBlas$class.f2jBLAS(this);
    }

    @Override
    public BLAS blas() {
        return HasNetlibBlas$class.blas(this);
    }

    @Override
    public void dscal(double a, double[] data) {
        HasNetlibBlas$class.dscal(this, a, data);
    }

    @Override
    public void axpy(double a, double[] x, double[] y) {
        HasNetlibBlas$class.axpy((HasNetlibBlas)this, a, x, y);
    }

    @Override
    public void axpy(double a, Vector x, double[] y) {
        HasNetlibBlas$class.axpy((HasNetlibBlas)this, a, x, y);
    }

    @Override
    public void copy(double[] x, double[] y) {
        HasNetlibBlas$class.copy(this, x, y);
    }

    public ModelWithSummary.Block LossHistory() {
        return this.LossHistory;
    }

    public ModelWithSummary.Block WeightDiffHistory() {
        return this.WeightDiffHistory;
    }

    public ModelWithSummary.Block WeightNormHistory() {
        return this.WeightNormHistory;
    }

    public void linear(Matrix weights, DenseMatrix features, DenseMatrix labels, DenseMatrix updateTerm, DenseMatrix marginCache, DenseVector lossCache) {
        BLAS$.MODULE$.gemm(1.0, weights, features, 0.0, marginCache);
        this.axpy(-1.0, labels.values(), marginCache.values());
        double multiplier = 1.0 / (double)features.numCols();
        BLAS$.MODULE$.gemm(multiplier, (Matrix)marginCache, features.transpose(), 0.0, updateTerm);
        marginCache.foreachActive((Function3)new Serializable(lossCache, multiplier){
            public static final long serialVersionUID = 0L;
            private final DenseVector lossCache$4;
            private final double multiplier$1;

            public final void apply(int label, int sample2, double v) {
                this.lossCache$4.values()[label] = this.lossCache$4.values()[label] + this.multiplier$1 * v * v;
            }
            {
                this.lossCache$4 = lossCache$4;
                this.multiplier$1 = multiplier$1;
            }
        });
    }

    public void logistic(Matrix weights, DenseMatrix features, DenseMatrix labels, DenseMatrix updateTerm, DenseMatrix marginCache, DenseVector lossCache) {
        BLAS$.MODULE$.gemm(-1.0, weights, features, 0.0, marginCache);
        double multiplier = 1.0 / (double)features.numCols();
        MatrixUtils$.MODULE$.applyNonZeros((Matrix)labels, marginCache, (Function4<Object, Object, Object, Object, Object>)new Serializable(lossCache, multiplier){
            public static final long serialVersionUID = 0L;
            private final DenseVector lossCache$5;
            private final double multiplier$2;

            public final double apply(int label, int sample2, double labelValue, double margin) {
                this.lossCache$5.values()[label] = this.lossCache$5.values()[label] + this.multiplier$2 * (MLUtils$.MODULE$.log1pExp(margin) - (1.0 - labelValue) * margin);
                return 1.0 / (1.0 + Math.exp(margin)) - labelValue;
            }
            {
                this.lossCache$5 = lossCache$5;
                this.multiplier$2 = multiplier$2;
            }
        });
        BLAS$.MODULE$.gemm(multiplier, (Matrix)marginCache, features.transpose(), 0.0, updateTerm);
    }

    public Matrix logisticInitialization(DataFrame data, int numLabels, int numFeatures) {
        RDD qual$2 = data.map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Vector apply(Row x$5) {
                return (Vector)x$5.getAs(0);
            }
        }, ClassTag$.MODULE$.apply(Vector.class));
        MultivariateOnlineSummarizer x$10 = new MultivariateOnlineSummarizer();
        Serializable x$11 = new Serializable(){
            public static final long serialVersionUID = 0L;

            public final MultivariateOnlineSummarizer apply(MultivariateOnlineSummarizer a, Vector v) {
                return a.add(v);
            }
        };
        Serializable x$12 = new Serializable(){
            public static final long serialVersionUID = 0L;

            public final MultivariateOnlineSummarizer apply(MultivariateOnlineSummarizer a, MultivariateOnlineSummarizer b) {
                return a.merge(b);
            }
        };
        int x$13 = qual$2.treeAggregate$default$4((Object)x$10);
        MultivariateOnlineSummarizer stat = (MultivariateOnlineSummarizer)qual$2.treeAggregate((Object)x$10, (Function2)x$11, (Function2)x$12, x$13, ClassTag$.MODULE$.apply(MultivariateOnlineSummarizer.class));
        return MatrixUtils$.MODULE$.transformDense(DenseMatrix$.MODULE$.zeros(numLabels, numFeatures), (Function3<Object, Object, Object, Object>)new Serializable(numFeatures, stat){
            public static final long serialVersionUID = 0L;
            private final int numFeatures$1;
            private final MultivariateOnlineSummarizer stat$1;

            public final double apply(int label, int feature, double weight) {
                return feature == this.numFeatures$1 - 1 ? Math.log(this.stat$1.numNonzeros().apply(label) / (double)this.stat$1.count()) : 0.0;
            }
            {
                this.numFeatures$1 = numFeatures$1;
                this.stat$1 = stat$1;
            }
        });
    }

    public double linearWeightsDistance(Matrix oldWeights, DenseMatrix newWeights, int label) {
        DoubleRef diff = new DoubleRef(0.0);
        DoubleRef sum = new DoubleRef(0.0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), newWeights.numCols()).foreach$mVc$sp((Function1)new Serializable(oldWeights, newWeights, label, diff, sum){
            public static final long serialVersionUID = 0L;
            private final Matrix oldWeights$1;
            private final DenseMatrix newWeights$3;
            private final int label$5;
            private final DoubleRef diff$1;
            private final DoubleRef sum$2;

            public final void apply(int j) {
                this.apply$mcVI$sp(j);
            }

            public void apply$mcVI$sp(int j) {
                this.sum$2.elem += this.newWeights$3.apply(this.label$5, j) * this.newWeights$3.apply(this.label$5, j);
                this.diff$1.elem += (this.oldWeights$1.apply(this.label$5, j) - this.newWeights$3.apply(this.label$5, j)) * (this.oldWeights$1.apply(this.label$5, j) - this.newWeights$3.apply(this.label$5, j));
            }
            {
                this.oldWeights$1 = oldWeights$1;
                this.newWeights$3 = newWeights$3;
                this.label$5 = label$5;
                this.diff$1 = diff$1;
                this.sum$2 = sum$2;
            }
        });
        return Math.sqrt(diff.elem) / Math.sqrt(sum.elem);
    }

    public double logisticWeightsDistance(Matrix oldWeights, DenseMatrix newWeights, int label) {
        DoubleRef cor = new DoubleRef(0.0);
        DoubleRef sumNew = new DoubleRef(0.0);
        DoubleRef sumOld = new DoubleRef(0.0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), newWeights.numCols()).foreach$mVc$sp((Function1)new Serializable(oldWeights, newWeights, label, cor, sumNew, sumOld){
            public static final long serialVersionUID = 0L;
            private final Matrix oldWeights$2;
            private final DenseMatrix newWeights$4;
            private final int label$6;
            private final DoubleRef cor$1;
            private final DoubleRef sumNew$1;
            private final DoubleRef sumOld$1;

            public final void apply(int j) {
                this.apply$mcVI$sp(j);
            }

            public void apply$mcVI$sp(int j) {
                this.sumNew$1.elem += this.newWeights$4.apply(this.label$6, j) * this.newWeights$4.apply(this.label$6, j);
                this.sumOld$1.elem += this.oldWeights$2.apply(this.label$6, j) * this.oldWeights$2.apply(this.label$6, j);
                this.cor$1.elem += this.oldWeights$2.apply(this.label$6, j) * this.newWeights$4.apply(this.label$6, j);
            }
            {
                this.oldWeights$2 = oldWeights$2;
                this.newWeights$4 = newWeights$4;
                this.label$6 = label$6;
                this.cor$1 = cor$1;
                this.sumNew$1 = sumNew$1;
                this.sumOld$1 = sumOld$1;
            }
        });
        return sumNew.elem * sumOld.elem > 0.0 ? 1.0 - cor.elem / Math.sqrt(sumNew.elem * sumOld.elem) : 2.0;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private DSVRGD$() {
        MODULE$ = this;
        HasNetlibBlas$class.$init$(this);
        this.LossHistory = new ModelWithSummary.Block("lossHistory");
        this.WeightDiffHistory = new ModelWithSummary.Block("weightDiffHistory");
        this.WeightNormHistory = new ModelWithSummary.Block("weightNormHistory");
    }
}

