/*
 * Decompiled with CFR 0.152.
 */
package epic.dense;

import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.storage.Zero;
import epic.dense.OutputEmbeddingTransform;
import epic.dense.Transform;
import scala.Console$;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.immutable.Range;
import scala.collection.immutable.Range$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.util.Random;

public final class OutputEmbeddingTransform$
implements Serializable {
    public static final OutputEmbeddingTransform$ MODULE$;

    static {
        new OutputEmbeddingTransform$();
    }

    public DenseVector<Object> getIdentityEmbeddingWeights(int numOutputs, int outputDim, Random rng) {
        boolean bl = outputDim <= numOutputs;
        Predef$ predef$ = Predef$.MODULE$;
        if (!bl) {
            throw new IllegalArgumentException(new StringBuilder().append((Object)"requirement failed: ").append((Object)new StringBuilder().append(outputDim).append((Object)" ").append((Object)BoxesRunTime.boxToInteger((int)numOutputs)).toString()).toString());
        }
        DenseMatrix mat = DenseMatrix$.MODULE$.zeros$mDc$sp(numOutputs, outputDim, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        Predef$ predef$2 = Predef$.MODULE$;
        Range range = Range$.MODULE$.apply(0, outputDim);
        if (!range.isEmpty()) {
            int n = range.start();
            while (true) {
                mat.update$mcD$sp(n, n, 1.0);
                if (n == range.lastElement()) break;
                n += range.step();
            }
        }
        Predef$ predef$3 = Predef$.MODULE$;
        Range range2 = Range$.MODULE$.apply(outputDim, numOutputs);
        if (!range2.isEmpty()) {
            int n = range2.start();
            while (true) {
                mat.update$mcD$sp(n, rng.nextInt(outputDim), 1.0);
                if (n == range2.lastElement()) break;
                n += range2.step();
            }
        }
        DenseVector biasInitializer = DenseVector$.MODULE$.zeros$mDc$sp(numOutputs, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        DenseVector initWeights = DenseVector$.MODULE$.vertcat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new DenseVector[]{DenseVector$.MODULE$.apply$mDc$sp(mat.data$mcD$sp()), biasInitializer}), DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        return initWeights;
    }

    public void clipEmbeddingNorms(DenseMatrix<Object> embeddings) {
        Predef$ predef$ = Predef$.MODULE$;
        int n = embeddings.rows();
        Range range = Range$.MODULE$.apply(0, n);
        if (!range.isEmpty()) {
            int n2 = range.start();
            while (true) {
                DoubleRef norm1 = DoubleRef.create((double)0.0);
                Predef$ predef$2 = Predef$.MODULE$;
                int n3 = embeddings.cols();
                Range range2 = Range$.MODULE$.apply(0, n3);
                if (!range2.isEmpty()) {
                    int n4 = range2.start();
                    while (true) {
                        norm1.elem += embeddings.apply$mcD$sp(n2, n4) * embeddings.apply$mcD$sp(n2, n4);
                        if (n4 == range2.lastElement()) break;
                        n4 += range2.step();
                    }
                }
                norm1.elem = Math.sqrt(norm1.elem);
                Predef$ predef$3 = Predef$.MODULE$;
                int n5 = embeddings.cols();
                Range range3 = Range$.MODULE$.apply(0, n5);
                if (!range3.isEmpty()) {
                    int n6 = range3.start();
                    while (true) {
                        embeddings.update$mcD$sp(n2, n6, embeddings.apply$mcD$sp(n2, n6) / norm1.elem);
                        if (n6 == range3.lastElement()) break;
                        n6 += range3.step();
                    }
                }
                if (n2 == range.lastElement()) break;
                n2 += range.step();
            }
        }
    }

    public void displayEmbeddingNorms(DenseMatrix<Object> embeddings) {
        DoubleRef avgNorm = DoubleRef.create((double)0.0);
        DoubleRef maxNorm = DoubleRef.create((double)0.0);
        Predef$ predef$ = Predef$.MODULE$;
        int n = embeddings.rows();
        Range range = Range$.MODULE$.apply(0, n);
        if (!range.isEmpty()) {
            int n2 = range.start();
            while (true) {
                DoubleRef norm1 = DoubleRef.create((double)0.0);
                Predef$ predef$2 = Predef$.MODULE$;
                int n3 = embeddings.cols();
                Range range2 = Range$.MODULE$.apply(0, n3);
                if (!range2.isEmpty()) {
                    int n4 = range2.start();
                    while (true) {
                        norm1.elem += embeddings.apply$mcD$sp(n2, n4) * embeddings.apply$mcD$sp(n2, n4);
                        if (n4 == range2.lastElement()) break;
                        n4 += range2.step();
                    }
                }
                norm1.elem = Math.sqrt(norm1.elem);
                avgNorm.elem += norm1.elem;
                maxNorm.elem = Math.max(maxNorm.elem, norm1.elem);
                if (n2 == range.lastElement()) break;
                n2 += range.step();
            }
        }
        String string = new StringBuilder().append((Object)"Average norm: ").append((Object)BoxesRunTime.boxToDouble((double)(avgNorm.elem / (double)embeddings.rows()))).append((Object)", max norm: ").append((Object)BoxesRunTime.boxToDouble((double)maxNorm.elem)).toString();
        Predef$ predef$3 = Predef$.MODULE$;
        Console$.MODULE$.println((Object)string);
    }

    public DenseVector<Object> getCoarsenedInitialEmbeddingWeights(int numOutputs, int outputDim, Function1<Object, Object> coarsenerForInitialization) {
        DenseMatrix mat = DenseMatrix$.MODULE$.zeros$mDc$sp(numOutputs, outputDim, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        Predef$ predef$ = Predef$.MODULE$;
        Range range = Range$.MODULE$.apply(0, numOutputs);
        if (!range.isEmpty()) {
            int n = range.start();
            while (true) {
                int j1 = (coarsenerForInitialization.apply$mcII$sp(n) % outputDim + outputDim) % outputDim;
                mat.update$mcD$sp(n, j1, 1.0);
                if (n == range.lastElement()) break;
                n += range.step();
            }
        }
        DenseVector biasInitializer = DenseVector$.MODULE$.zeros$mDc$sp(numOutputs, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        DenseVector initWeights = DenseVector$.MODULE$.vertcat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new DenseVector[]{DenseVector$.MODULE$.apply$mDc$sp(mat.data$mcD$sp()), biasInitializer}), DenseVector$.MODULE$.dv_dv_UpdateOp_Double_OpSet(), ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$);
        return initWeights;
    }

    public <FV> OutputEmbeddingTransform<FV> apply(int numOutputs, int outputDim, Transform<FV, DenseVector<Object>> innerTransform, Option<Function1<Object, Object>> coarsenerForInitialization) {
        return new OutputEmbeddingTransform<FV>(numOutputs, outputDim, innerTransform, coarsenerForInitialization);
    }

    public <FV> Option<Tuple4<Object, Object, Transform<FV, DenseVector<Object>>, Option<Function1<Object, Object>>>> unapply(OutputEmbeddingTransform<FV> x$0) {
        return x$0 == null ? None$.MODULE$ : new Some((Object)new Tuple4((Object)BoxesRunTime.boxToInteger((int)x$0.numOutputs()), (Object)BoxesRunTime.boxToInteger((int)x$0.outputDim()), x$0.innerTransform(), x$0.coarsenerForInitialization()));
    }

    public <FV> Option<Function1<Object, Object>> $lessinit$greater$default$4() {
        return None$.MODULE$;
    }

    public <FV> Option<Function1<Object, Object>> apply$default$4() {
        return None$.MODULE$;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private OutputEmbeddingTransform$() {
        MODULE$ = this;
    }
}

