/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import breeze.linalg.$times$;
import breeze.linalg.BroadcastedColumns$;
import breeze.linalg.Broadcaster$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.NumericOps;
import breeze.math.Semiring$;
import java.io.Serializable;
import java.util.Random;
import org.apache.spark.ml.ann.BreezeUtil$;
import org.apache.spark.ml.ann.GeneralIceLayerModel;
import org.apache.spark.ml.ann.IceAffineLayer;
import org.apache.spark.ml.ann.IceAffineLayerModel$;
import scala.Function1;
import scala.Predef$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005]e!B\u0010!\u0001\u0001R\u0003\u0002C\u001b\u0001\u0005\u000b\u0007I\u0011A\u001c\t\u0011\r\u0003!\u0011!Q\u0001\naB\u0001\u0002\u0012\u0001\u0003\u0006\u0004%\t!\u0012\u0005\t\u0013\u0002\u0011\t\u0011)A\u0005\r\"1!\n\u0001C\u0001A-Cqa\u0014\u0001C\u0002\u0013\u0005\u0001\u000b\u0003\u0004U\u0001\u0001\u0006I!\u0015\u0005\b+\u0002\u0011\r\u0011\"\u00018\u0011\u00191\u0006\u0001)A\u0005q!9q\u000b\u0001a\u0001\n\u00139\u0004b\u0002-\u0001\u0001\u0004%I!\u0017\u0005\u0007?\u0002\u0001\u000b\u0015\u0002\u001d\t\u000f\u0001\u0004\u0001\u0019!C\u0005C\"9!\r\u0001a\u0001\n\u0013\u0019\u0007BB3\u0001A\u0003&\u0011\u0007C\u0003g\u0001\u0011\u0005s\rC\u0003m\u0001\u0011\u0005S\u000eC\u0003t\u0001\u0011\u0005C\u000fC\u0004\u0002\u0004\u0001!\t%!\u0002\t\u000f\u0005=\u0001\u0001\"\u0011\u0002\u0012!9\u0011Q\u0005\u0001\u0005B\u0005\u001d\u0002bBA\u001e\u0001\u0011\u0005\u0013Q\b\u0005\b\u0003\u0003\u0002A\u0011IA\"\u0011\u001d\t9\u0005\u0001C\u0001\u0003\u0013:\u0001\"!\u0014!\u0011\u0003\u0001\u0013q\n\u0004\b?\u0001B\t\u0001IA)\u0011\u0019Q%\u0004\"\u0001\u0002Z!9\u00111\f\u000e\u0005\u0002\u0005u\u0003bBA<5\u0011\u0005\u0011\u0011\u0010\u0005\n\u0003\u000fS\u0012\u0011!C\u0005\u0003\u0013\u00131#S2f\u0003\u001a4\u0017N\\3MCf,'/T8eK2T!!\t\u0012\u0002\u0007\u0005tgN\u0003\u0002$I\u0005\u0011Q\u000e\u001c\u0006\u0003K\u0019\nQa\u001d9be.T!a\n\u0015\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005I\u0013aA8sON\u0019\u0001aK\u0019\u0011\u00051zS\"A\u0017\u000b\u00039\nQa]2bY\u0006L!\u0001M\u0017\u0003\r\u0005s\u0017PU3g!\t\u00114'D\u0001!\u0013\t!\u0004E\u0001\u000bHK:,'/\u00197JG\u0016d\u0015-_3s\u001b>$W\r\\\u0001\bo\u0016Lw\r\u001b;t\u0007\u0001)\u0012\u0001\u000f\t\u0004sy\u0002U\"\u0001\u001e\u000b\u0005mb\u0014A\u00027j]\u0006dwMC\u0001>\u0003\u0019\u0011'/Z3{K&\u0011qH\u000f\u0002\f\t\u0016t7/\u001a,fGR|'\u000f\u0005\u0002-\u0003&\u0011!)\f\u0002\u0007\t>,(\r\\3\u0002\u0011],\u0017n\u001a5ug\u0002\nQ\u0001\\1zKJ,\u0012A\u0012\t\u0003e\u001dK!\u0001\u0013\u0011\u0003\u001d%\u001bW-\u00114gS:,G*Y=fe\u00061A.Y=fe\u0002\na\u0001P5oSRtDc\u0001'N\u001dB\u0011!\u0007\u0001\u0005\u0006k\u0015\u0001\r\u0001\u000f\u0005\u0006\t\u0016\u0001\rAR\u0001\u0002oV\t\u0011\u000bE\u0002:%\u0002K!a\u0015\u001e\u0003\u0017\u0011+gn]3NCR\u0014\u0018\u000e_\u0001\u0003o\u0002\n\u0011AY\u0001\u0003E\u0002\nAa\u001c8fg\u0006AqN\\3t?\u0012*\u0017\u000f\u0006\u0002[;B\u0011AfW\u0005\u000396\u0012A!\u00168ji\"9alCA\u0001\u0002\u0004A\u0014a\u0001=%c\u0005)qN\\3tA\u0005Ia.\u001a=u\u0019\u0006LXM]\u000b\u0002c\u0005ia.\u001a=u\u0019\u0006LXM]0%KF$\"A\u00173\t\u000fys\u0011\u0011!a\u0001c\u0005Qa.\u001a=u\u0019\u0006LXM\u001d\u0011\u0002\t\u00154\u0018\r\u001c\u000b\u00045\"T\u0007\"B5\u0011\u0001\u0004\t\u0016\u0001\u00023bi\u0006DQa\u001b\tA\u0002E\u000baa\\;uaV$\u0018\u0001E2p[B,H/\u001a)sKZ$U\r\u001c;b)\u0011Qf\u000e]9\t\u000b=\f\u0002\u0019A)\u0002\u000b\u0011,G\u000e^1\t\u000b-\f\u0002\u0019A)\t\u000bI\f\u0002\u0019A)\u0002\u0013A\u0014XM\u001e#fYR\f\u0017AC:j]\u001edWm\u0012:bIR1!,\u001e<y{~DQa\u001c\nA\u0002ECQa\u001e\nA\u0002E\u000bQ!\u001b8qkRDQ!\u001f\nA\u0002i\f\u0011!\u001c\t\u0003YmL!\u0001`\u0017\u0003\u0007%sG\u000fC\u0003\u007f%\u0001\u0007\u0011+A\u0006xK&<\u0007\u000e^$sC\u0012\u0014\u0005BBA\u0001%\u0001\u0007\u0001(A\u0005cS\u0006\u001cxI]1e\u0005\u0006!qM]1e)\u001dQ\u0016qAA\u0005\u0003\u0017AQa\\\nA\u0002ECQa^\nA\u0002ECa!!\u0004\u0014\u0001\u0004A\u0014aB2v[\u001e\u0013\u0018\rZ\u0001\u0006OJ\fGM\r\u000b\u000e5\u0006M\u0011QCA\r\u0003;\ty\"!\t\t\u000b=$\u0002\u0019A)\t\r\u0005]A\u00031\u0001R\u0003%qW\r\u001f;EK2$\u0018\r\u0003\u0004\u0002\u001cQ\u0001\r!U\u0001\u0006O\u0006lW.\u0019\u0005\u0006oR\u0001\r!\u0015\u0005\u0006WR\u0001\r!\u0015\u0005\u0007\u0003G!\u0002\u0019\u0001\u001d\u0002\u000b\r,Xn\u0012\u001a\u00021\r|W\u000e];uKB\u0013XM\u001e#fYR\fW\t\u001f9b]\u0012,G\rF\b[\u0003S\tY#!\f\u00020\u0005M\u0012QGA\u001c\u0011\u0015yW\u00031\u0001R\u0011\u0019\t9\"\u0006a\u0001#\"1\u00111D\u000bA\u0002ECa!!\r\u0016\u0001\u0004\t\u0016A\u00039sKZ|U\u000f\u001e9vi\")1.\u0006a\u0001#\")!/\u0006a\u0001#\"1\u0011\u0011H\u000bA\u0002E\u000b\u0011\u0002\u001d:fm\u001e\u000bW.\\1\u0002\u001f\u0005\u001cG/\u001b<bi&|g\u000eR3sSZ$2\u0001QA \u0011\u00159h\u00031\u0001A\u0003U\t7\r^5wCRLwN\\*fG>tG\rR3sSZ$2\u0001QA#\u0011\u00159x\u00031\u0001A\u00031\u0019X\r\u001e(fqRd\u0015-_3s)\rQ\u00161\n\u0005\u0006Ab\u0001\r!M\u0001\u0014\u0013\u000e,\u0017I\u001a4j]\u0016d\u0015-_3s\u001b>$W\r\u001c\t\u0003ei\u0019BAG\u0016\u0002TA\u0019A&!\u0016\n\u0007\u0005]SF\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002P\u0005)\u0011\r\u001d9msR9A*a\u0018\u0002b\u0005\r\u0004\"\u0002#\u001d\u0001\u00041\u0005\"B\u001b\u001d\u0001\u0004A\u0004bBA39\u0001\u0007\u0011qM\u0001\u0007e\u0006tGm\\7\u0011\t\u0005%\u00141O\u0007\u0003\u0003WRA!!\u001c\u0002p\u0005!Q\u000f^5m\u0015\t\t\t(\u0001\u0003kCZ\f\u0017\u0002BA;\u0003W\u0012aAU1oI>l\u0017!\u0004:b]\u0012|WnV3jO\"$8\u000fF\u0005[\u0003w\ny(a!\u0002\u0006\"1\u0011QP\u000fA\u0002i\fQA\\;n\u0013:Da!!!\u001e\u0001\u0004Q\u0018A\u00028v[>+H\u000fC\u00036;\u0001\u0007\u0001\bC\u0004\u0002fu\u0001\r!a\u001a\u0002\u0017I,\u0017\r\u001a*fg>dg/\u001a\u000b\u0003\u0003\u0017\u0003B!!$\u0002\u00146\u0011\u0011q\u0012\u0006\u0005\u0003#\u000by'\u0001\u0003mC:<\u0017\u0002BAK\u0003\u001f\u0013aa\u00142kK\u000e$\b")
public class IceAffineLayerModel
implements GeneralIceLayerModel {
    private final DenseVector<Object> weights;
    private final IceAffineLayer layer;
    private final DenseMatrix<Object> w;
    private final DenseVector<Object> b;
    private DenseVector<Object> ones;
    private GeneralIceLayerModel nextLayer;

    public static void randomWeights(int n, int n2, DenseVector<Object> denseVector, Random random) {
        IceAffineLayerModel$.MODULE$.randomWeights(n, n2, denseVector, random);
    }

    public static IceAffineLayerModel apply(IceAffineLayer iceAffineLayer, DenseVector<Object> denseVector, Random random) {
        return IceAffineLayerModel$.MODULE$.apply(iceAffineLayer, denseVector, random);
    }

    public DenseVector<Object> weights() {
        return this.weights;
    }

    public IceAffineLayer layer() {
        return this.layer;
    }

    public DenseMatrix<Object> w() {
        return this.w;
    }

    public DenseVector<Object> b() {
        return this.b;
    }

    private DenseVector<Object> ones() {
        return this.ones;
    }

    private void ones_$eq(DenseVector<Object> x$1) {
        this.ones = x$1;
    }

    private GeneralIceLayerModel nextLayer() {
        return this.nextLayer;
    }

    private void nextLayer_$eq(GeneralIceLayerModel x$1) {
        this.nextLayer = x$1;
    }

    public void eval(DenseMatrix<Object> data, DenseMatrix<Object> output) {
        ((NumericOps)output.apply((Object)package$.MODULE$.$colon$colon(), (Object)$times$.MODULE$, Broadcaster$.MODULE$.canBroadcastColumns(DenseMatrix$.MODULE$.handholdCanMapRows()))).$colon$eq(this.b(), BroadcastedColumns$.MODULE$.broadcastInplaceOp2(DenseMatrix$.MODULE$.handholdCanMapRows(), DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet(), DenseMatrix$.MODULE$.canTraverseCols()));
        BreezeUtil$.MODULE$.dgemm(1.0, this.w(), data, 1.0, output);
    }

    public void computePrevDelta(DenseMatrix<Object> delta, DenseMatrix<Object> output, DenseMatrix<Object> prevDelta) {
        BreezeUtil$.MODULE$.dgemm(1.0, (DenseMatrix)this.w().t(DenseMatrix$.MODULE$.canTranspose()), delta, 0.0, prevDelta);
    }

    @Override
    public void singleGrad(DenseMatrix<Object> delta, DenseMatrix<Object> input, int m, DenseMatrix<Object> weightGradB, DenseVector<Object> biasGradB) {
        for (int i = 0; i < weightGradB.rows(); ++i) {
            double delta_i = delta.apply$mcD$sp(i, m);
            biasGradB.update$mcD$sp(i, delta_i);
            for (int k = 0; k < weightGradB.cols(); ++k) {
                double a_k = input.apply$mcD$sp(k, m);
                weightGradB.update$mcD$sp(i, k, delta_i * a_k);
            }
        }
    }

    public void grad(DenseMatrix<Object> delta, DenseMatrix<Object> input, DenseVector<Object> cumGrad) {
        DenseMatrix.mcD.sp cumGradientOfWeights = new DenseMatrix.mcD.sp(this.w().rows(), this.w().cols(), cumGrad.data$mcD$sp(), cumGrad.offset());
        BreezeUtil$.MODULE$.dgemm(1.0 / (double)input.cols(), delta, (DenseMatrix)input.t(DenseMatrix$.MODULE$.canTranspose()), 1.0, (DenseMatrix)cumGradientOfWeights);
        if (this.ones() == null || this.ones().length() != delta.cols()) {
            this.ones_$eq((DenseVector<Object>)DenseVector$.MODULE$.ones$mDc$sp(delta.cols(), ClassTag$.MODULE$.Double(), Semiring$.MODULE$.semiringD()));
        }
        DenseVector.mcD.sp cumGradientOfBias = new DenseVector.mcD.sp(cumGrad.data$mcD$sp(), cumGrad.offset() + this.w().size(), 1, this.b().length());
        BreezeUtil$.MODULE$.dgemv(1.0 / (double)input.cols(), delta, this.ones(), 1.0, (DenseVector)cumGradientOfBias);
    }

    @Override
    public void grad2(DenseMatrix<Object> delta, DenseMatrix<Object> nextDelta, DenseMatrix<Object> gamma, DenseMatrix<Object> input, DenseMatrix<Object> output, DenseVector<Object> cumG2) {
        DenseMatrix.mcD.sp cumG2ofWeights = new DenseMatrix.mcD.sp(this.w().rows(), this.w().cols(), cumG2.data$mcD$sp(), cumG2.offset());
        DenseVector.mcD.sp cumG2ofBias = new DenseVector.mcD.sp(cumG2.data$mcD$sp(), cumG2.offset() + this.w().size(), 1, this.b().length());
        DenseMatrix<Object> targetDelta = null;
        targetDelta = nextDelta != null ? nextDelta : delta;
        for (int m = 0; m < gamma.cols(); ++m) {
            for (int i = 0; i < cumG2ofWeights.rows(); ++i) {
                double gamma_i = gamma.apply$mcD$sp(i, m);
                double prevOutput_i = output.apply$mcD$sp(i, m);
                double fprime_i = this.nextLayer().activationDeriv(prevOutput_i);
                double fprime2_i = this.nextLayer().activationSecondDeriv(prevOutput_i);
                double delta_i = targetDelta.apply$mcD$sp(i, m);
                double scale = gamma_i * fprime_i * fprime_i + delta_i * fprime2_i;
                int n = i;
                cumG2ofBias.update$mcD$sp(n, cumG2ofBias.apply$mcD$sp(n) + scale);
                int k = 0;
                while (k < cumG2ofWeights.cols()) {
                    double a_k = input.apply$mcD$sp(k, m);
                    double computedG2 = scale * a_k * a_k;
                    int n2 = i;
                    int n3 = k++;
                    cumG2ofWeights.update$mcD$sp(n2, n3, cumG2ofWeights.apply$mcD$sp(n2, n3) + computedG2);
                }
            }
        }
        double invObsCount = 1.0 / (double)gamma.cols();
        for (int i = 0; i < cumG2.length(); ++i) {
            cumG2.update$mcD$sp(i, cumG2.apply$mcD$sp(i) * invObsCount);
        }
    }

    @Override
    public void computePrevDeltaExpanded(DenseMatrix<Object> delta, DenseMatrix<Object> nextDelta, DenseMatrix<Object> gamma, DenseMatrix<Object> prevOutput, DenseMatrix<Object> output, DenseMatrix<Object> prevDelta, DenseMatrix<Object> prevGamma) {
        this.computePrevDelta(delta, output, prevDelta);
        DenseMatrix<Object> targetDelta = null;
        boolean nextIsLoss = false;
        if (nextDelta != null) {
            targetDelta = nextDelta;
        } else {
            targetDelta = delta;
            nextIsLoss = true;
        }
        IntRef m = IntRef.create((int)0);
        while (m.elem < gamma.cols()) {
            IntRef i = IntRef.create((int)0);
            while (i.elem < gamma.rows()) {
                double gamma_i = gamma.apply$mcD$sp(i.elem, m.elem);
                double prevOutput_i = prevOutput.apply$mcD$sp(i.elem, m.elem);
                double fprime_i = this.nextLayer().activationDeriv(prevOutput_i);
                double fprime2_i = this.nextLayer().activationSecondDeriv(prevOutput_i);
                double delta_i = targetDelta.apply$mcD$sp(i.elem, m.elem);
                double deltaTerm = delta_i * fprime2_i;
                double gammaTerm = gamma_i * fprime_i * fprime_i;
                double scale = gammaTerm + deltaTerm;
                if (!nextIsLoss) {
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.w().cols()).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)k -> {
                        double currWeight = this.w().apply$mcD$sp(i$1.elem, k);
                        double computedGamma = scale * currWeight * currWeight;
                        int n = m$1.elem;
                        prevGamma.update$mcD$sp(k, n, prevGamma.apply$mcD$sp(k, n) + computedGamma);
                    });
                } else {
                    int k2 = 0;
                    while (k2 < this.w().cols()) {
                        double w_k = this.w().apply$mcD$sp(i.elem, k2);
                        int n = k2;
                        int n2 = m.elem;
                        prevGamma.update$mcD$sp(n, n2, prevGamma.apply$mcD$sp(n, n2) + deltaTerm * w_k * w_k);
                        double subSum = 0.0;
                        for (int u = 0; u < gamma.rows(); ++u) {
                            double w_u = this.w().apply$mcD$sp(u, k2);
                            double output_u = prevOutput.apply$mcD$sp(u, m.elem);
                            double kron = 0.0;
                            if (u == i.elem) {
                                kron = 1.0;
                            }
                            double elem = (kron - output_u) * prevOutput_i * w_u;
                            subSum += elem;
                        }
                        int n3 = k2++;
                        int n4 = m.elem;
                        prevGamma.update$mcD$sp(n3, n4, prevGamma.apply$mcD$sp(n3, n4) + subSum * w_k);
                    }
                }
                ++i.elem;
            }
            ++m.elem;
        }
    }

    @Override
    public double activationDeriv(double input) {
        return 1.0;
    }

    @Override
    public double activationSecondDeriv(double input) {
        return 0.0;
    }

    @Override
    public void setNextLayer(GeneralIceLayerModel nextLayer) {
        this.nextLayer_$eq(nextLayer);
    }

    public IceAffineLayerModel(DenseVector<Object> weights, IceAffineLayer layer) {
        this.weights = weights;
        this.layer = layer;
        this.w = new DenseMatrix.mcD.sp(layer.numOut(), layer.numIn(), weights.data$mcD$sp(), weights.offset());
        this.b = new DenseVector.mcD.sp(weights.data$mcD$sp(), weights.offset() + layer.numOut() * layer.numIn(), 1, layer.numOut());
        this.ones = null;
        this.nextLayer = null;
    }
}

