package org.clulab.scala_transformers.encoder;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import java.util.HashMap;
import org.clulab.scala_transformers.encoder.math.Mathematics$;
import org.clulab.shaded.org.ejml.data.FMatrixRMaj;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;

/* compiled from: Encoder.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055b\u0001\u0002\u000b\u0016\u0001yA\u0001\"\n\u0001\u0003\u0006\u0004%\tA\n\u0005\t_\u0001\u0011\t\u0011)A\u0005O!A\u0001\u0007\u0001BC\u0002\u0013\u0005\u0011\u0007\u0003\u00056\u0001\t\u0005\t\u0015!\u00033\u0011!1\u0004A!A!\u0002\u00139\u0004\"\u0002 \u0001\t\u0003y\u0004\"\u0002#\u0001\t\u0003)\u0005\"\u0002#\u0001\t\u00031w!B5\u0016\u0011\u0003Qg!\u0002\u000b\u0016\u0011\u0003Y\u0007\"\u0002 \u000b\t\u0003a\u0007bB7\u000b\u0005\u0004%\tA\n\u0005\u0007]*\u0001\u000b\u0011B\u0014\t\u000b=TA\u0011\u00039\t\u000bMTA\u0011\u0003;\t\r}TA\u0011CA\u0001\u0011\u001d\t9A\u0003C\u0001\u0003\u0013Aq!a\u0004\u000b\t\u0003\t\t\u0002C\u0005\u0002\u0016)\t\n\u0011\"\u0001\u0002\u0018\t9QI\\2pI\u0016\u0014(B\u0001\f\u0018\u0003\u001d)gnY8eKJT!\u0001G\r\u0002%M\u001c\u0017\r\\1`iJ\fgn\u001d4pe6,'o\u001d\u0006\u00035m\taa\u00197vY\u0006\u0014'\"\u0001\u000f\u0002\u0007=\u0014xm\u0001\u0001\u0014\u0005\u0001y\u0002C\u0001\u0011$\u001b\u0005\t#\"\u0001\u0012\u0002\u000bM\u001c\u0017\r\\1\n\u0005\u0011\n#AB!osJ+g-\u0001\nf]\u000e|G-\u001a:F]ZL'o\u001c8nK:$X#A\u0014\u0011\u0005!jS\"A\u0015\u000b\u0005)Z\u0013aC8o]b\u0014XO\u001c;j[\u0016T\u0011\u0001L\u0001\u0003C&L!AL\u0015\u0003\u001d=\u0013H/\u00128wSJ|g.\\3oi\u0006\u0019RM\\2pI\u0016\u0014XI\u001c<je>tW.\u001a8uA\u0005qQM\\2pI\u0016\u00148+Z:tS>tW#\u0001\u001a\u0011\u0005!\u001a\u0014B\u0001\u001b*\u0005)y%\u000f^*fgNLwN\\\u0001\u0010K:\u001cw\u000eZ3s'\u0016\u001c8/[8oA\u0005Ian\u001c8MS:|\u0005\u000f\u001e\t\u0004AaR\u0014BA\u001d\"\u0005\u0019y\u0005\u000f^5p]B\u00111\bP\u0007\u0002+%\u0011Q(\u0006\u0002\r\u001d>tG*\u001b8fCJLG/_\u0001\u0007y%t\u0017\u000e\u001e \u0015\t\u0001\u000b%i\u0011\t\u0003w\u0001AQ!\n\u0004A\u0002\u001dBQ\u0001\r\u0004A\u0002IBqA\u000e\u0004\u0011\u0002\u0003\u0007q'A\u0004g_J<\u0018M\u001d3\u0015\u0005\u0019{\u0006c\u0001\u0011H\u0013&\u0011\u0001*\t\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0003\u0015rs!aS-\u000f\u00051;fBA'W\u001d\tqUK\u0004\u0002P):\u0011\u0001kU\u0007\u0002#*\u0011!+H\u0001\u0007yI|w\u000e\u001e \n\u0003qI!AG\u000e\n\u0005aI\u0012B\u0001\f\u0018\u0013\tAV#\u0001\u0003nCRD\u0017B\u0001.\\\u0003-i\u0015\r\u001e5f[\u0006$\u0018nY:\u000b\u0005a+\u0012BA/_\u0005)i\u0015\r\u001e5NCR\u0014\u0018\u000e\u001f\u0006\u00035nCQ\u0001Y\u0004A\u0002\u0005\fQBY1uG\"Le\u000e];u\u0013\u0012\u001c\bc\u0001\u0011HEB\u0019\u0001eR2\u0011\u0005\u0001\"\u0017BA3\"\u0005\u0011auN\\4\u0015\u0005%;\u0007\"\u00025\t\u0001\u0004\u0011\u0017\u0001C5oaV$\u0018\nZ:\u0002\u000f\u0015s7m\u001c3feB\u00111HC\n\u0003\u0015}!\u0012A[\u0001\u000f_J$XI\u001c<je>tW.\u001a8u\u0003=y'\u000f^#om&\u0014xN\\7f]R\u0004\u0013a\u00034s_6\u001cVm]:j_:$\"\u0001Q9\t\u000bIt\u0001\u0019\u0001\u001a\u0002\u0015=\u0014HoU3tg&|g.\u0001\npeR\u001cVm]:j_:4%o\\7GS2,GC\u0001\u001av\u0011\u00151x\u00021\u0001x\u0003!1\u0017\u000e\\3OC6,\u0007C\u0001=}\u001d\tI(\u0010\u0005\u0002QC%\u001110I\u0001\u0007!J,G-\u001a4\n\u0005ut(AB*ue&twM\u0003\u0002|C\u00051rN\u001d;TKN\u001c\u0018n\u001c8Ge>l'+Z:pkJ\u001cW\rF\u00023\u0003\u0007Aa!!\u0002\u0011\u0001\u00049\u0018\u0001\u0004:fg>,(oY3OC6,\u0017\u0001\u00034s_64\u0015\u000e\\3\u0015\u0007\u0001\u000bY\u0001\u0003\u0004\u0002\u000eE\u0001\ra^\u0001\u000e_:t\u00070T8eK24\u0015\u000e\\3\u0002\u0019\u0019\u0014x.\u001c*fg>,(oY3\u0015\u0007\u0001\u000b\u0019\u0002\u0003\u0004\u0002\u0006I\u0001\ra^\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u001a\u0016\u0005\u0005e!fA\u001c\u0002\u001c-\u0012\u0011Q\u0004\t\u0005\u0003?\tI#\u0004\u0002\u0002\")!\u00111EA\u0013\u0003%)hn\u00195fG.,GMC\u0002\u0002(\u0005\n!\"\u00198o_R\fG/[8o\u0013\u0011\tY#!\t\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cW\r")
/* loaded from: input_file:org/clulab/scala_transformers/encoder/Encoder.class */
public class Encoder {
    private final OrtEnvironment encoderEnvironment;
    private final OrtSession encoderSession;
    private final Option<NonLinearity> nonLinOpt;

    public static Encoder fromResource(String str) {
        return Encoder$.MODULE$.fromResource(str);
    }

    public static Encoder fromFile(String str) {
        return Encoder$.MODULE$.fromFile(str);
    }

    public static OrtEnvironment ortEnvironment() {
        return Encoder$.MODULE$.ortEnvironment();
    }

    public OrtEnvironment encoderEnvironment() {
        return this.encoderEnvironment;
    }

    public OrtSession encoderSession() {
        return this.encoderSession;
    }

    public FMatrixRMaj[] forward(long[][] jArr) {
        HashMap hashMap = new HashMap();
        hashMap.put("token_ids", OnnxTensor.createTensor(encoderEnvironment(), jArr));
        FMatrixRMaj[] fromResult = Mathematics$.MODULE$.Math().fromResult(encoderSession().run(hashMap));
        this.nonLinOpt.foreach(nonLinearity -> {
            $anonfun$forward$1(fromResult, nonLinearity);
            return BoxedUnit.UNIT;
        });
        return fromResult;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public FMatrixRMaj forward(long[] jArr) {
        return (FMatrixRMaj) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(forward((long[][]) ((Object[]) new long[]{jArr})))).head();
    }

    public static final /* synthetic */ void $anonfun$forward$2(NonLinearity nonLinearity, FMatrixRMaj fMatrixRMaj) {
        Mathematics$.MODULE$.Math().map(fMatrixRMaj, (Function1<Object, Object>) f -> {
            return nonLinearity.compute(f);
        });
    }

    public static final /* synthetic */ void $anonfun$forward$1(FMatrixRMaj[] fMatrixRMajArr, NonLinearity nonLinearity) {
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(fMatrixRMajArr)).foreach(fMatrixRMaj -> {
            $anonfun$forward$2(nonLinearity, fMatrixRMaj);
            return BoxedUnit.UNIT;
        });
    }

    public Encoder(OrtEnvironment ortEnvironment, OrtSession ortSession, Option<NonLinearity> option) {
        this.encoderEnvironment = ortEnvironment;
        this.encoderSession = ortSession;
        this.nonLinOpt = option;
    }
}
