package org.clulab.scala_transformers.encoder.math;

import ai.onnxruntime.OrtSession;
import org.clulab.shaded.org.ejml.data.FMatrixIterator;
import org.clulab.shaded.org.ejml.data.FMatrixRMaj;
import org.clulab.shaded.org.ejml.simple.SimpleMatrix;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.FloatRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: EjmlMath.scala */
/* loaded from: input_file:org/clulab/scala_transformers/encoder/math/EjmlMath$.class */
public final class EjmlMath$ implements Math {
    public static final EjmlMath$ MODULE$ = null;

    static {
        new EjmlMath$();
    }

    public boolean isRowVector(FMatrixRMaj fMatrixRMaj) {
        return fMatrixRMaj.getNumRows() == 1;
    }

    public boolean isColVector(FMatrixRMaj fMatrixRMaj) {
        return fMatrixRMaj.getNumCols() == 1;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj[] fromResult(OrtSession.Result result) {
        return (FMatrixRMaj[]) Predef$.MODULE$.refArrayOps((float[][][]) result.get(0).getValue()).map(new EjmlMath$$anonfun$1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(FMatrixRMaj.class)));
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public int argmax(FMatrixRMaj fMatrixRMaj) {
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj));
        IntRef create = IntRef.create(0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(1), fMatrixRMaj.getNumCols()).foreach$mVc$sp(new EjmlMath$$anonfun$argmax$1(fMatrixRMaj, create, FloatRef.create(fMatrixRMaj.get(create.elem))));
        return create.elem;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public void inplaceMatrixAddition(FMatrixRMaj fMatrixRMaj, FMatrixRMaj fMatrixRMaj2) {
        Predef$.MODULE$.assert(isColVector(fMatrixRMaj2));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), fMatrixRMaj.getNumRows()).foreach$mVc$sp(new EjmlMath$$anonfun$inplaceMatrixAddition$1(fMatrixRMaj, fMatrixRMaj2));
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public void inplaceMatrixAddition(FMatrixRMaj fMatrixRMaj, int i, FMatrixRMaj fMatrixRMaj2) {
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj2));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), fMatrixRMaj.getNumCols()).foreach$mVc$sp(new EjmlMath$$anonfun$inplaceMatrixAddition$2(fMatrixRMaj, i, fMatrixRMaj2));
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj mul(FMatrixRMaj fMatrixRMaj, FMatrixRMaj fMatrixRMaj2) {
        return (FMatrixRMaj) SimpleMatrix.wrap(fMatrixRMaj).mult(SimpleMatrix.wrap(fMatrixRMaj2)).getMatrix();
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public int rows(FMatrixRMaj fMatrixRMaj) {
        return fMatrixRMaj.getNumRows();
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public int cols(FMatrixRMaj fMatrixRMaj) {
        return fMatrixRMaj.getNumCols();
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public int length(FMatrixRMaj fMatrixRMaj) {
        Predef$.MODULE$.assert(isColVector(fMatrixRMaj));
        return fMatrixRMaj.getNumRows();
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj vertcat(FMatrixRMaj fMatrixRMaj, FMatrixRMaj fMatrixRMaj2) {
        Predef$.MODULE$.assert(isColVector(fMatrixRMaj));
        Predef$.MODULE$.assert(isColVector(fMatrixRMaj2));
        FMatrixRMaj fMatrixRMaj3 = (FMatrixRMaj) SimpleMatrix.wrap(fMatrixRMaj).concatRows(SimpleMatrix.wrap(fMatrixRMaj2)).getMatrix();
        Predef$.MODULE$.assert(isColVector(fMatrixRMaj3));
        return fMatrixRMaj3;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj zeros(int i, int i2) {
        return new FMatrixRMaj(i, i2);
    }

    public void map(FMatrixRMaj fMatrixRMaj, Function1<Object, Object> function1) {
        FMatrixIterator it = fMatrixRMaj.iterator(true, 0, 0, fMatrixRMaj.getNumRows(), fMatrixRMaj.getNumCols());
        while (it.hasNext()) {
            it.set(function1.apply$mcFF$sp(Predef$.MODULE$.Float2float(it.next())));
        }
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj row(FMatrixRMaj fMatrixRMaj, int i) {
        FMatrixRMaj fMatrixRMaj2 = (FMatrixRMaj) SimpleMatrix.wrap(fMatrixRMaj).rows(i, i + 1).getMatrix();
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj2));
        return fMatrixRMaj2;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj horcat(FMatrixRMaj fMatrixRMaj, FMatrixRMaj fMatrixRMaj2) {
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj));
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj2));
        FMatrixRMaj fMatrixRMaj3 = (FMatrixRMaj) SimpleMatrix.wrap(fMatrixRMaj).concatColumns(SimpleMatrix.wrap(fMatrixRMaj2)).getMatrix();
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj3));
        return fMatrixRMaj3;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public float[] toArray(FMatrixRMaj fMatrixRMaj) {
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj));
        return fMatrixRMaj.getData();
    }

    public float get(FMatrixRMaj fMatrixRMaj, int i) {
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj));
        return fMatrixRMaj.get(i);
    }

    public void set(FMatrixRMaj fMatrixRMaj, int i, float f) {
        Predef$.MODULE$.assert(isRowVector(fMatrixRMaj));
        fMatrixRMaj.set(i, f);
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj mkMatrixFromRows(float[][] fArr) {
        return new FMatrixRMaj(fArr);
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj mkMatrixFromCols(float[][] fArr) {
        int length = fArr.length;
        int length2 = ((float[]) Predef$.MODULE$.refArrayOps(fArr).head()).length;
        FMatrixRMaj fMatrixRMaj = new FMatrixRMaj(length2, length);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp(new EjmlMath$$anonfun$mkMatrixFromCols$1(fArr, length2, fMatrixRMaj));
        return fMatrixRMaj;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public FMatrixRMaj mkColVector(float[] fArr) {
        FMatrixRMaj fMatrixRMaj = new FMatrixRMaj(fArr);
        Predef$.MODULE$.assert(isColVector(fMatrixRMaj));
        return fMatrixRMaj;
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public /* bridge */ /* synthetic */ void set(Object obj, int i, Object obj2) {
        set((FMatrixRMaj) obj, i, BoxesRunTime.unboxToFloat(obj2));
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public /* bridge */ /* synthetic */ Object get(Object obj, int i) {
        return BoxesRunTime.boxToFloat(get((FMatrixRMaj) obj, i));
    }

    @Override // org.clulab.scala_transformers.encoder.math.Math
    public /* bridge */ /* synthetic */ void map(Object obj, Function1 function1) {
        map((FMatrixRMaj) obj, (Function1<Object, Object>) function1);
    }

    private EjmlMath$() {
        MODULE$ = this;
    }
}
