/*
 * Decompiled with CFR 0.152.
 */
package epic.dense;

import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.storage.Zero;
import epic.dense.OutputEmbeddingTransform;
import epic.dense.Transform;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.immutable.Range;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;
import scala.util.Random;

public final class OutputEmbeddingTransform$
implements Serializable {
    public static OutputEmbeddingTransform$ MODULE$;

    static {
        new OutputEmbeddingTransform$();
    }

    /*
     * WARNING - void declaration
     */
    public DenseVector<Object> getIdentityEmbeddingWeights(int numOutputs, int outputDim, Random rng) {
        void require_requirement;
        boolean bl;
        boolean bl2 = bl = outputDim <= numOutputs;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        if (require_requirement == false) {
            throw new IllegalArgumentException("requirement failed: " + OutputEmbeddingTransform$.$anonfun$getIdentityEmbeddingWeights$1(numOutputs, outputDim));
        }
        DenseMatrix mat = DenseMatrix$.MODULE$.zeros$mDc$sp(numOutputs, outputDim, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        int n = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range = RichInt$.MODULE$.until$extension0(n, outputDim);
        if (range == null) {
            throw null;
        }
        Range foreach$mVc$sp_this = range;
        if (!foreach$mVc$sp_this.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this.start();
            while (true) {
                mat.update$mcD$sp(foreach$mVc$sp_i, foreach$mVc$sp_i, 1.0);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this.step();
            }
        }
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range2 = RichInt$.MODULE$.until$extension0(outputDim, numOutputs);
        if (range2 == null) {
            throw null;
        }
        Range foreach$mVc$sp_this2 = range2;
        if (!foreach$mVc$sp_this2.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this2.start();
            while (true) {
                mat.update$mcD$sp(foreach$mVc$sp_i, rng.nextInt(outputDim), 1.0);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this2.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this2.step();
            }
        }
        DenseVector biasInitializer = DenseVector$.MODULE$.zeros$mDc$sp(numOutputs, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        return DenseVector$.MODULE$.vertcat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new DenseVector[]{DenseVector$.MODULE$.apply$mDc$sp(mat.data$mcD$sp()), biasInitializer}), DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
    }

    public void clipEmbeddingNorms(DenseMatrix<Object> embeddings) {
        int n = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range = RichInt$.MODULE$.until$extension0(n, embeddings.rows());
        if (range == null) {
            throw null;
        }
        Range foreach$mVc$sp_this = range;
        if (!foreach$mVc$sp_this.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this.start();
            while (true) {
                OutputEmbeddingTransform$.$anonfun$clipEmbeddingNorms$1(embeddings, foreach$mVc$sp_i);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this.step();
            }
        }
    }

    public void displayEmbeddingNorms(DenseMatrix<Object> embeddings) {
        DoubleRef avgNorm = DoubleRef.create((double)0.0);
        DoubleRef maxNorm = DoubleRef.create((double)0.0);
        int n = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range = RichInt$.MODULE$.until$extension0(n, embeddings.rows());
        if (range == null) {
            throw null;
        }
        Range foreach$mVc$sp_this = range;
        if (!foreach$mVc$sp_this.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this.start();
            while (true) {
                OutputEmbeddingTransform$.$anonfun$displayEmbeddingNorms$1(embeddings, avgNorm, maxNorm, foreach$mVc$sp_i);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this.step();
            }
        }
        Predef$.MODULE$.println((Object)("Average norm: " + avgNorm.elem / (double)embeddings.rows() + ", max norm: " + maxNorm.elem));
    }

    public DenseVector<Object> getCoarsenedInitialEmbeddingWeights(int numOutputs, int outputDim, Function1<Object, Object> coarsenerForInitialization) {
        DenseMatrix mat = DenseMatrix$.MODULE$.zeros$mDc$sp(numOutputs, outputDim, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        int n = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range = RichInt$.MODULE$.until$extension0(n, numOutputs);
        if (range == null) {
            throw null;
        }
        Range foreach$mVc$sp_this = range;
        if (!foreach$mVc$sp_this.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this.start();
            while (true) {
                OutputEmbeddingTransform$.$anonfun$getCoarsenedInitialEmbeddingWeights$1(outputDim, coarsenerForInitialization, mat, foreach$mVc$sp_i);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this.step();
            }
        }
        DenseVector biasInitializer = DenseVector$.MODULE$.zeros$mDc$sp(numOutputs, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        return DenseVector$.MODULE$.vertcat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new DenseVector[]{DenseVector$.MODULE$.apply$mDc$sp(mat.data$mcD$sp()), biasInitializer}), DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
    }

    public <FV> OutputEmbeddingTransform<FV> apply(int numOutputs, int outputDim, Transform<FV, DenseVector<Object>> innerTransform, Option<Function1<Object, Object>> coarsenerForInitialization) {
        return new OutputEmbeddingTransform<FV>(numOutputs, outputDim, innerTransform, coarsenerForInitialization);
    }

    public <FV> Option<Tuple4<Object, Object, Transform<FV, DenseVector<Object>>, Option<Function1<Object, Object>>>> unapply(OutputEmbeddingTransform<FV> x$0) {
        if (x$0 == null) {
            return None$.MODULE$;
        }
        return new Some((Object)new Tuple4((Object)BoxesRunTime.boxToInteger((int)x$0.numOutputs()), (Object)BoxesRunTime.boxToInteger((int)x$0.outputDim()), x$0.innerTransform(), x$0.coarsenerForInitialization()));
    }

    public <FV> Option<Function1<Object, Object>> $lessinit$greater$default$4() {
        return None$.MODULE$;
    }

    public <FV> Option<Function1<Object, Object>> apply$default$4() {
        return None$.MODULE$;
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ String $anonfun$getIdentityEmbeddingWeights$1(int numOutputs$1, int outputDim$2) {
        return outputDim$2 + " " + numOutputs$1;
    }

    public static final /* synthetic */ void $anonfun$clipEmbeddingNorms$2(DenseMatrix embeddings$1, DoubleRef norm$1, int i$1, int j) {
        norm$1.elem += embeddings$1.apply$mcD$sp(i$1, j) * embeddings$1.apply$mcD$sp(i$1, j);
    }

    public static final /* synthetic */ void $anonfun$clipEmbeddingNorms$1(DenseMatrix embeddings$1, int i) {
        DoubleRef norm = DoubleRef.create((double)0.0);
        int n = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range = RichInt$.MODULE$.until$extension0(n, embeddings$1.cols());
        if (range == null) {
            throw null;
        }
        Range foreach$mVc$sp_this = range;
        if (!foreach$mVc$sp_this.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this.start();
            while (true) {
                OutputEmbeddingTransform$.$anonfun$clipEmbeddingNorms$2(embeddings$1, norm, i, foreach$mVc$sp_i);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this.step();
            }
        }
        norm.elem = Math.sqrt(norm.elem);
        int n2 = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range2 = RichInt$.MODULE$.until$extension0(n2, embeddings$1.cols());
        if (range2 == null) {
            throw null;
        }
        Range foreach$mVc$sp_this2 = range2;
        if (!foreach$mVc$sp_this2.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this2.start();
            while (true) {
                embeddings$1.update$mcD$sp(i, foreach$mVc$sp_i, embeddings$1.apply$mcD$sp(i, foreach$mVc$sp_i) / norm.elem);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this2.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this2.step();
            }
        }
    }

    public static final /* synthetic */ void $anonfun$displayEmbeddingNorms$2(DenseMatrix embeddings$2, DoubleRef norm$2, int i$2, int j) {
        norm$2.elem += embeddings$2.apply$mcD$sp(i$2, j) * embeddings$2.apply$mcD$sp(i$2, j);
    }

    public static final /* synthetic */ void $anonfun$displayEmbeddingNorms$1(DenseMatrix embeddings$2, DoubleRef avgNorm$1, DoubleRef maxNorm$1, int i) {
        DoubleRef norm = DoubleRef.create((double)0.0);
        int n = 0;
        if (Predef$.MODULE$ == null) {
            throw null;
        }
        Range range = RichInt$.MODULE$.until$extension0(n, embeddings$2.cols());
        if (range == null) {
            throw null;
        }
        Range foreach$mVc$sp_this = range;
        if (!foreach$mVc$sp_this.isEmpty()) {
            int foreach$mVc$sp_i = foreach$mVc$sp_this.start();
            while (true) {
                OutputEmbeddingTransform$.$anonfun$displayEmbeddingNorms$2(embeddings$2, norm, i, foreach$mVc$sp_i);
                if (foreach$mVc$sp_i == foreach$mVc$sp_this.scala$collection$immutable$Range$$lastElement()) break;
                foreach$mVc$sp_i += foreach$mVc$sp_this.step();
            }
        }
        norm.elem = Math.sqrt(norm.elem);
        avgNorm$1.elem += norm.elem;
        maxNorm$1.elem = Math.max(maxNorm$1.elem, norm.elem);
    }

    public static final /* synthetic */ void $anonfun$getCoarsenedInitialEmbeddingWeights$1(int outputDim$1, Function1 coarsenerForInitialization$1, DenseMatrix mat$1, int i) {
        int j = (coarsenerForInitialization$1.apply$mcII$sp(i) % outputDim$1 + outputDim$1) % outputDim$1;
        mat$1.update$mcD$sp(i, j, 1.0);
    }

    private OutputEmbeddingTransform$() {
        MODULE$ = this;
    }
}

