/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import breeze.linalg.$times$;
import breeze.linalg.BroadcastedColumns$;
import breeze.linalg.Broadcaster$;
import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.NumericOps;
import breeze.math.Semiring$;
import java.util.Random;
import org.apache.spark.ml.ann.BreezeUtil$;
import org.apache.spark.ml.ann.GeneralIceLayerModel;
import org.apache.spark.ml.ann.IceAffineLayer;
import org.apache.spark.ml.ann.IceAffineLayerModel$;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

@ScalaSignature(bytes="\u0006\u0001\u0005]e!B\u0001\u0003\u0001\ta!aE%dK\u00063g-\u001b8f\u0019\u0006LXM]'pI\u0016d'BA\u0002\u0005\u0003\r\tgN\u001c\u0006\u0003\u000b\u0019\t!!\u001c7\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c2\u0001A\u0007\u0014!\tq\u0011#D\u0001\u0010\u0015\u0005\u0001\u0012!B:dC2\f\u0017B\u0001\n\u0010\u0005\u0019\te.\u001f*fMB\u0011A#F\u0007\u0002\u0005%\u0011aC\u0001\u0002\u0015\u000f\u0016tWM]1m\u0013\u000e,G*Y=fe6{G-\u001a7\t\u0011a\u0001!Q1A\u0005\u0002i\tqa^3jO\"$8o\u0001\u0001\u0016\u0003m\u00012\u0001H\u0011$\u001b\u0005i\"B\u0001\u0010 \u0003\u0019a\u0017N\\1mO*\t\u0001%\u0001\u0004ce\u0016,'0Z\u0005\u0003Eu\u00111\u0002R3og\u00164Vm\u0019;peB\u0011a\u0002J\u0005\u0003K=\u0011a\u0001R8vE2,\u0007\u0002C\u0014\u0001\u0005\u0003\u0005\u000b\u0011B\u000e\u0002\u0011],\u0017n\u001a5ug\u0002B\u0001\"\u000b\u0001\u0003\u0006\u0004%\tAK\u0001\u0006Y\u0006LXM]\u000b\u0002WA\u0011A\u0003L\u0005\u0003[\t\u0011a\"S2f\u0003\u001a4\u0017N\\3MCf,'\u000f\u0003\u00050\u0001\t\u0005\t\u0015!\u0003,\u0003\u0019a\u0017-_3sA!1\u0011\u0007\u0001C\u0001\u0005I\na\u0001P5oSRtDcA\u001a5kA\u0011A\u0003\u0001\u0005\u00061A\u0002\ra\u0007\u0005\u0006SA\u0002\ra\u000b\u0005\bo\u0001\u0011\r\u0011\"\u00019\u0003\u00059X#A\u001d\u0011\u0007qQ4%\u0003\u0002<;\tYA)\u001a8tK6\u000bGO]5y\u0011\u0019i\u0004\u0001)A\u0005s\u0005\u0011q\u000f\t\u0005\b\u007f\u0001\u0011\r\u0011\"\u0001\u001b\u0003\u0005\u0011\u0007BB!\u0001A\u0003%1$\u0001\u0002cA!91\t\u0001a\u0001\n\u0013Q\u0012\u0001B8oKNDq!\u0012\u0001A\u0002\u0013%a)\u0001\u0005p]\u0016\u001cx\fJ3r)\t9%\n\u0005\u0002\u000f\u0011&\u0011\u0011j\u0004\u0002\u0005+:LG\u000fC\u0004L\t\u0006\u0005\t\u0019A\u000e\u0002\u0007a$\u0013\u0007\u0003\u0004N\u0001\u0001\u0006KaG\u0001\u0006_:,7\u000f\t\u0005\b\u001f\u0002\u0001\r\u0011\"\u0003Q\u0003%qW\r\u001f;MCf,'/F\u0001\u0014\u0011\u001d\u0011\u0006\u00011A\u0005\nM\u000bQB\\3yi2\u000b\u00170\u001a:`I\u0015\fHCA$U\u0011\u001dY\u0015+!AA\u0002MAaA\u0016\u0001!B\u0013\u0019\u0012A\u00038fqRd\u0015-_3sA!)\u0001\f\u0001C!3\u0006!QM^1m)\r9%\f\u0018\u0005\u00067^\u0003\r!O\u0001\u0005I\u0006$\u0018\rC\u0003^/\u0002\u0007\u0011(\u0001\u0004pkR\u0004X\u000f\u001e\u0005\u0006?\u0002!\t\u0005Y\u0001\u0011G>l\u0007/\u001e;f!J,g\u000fR3mi\u0006$BaR1dI\")!M\u0018a\u0001s\u0005)A-\u001a7uC\")QL\u0018a\u0001s!)QM\u0018a\u0001s\u0005I\u0001O]3w\t\u0016dG/\u0019\u0005\u0006O\u0002!\t\u0005[\u0001\u000bg&tw\r\\3He\u0006$GCB$jU2\f8\u000fC\u0003cM\u0002\u0007\u0011\bC\u0003lM\u0002\u0007\u0011(A\u0003j]B,H\u000fC\u0003nM\u0002\u0007a.A\u0001n!\tqq.\u0003\u0002q\u001f\t\u0019\u0011J\u001c;\t\u000bI4\u0007\u0019A\u001d\u0002\u0017],\u0017n\u001a5u\u000fJ\fGM\u0011\u0005\u0006i\u001a\u0004\raG\u0001\nE&\f7o\u0012:bI\nCQA\u001e\u0001\u0005B]\fAa\u001a:bIR!q\t_={\u0011\u0015\u0011W\u000f1\u0001:\u0011\u0015YW\u000f1\u0001:\u0011\u0015YX\u000f1\u0001\u001c\u0003\u001d\u0019W/\\$sC\u0012DQ! \u0001\u0005By\fQa\u001a:bIJ\"BbR@\u0002\u0002\u0005\u0015\u0011\u0011BA\u0006\u0003\u001bAQA\u0019?A\u0002eBa!a\u0001}\u0001\u0004I\u0014!\u00038fqR$U\r\u001c;b\u0011\u0019\t9\u0001 a\u0001s\u0005)q-Y7nC\")1\u000e a\u0001s!)Q\f a\u0001s!1\u0011q\u0002?A\u0002m\tQaY;n\u000fJBq!a\u0005\u0001\t\u0003\n)\"\u0001\rd_6\u0004X\u000f^3Qe\u00164H)\u001a7uC\u0016C\b/\u00198eK\u0012$rbRA\f\u00033\tY\"!\b\u0002\"\u0005\r\u0012Q\u0005\u0005\u0007E\u0006E\u0001\u0019A\u001d\t\u000f\u0005\r\u0011\u0011\u0003a\u0001s!9\u0011qAA\t\u0001\u0004I\u0004bBA\u0010\u0003#\u0001\r!O\u0001\u000baJ,goT;uaV$\bBB/\u0002\u0012\u0001\u0007\u0011\b\u0003\u0004f\u0003#\u0001\r!\u000f\u0005\b\u0003O\t\t\u00021\u0001:\u0003%\u0001(/\u001a<HC6l\u0017\rC\u0004\u0002,\u0001!\t%!\f\u0002\u001f\u0005\u001cG/\u001b<bi&|g\u000eR3sSZ$2aIA\u0018\u0011\u0019Y\u0017\u0011\u0006a\u0001G!9\u00111\u0007\u0001\u0005B\u0005U\u0012!F1di&4\u0018\r^5p]N+7m\u001c8e\t\u0016\u0014\u0018N\u001e\u000b\u0004G\u0005]\u0002BB6\u00022\u0001\u00071\u0005C\u0004\u0002<\u0001!\t!!\u0010\u0002\u0019M,GOT3yi2\u000b\u00170\u001a:\u0015\u0007\u001d\u000by\u0004\u0003\u0004P\u0003s\u0001\raE\u0004\t\u0003\u0007\u0012\u0001\u0012\u0001\u0002\u0002F\u0005\u0019\u0012jY3BM\u001aLg.\u001a'bs\u0016\u0014Xj\u001c3fYB\u0019A#a\u0012\u0007\u000f\u0005\u0011\u0001\u0012\u0001\u0002\u0002JM)\u0011qI\u0007\u0002LA\u0019a\"!\u0014\n\u0007\u0005=sB\u0001\u0007TKJL\u0017\r\\5{C\ndW\rC\u00042\u0003\u000f\"\t!a\u0015\u0015\u0005\u0005\u0015\u0003\u0002CA,\u0003\u000f\"\t!!\u0017\u0002\u000b\u0005\u0004\b\u000f\\=\u0015\u000fM\nY&!\u0018\u0002`!1\u0011&!\u0016A\u0002-Ba\u0001GA+\u0001\u0004Y\u0002\u0002CA1\u0003+\u0002\r!a\u0019\u0002\rI\fg\u000eZ8n!\u0011\t)'a\u001c\u000e\u0005\u0005\u001d$\u0002BA5\u0003W\nA!\u001e;jY*\u0011\u0011QN\u0001\u0005U\u00064\u0018-\u0003\u0003\u0002r\u0005\u001d$A\u0002*b]\u0012|W\u000e\u0003\u0005\u0002v\u0005\u001dC\u0011AA<\u00035\u0011\u0018M\u001c3p[^+\u0017n\u001a5ugRIq)!\u001f\u0002~\u0005\u0005\u00151\u0011\u0005\b\u0003w\n\u0019\b1\u0001o\u0003\u0015qW/\\%o\u0011\u001d\ty(a\u001dA\u00029\faA\\;n\u001fV$\bB\u0002\r\u0002t\u0001\u00071\u0004\u0003\u0005\u0002b\u0005M\u0004\u0019AA2\u0011)\t9)a\u0012\u0002\u0002\u0013%\u0011\u0011R\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002\fB!\u0011QRAJ\u001b\t\tyI\u0003\u0003\u0002\u0012\u0006-\u0014\u0001\u00027b]\u001eLA!!&\u0002\u0010\n1qJ\u00196fGR\u0004")
public class IceAffineLayerModel
implements GeneralIceLayerModel {
    private final DenseVector<Object> weights;
    private final IceAffineLayer layer;
    private final DenseMatrix<Object> w;
    private final DenseVector<Object> b;
    private DenseVector<Object> ones;
    private GeneralIceLayerModel nextLayer;

    public static void randomWeights(int n, int n2, DenseVector<Object> denseVector, Random random) {
        IceAffineLayerModel$.MODULE$.randomWeights(n, n2, denseVector, random);
    }

    public static IceAffineLayerModel apply(IceAffineLayer iceAffineLayer, DenseVector<Object> denseVector, Random random) {
        return IceAffineLayerModel$.MODULE$.apply(iceAffineLayer, denseVector, random);
    }

    public DenseVector<Object> weights() {
        return this.weights;
    }

    public IceAffineLayer layer() {
        return this.layer;
    }

    public DenseMatrix<Object> w() {
        return this.w;
    }

    public DenseVector<Object> b() {
        return this.b;
    }

    private DenseVector<Object> ones() {
        return this.ones;
    }

    private void ones_$eq(DenseVector<Object> x$1) {
        this.ones = x$1;
    }

    private GeneralIceLayerModel nextLayer() {
        return this.nextLayer;
    }

    private void nextLayer_$eq(GeneralIceLayerModel x$1) {
        this.nextLayer = x$1;
    }

    public void eval(DenseMatrix<Object> data, DenseMatrix<Object> output) {
        ((NumericOps)output.apply((Object)package$.MODULE$.$colon$colon(), (Object)$times$.MODULE$, Broadcaster$.MODULE$.canBroadcastColumns(DenseMatrix$.MODULE$.handholdCanMapRows()))).$colon$eq(this.b(), BroadcastedColumns$.MODULE$.broadcastInplaceOp2(DenseMatrix$.MODULE$.handholdCanMapRows(), DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet(), DenseMatrix$.MODULE$.canTraverseCols()));
        BreezeUtil$.MODULE$.dgemm(1.0, this.w(), data, 1.0, output);
    }

    public void computePrevDelta(DenseMatrix<Object> delta, DenseMatrix<Object> output, DenseMatrix<Object> prevDelta) {
        BreezeUtil$.MODULE$.dgemm(1.0, (DenseMatrix)this.w().t(DenseMatrix$.MODULE$.canTranspose()), delta, 0.0, prevDelta);
    }

    @Override
    public void singleGrad(DenseMatrix<Object> delta, DenseMatrix<Object> input, int m, DenseMatrix<Object> weightGradB, DenseVector<Object> biasGradB) {
        for (int i = 0; i < weightGradB.rows(); ++i) {
            double delta_i = delta.apply$mcD$sp(i, m);
            biasGradB.update$mcD$sp(i, delta_i);
            for (int k = 0; k < weightGradB.cols(); ++k) {
                double a_k = input.apply$mcD$sp(k, m);
                weightGradB.update$mcD$sp(i, k, delta_i * a_k);
            }
        }
    }

    public void grad(DenseMatrix<Object> delta, DenseMatrix<Object> input, DenseVector<Object> cumGrad) {
        DenseMatrix.mcD.sp cumGradientOfWeights = new DenseMatrix.mcD.sp(this.w().rows(), this.w().cols(), cumGrad.data$mcD$sp(), cumGrad.offset());
        BreezeUtil$.MODULE$.dgemm(1.0 / (double)input.cols(), delta, (DenseMatrix)input.t(DenseMatrix$.MODULE$.canTranspose()), 1.0, (DenseMatrix)cumGradientOfWeights);
        if (this.ones() == null || this.ones().length() != delta.cols()) {
            this.ones_$eq((DenseVector<Object>)DenseVector$.MODULE$.ones$mDc$sp(delta.cols(), ClassTag$.MODULE$.Double(), Semiring$.MODULE$.semiringD()));
        }
        DenseVector.mcD.sp cumGradientOfBias = new DenseVector.mcD.sp(cumGrad.data$mcD$sp(), cumGrad.offset() + this.w().size(), 1, this.b().length());
        BreezeUtil$.MODULE$.dgemv(1.0 / (double)input.cols(), delta, this.ones(), 1.0, (DenseVector)cumGradientOfBias);
    }

    @Override
    public void grad2(DenseMatrix<Object> delta, DenseMatrix<Object> nextDelta, DenseMatrix<Object> gamma, DenseMatrix<Object> input, DenseMatrix<Object> output, DenseVector<Object> cumG2) {
        DenseMatrix.mcD.sp cumG2ofWeights = new DenseMatrix.mcD.sp(this.w().rows(), this.w().cols(), cumG2.data$mcD$sp(), cumG2.offset());
        DenseVector.mcD.sp cumG2ofBias = new DenseVector.mcD.sp(cumG2.data$mcD$sp(), cumG2.offset() + this.w().size(), 1, this.b().length());
        DenseMatrix<Object> targetDelta = null;
        targetDelta = nextDelta == null ? delta : nextDelta;
        for (int m = 0; m < gamma.cols(); ++m) {
            for (int i = 0; i < cumG2ofWeights.rows(); ++i) {
                double gamma_i = gamma.apply$mcD$sp(i, m);
                double prevOutput_i = output.apply$mcD$sp(i, m);
                double fprime_i = this.nextLayer().activationDeriv(prevOutput_i);
                double fprime2_i = this.nextLayer().activationSecondDeriv(prevOutput_i);
                double delta_i = targetDelta.apply$mcD$sp(i, m);
                double scale = gamma_i * fprime_i * fprime_i + delta_i * fprime2_i;
                int n = i;
                cumG2ofBias.update$mcD$sp(n, cumG2ofBias.apply$mcD$sp(n) + scale);
                int k = 0;
                while (k < cumG2ofWeights.cols()) {
                    double a_k = input.apply$mcD$sp(k, m);
                    double computedG2 = scale * a_k * a_k;
                    int n2 = i;
                    int n3 = k++;
                    cumG2ofWeights.update$mcD$sp(n2, n3, cumG2ofWeights.apply$mcD$sp(n2, n3) + computedG2);
                }
            }
        }
        double invObsCount = 1.0 / (double)gamma.cols();
        for (int i = 0; i < cumG2.length(); ++i) {
            cumG2.update$mcD$sp(i, cumG2.apply$mcD$sp(i) * invObsCount);
        }
    }

    @Override
    public void computePrevDeltaExpanded(DenseMatrix<Object> delta, DenseMatrix<Object> nextDelta, DenseMatrix<Object> gamma, DenseMatrix<Object> prevOutput, DenseMatrix<Object> output, DenseMatrix<Object> prevDelta, DenseMatrix<Object> prevGamma) {
        this.computePrevDelta(delta, output, prevDelta);
        DenseMatrix<Object> targetDelta = null;
        boolean nextIsLoss = false;
        if (nextDelta == null) {
            targetDelta = delta;
            nextIsLoss = true;
        } else {
            targetDelta = nextDelta;
        }
        IntRef m = IntRef.create((int)0);
        while (m.elem < gamma.cols()) {
            IntRef i = IntRef.create((int)0);
            while (i.elem < gamma.rows()) {
                double gamma_i = gamma.apply$mcD$sp(i.elem, m.elem);
                double prevOutput_i = prevOutput.apply$mcD$sp(i.elem, m.elem);
                double fprime_i = this.nextLayer().activationDeriv(prevOutput_i);
                double fprime2_i = this.nextLayer().activationSecondDeriv(prevOutput_i);
                double delta_i = targetDelta.apply$mcD$sp(i.elem, m.elem);
                double deltaTerm = delta_i * fprime2_i;
                double gammaTerm = gamma_i * fprime_i * fprime_i;
                double scale = gammaTerm + deltaTerm;
                if (nextIsLoss) {
                    int k = 0;
                    while (k < this.w().cols()) {
                        double w_k = this.w().apply$mcD$sp(i.elem, k);
                        int n = k;
                        int n2 = m.elem;
                        prevGamma.update$mcD$sp(n, n2, prevGamma.apply$mcD$sp(n, n2) + deltaTerm * w_k * w_k);
                        double subSum = 0.0;
                        for (int u = 0; u < gamma.rows(); ++u) {
                            double w_u = this.w().apply$mcD$sp(u, k);
                            double output_u = prevOutput.apply$mcD$sp(u, m.elem);
                            double kron = 0.0;
                            if (u == i.elem) {
                                kron = 1.0;
                            }
                            double elem = (kron - output_u) * prevOutput_i * w_u;
                            subSum += elem;
                        }
                        int n3 = k++;
                        int n4 = m.elem;
                        prevGamma.update$mcD$sp(n3, n4, prevGamma.apply$mcD$sp(n3, n4) + subSum * w_k);
                    }
                } else {
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.w().cols()).foreach$mVc$sp((Function1)new Serializable(this, prevGamma, m, i, scale){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ IceAffineLayerModel $outer;
                        private final DenseMatrix prevGamma$1;
                        private final IntRef m$1;
                        private final IntRef i$1;
                        private final double scale$1;

                        public final void apply(int k) {
                            this.apply$mcVI$sp(k);
                        }

                        public void apply$mcVI$sp(int k) {
                            double currWeight = this.$outer.w().apply$mcD$sp(this.i$1.elem, k);
                            double computedGamma = this.scale$1 * currWeight * currWeight;
                            int n = this.m$1.elem;
                            this.prevGamma$1.update$mcD$sp(k, n, this.prevGamma$1.apply$mcD$sp(k, n) + computedGamma);
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.prevGamma$1 = prevGamma$1;
                            this.m$1 = m$1;
                            this.i$1 = i$1;
                            this.scale$1 = scale$1;
                        }
                    });
                }
                ++i.elem;
            }
            ++m.elem;
        }
    }

    @Override
    public double activationDeriv(double input) {
        return 1.0;
    }

    @Override
    public double activationSecondDeriv(double input) {
        return 0.0;
    }

    @Override
    public void setNextLayer(GeneralIceLayerModel nextLayer) {
        this.nextLayer_$eq(nextLayer);
    }

    public IceAffineLayerModel(DenseVector<Object> weights, IceAffineLayer layer) {
        this.weights = weights;
        this.layer = layer;
        this.w = new DenseMatrix.mcD.sp(layer.numOut(), layer.numIn(), weights.data$mcD$sp(), weights.offset());
        this.b = new DenseVector.mcD.sp(weights.data$mcD$sp(), weights.offset() + layer.numOut() * layer.numIn(), 1, layer.numOut());
        this.ones = null;
        this.nextLayer = null;
    }
}

