/*
 * Decompiled with CFR 0.152.
 */
package breeze.optimize.linear;

import breeze.generic.UFunc;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.linalg.operators.OpMulMatrix$;
import breeze.math.MutableInnerProductVectorSpace;
import breeze.optimize.linear.ConjugateGradient$;
import breeze.optimize.linear.ConjugateGradient$State$;
import breeze.util.Implicits$;
import breeze.util.LazyLogger;
import breeze.util.SerializableLogging;
import java.io.Serializable;
import scala.Conversion;
import scala.Function1;
import scala.Predef;
import scala.Predef$;
import scala.Product;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.StringOps$;
import scala.collection.immutable.Seq;
import scala.math.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.LazyVals$;
import scala.runtime.Scala3RunTime$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Statics;

public class ConjugateGradient<T, M>
implements SerializableLogging {
    private volatile transient LazyLogger breeze$util$SerializableLogging$$_the_logger;
    private final double maxNormValue;
    private final int maxIterations;
    private final double normSquaredPenalty;
    private final double tolerance;
    public final MutableInnerProductVectorSpace<T, Object> breeze$optimize$linear$ConjugateGradient$$space;
    private final UFunc.UImpl2<OpMulMatrix$, M, T, T> mult;
    public final ConjugateGradient$State$ State$lzy1;

    public static double $lessinit$greater$default$1() {
        return ConjugateGradient$.MODULE$.$lessinit$greater$default$1();
    }

    public static int $lessinit$greater$default$2() {
        return ConjugateGradient$.MODULE$.$lessinit$greater$default$2();
    }

    public static double $lessinit$greater$default$3() {
        return ConjugateGradient$.MODULE$.$lessinit$greater$default$3();
    }

    public static double $lessinit$greater$default$4() {
        return ConjugateGradient$.MODULE$.$lessinit$greater$default$4();
    }

    public <T, M> ConjugateGradient(double maxNormValue, int maxIterations, double normSquaredPenalty, double tolerance, MutableInnerProductVectorSpace<T, Object> space, UFunc.UImpl2<OpMulMatrix$, M, T, T> mult) {
        this.maxNormValue = maxNormValue;
        this.maxIterations = maxIterations;
        this.normSquaredPenalty = normSquaredPenalty;
        this.tolerance = tolerance;
        this.breeze$optimize$linear$ConjugateGradient$$space = space;
        this.mult = mult;
        this.State$lzy1 = new ConjugateGradient$State$(this);
        SerializableLogging.$init$(this);
    }

    @Override
    public LazyLogger breeze$util$SerializableLogging$$_the_logger() {
        return this.breeze$util$SerializableLogging$$_the_logger;
    }

    @Override
    public void breeze$util$SerializableLogging$$_the_logger_$eq(LazyLogger x$1) {
        this.breeze$util$SerializableLogging$$_the_logger = x$1;
    }

    public T minimize(T a, M B) {
        return this.minimize(a, B, this.breeze$optimize$linear$ConjugateGradient$$space.zeroLike().apply(a));
    }

    public T minimize(T a, M B, T initX) {
        return (T)this.minimizeAndReturnResidual(a, B, initX)._1();
    }

    public final ConjugateGradient$State$ State() {
        return this.State$lzy1;
    }

    public Tuple2<T, T> minimizeAndReturnResidual(T a, M B, T initX) {
        State state = (State)Implicits$.MODULE$.scEnrichIterator((Iterator)this.iterations(a, B, initX)).last();
        Object object = Predef$.MODULE$.ArrowAssoc(state.x());
        return Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(object, state.residual());
    }

    public Iterator<State> iterations(T a, M B, T initX) {
        return Implicits$.MODULE$.scEnrichIterator(scala.package$.MODULE$.Iterator().iterate((Object)this.initialState(a, B, initX), (Function1 & Serializable)state -> {
            State state2;
            Object r = state.residual();
            Object d = state.direction();
            double rtr = state.rtr();
            T Bd = this.mult.apply(B, d);
            double dtd = BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).dot(d, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
            double alpha = package$.MODULE$.pow(BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(r, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl())), 2.0) / (BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).dot(Bd, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV())) + this.normSquaredPenalty * dtd);
            Object nextX = ((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(state.x())).$plus(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).$times(BoxesRunTime.boxToDouble((double)alpha), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS_M()), this.breeze$optimize$linear$ConjugateGradient$$space.addVV());
            double xnorm = BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(nextX, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl()));
            if (xnorm >= this.maxNormValue) {
                double alphaNext;
                this.logger().info(() -> this.iterations$$anonfun$3$$anonfun$1(state, xnorm));
                double xtd = BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(state.x())).dot(d, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                double xtx = BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(state.x())).dot(state.x(), this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                double normSquare = this.maxNormValue * this.maxNormValue;
                double radius = package$.MODULE$.sqrt(xtd * xtd + dtd * (normSquare - xtx));
                double d2 = alphaNext = xtd >= 0.0 ? (normSquare - xtx) / (xtd + radius) : (radius - xtd) / dtd;
                if (Predef$.MODULE$.double2Double(alphaNext).isNaN()) {
                    throw Scala3RunTime$.MODULE$.assertFailed((Object)("" + xtd + " " + normSquare + " " + xtx + "  " + xtd + " " + radius + " " + dtd));
                }
                breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble((double)alphaNext), d, state.x(), this.breeze$optimize$linear$ConjugateGradient$$space.scaleAddVV());
                breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble((double)(-alphaNext)), ((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(Bd)).$plus(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).$times$colon$times(BoxesRunTime.boxToDouble((double)this.normSquaredPenalty), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.addVV()), r, this.breeze$optimize$linear$ConjugateGradient$$space.scaleAddVV());
                state2 = this.State().apply(state.x(), r, d, state.iter() + 1, true);
            } else {
                boolean converged;
                ((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(state.x())).$colon$eq(nextX, this.breeze$optimize$linear$ConjugateGradient$$space.setIntoVV());
                ((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(r)).$minus$eq(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(Bd)).$plus(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).$times$colon$times(BoxesRunTime.boxToDouble((double)this.normSquaredPenalty), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.addVV()))).$times$colon$times(BoxesRunTime.boxToDouble((double)alpha), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.subIntoVV());
                double newrtr = BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(r)).dot(r, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                double beta = newrtr / rtr;
                ((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).$colon$times$eq(BoxesRunTime.boxToDouble((double)beta), this.breeze$optimize$linear$ConjugateGradient$$space.mulIntoVS());
                ((NumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(d)).$plus$eq(r, this.breeze$optimize$linear$ConjugateGradient$$space.addIntoVV());
                rtr = newrtr;
                double normr = BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(r, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl()));
                boolean bl = converged = normr <= this.tolerance || state.iter() > this.maxIterations && this.maxIterations > 0;
                if (converged) {
                    boolean done;
                    boolean bl2 = done = state.iter() > this.maxIterations && this.maxIterations > 0;
                    if (done) {
                        this.logger().info(() -> this.iterations$$anonfun$4$$anonfun$2(state, normr));
                    } else {
                        this.logger().info(() -> this.iterations$$anonfun$5$$anonfun$3(state, normr));
                    }
                } else {
                    this.logger().info(() -> this.iterations$$anonfun$6$$anonfun$4(state, normr));
                }
                state2 = this.State().apply(state.x(), r, d, state.iter() + 1, converged);
            }
            return state2;
        })).takeUpToWhere((Function1 & Serializable)_$1 -> _$1.converged());
    }

    private State initialState(T a, M B, T initX) {
        Object r = ((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(a)).$minus(this.mult.apply(B, initX), this.breeze$optimize$linear$ConjugateGradient$$space.subVV()))).$minus(((ImmutableNumericOps)((Conversion)this.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(initX)).$times$colon$times(BoxesRunTime.boxToDouble((double)this.normSquaredPenalty), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.subVV());
        Object d = this.breeze$optimize$linear$ConjugateGradient$$space.copy().apply(r);
        double rnorm = BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(r, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl()));
        return this.State().apply(initX, r, d, 0, rnorm <= this.tolerance);
    }

    private final String iterations$$anonfun$3$$anonfun$1(State state$1, double xnorm$1) {
        return StringOps$.MODULE$.format$extension("%s boundary reached! norm(x): %.3f >= maxNormValue %s", (Seq)ScalaRunTime$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)state$1.iter()), BoxesRunTime.boxToDouble((double)xnorm$1), BoxesRunTime.boxToDouble((double)this.maxNormValue)}));
    }

    private final String iterations$$anonfun$4$$anonfun$2(State state$2, double normr$1) {
        return StringOps$.MODULE$.format$extension("max iteration %s reached! norm(residual): %.3f > tolerance %s.", (Seq)ScalaRunTime$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)state$2.iter()), BoxesRunTime.boxToDouble((double)normr$1), BoxesRunTime.boxToDouble((double)this.tolerance)}));
    }

    private final String iterations$$anonfun$5$$anonfun$3(State state$3, double normr$2) {
        return StringOps$.MODULE$.format$extension("%s converged! norm(residual): %.3f <= tolerance %s.", (Seq)ScalaRunTime$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)state$3.iter()), BoxesRunTime.boxToDouble((double)normr$2), BoxesRunTime.boxToDouble((double)this.tolerance)}));
    }

    private final String iterations$$anonfun$6$$anonfun$4(State state$4, double normr$3) {
        return StringOps$.MODULE$.format$extension("%s: norm(residual): %.3f > tolerance %s.", (Seq)ScalaRunTime$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)state$4.iter()), BoxesRunTime.boxToDouble((double)normr$3), BoxesRunTime.boxToDouble((double)this.tolerance)}));
    }

    /*
     * Illegal identifiers - consider using --renameillegalidents true
     */
    public class State
    implements Product,
    Serializable {
        public static final long OFFSET$0 = LazyVals$.MODULE$.getOffset(State.class, "0bitmap$1");
        public long 0bitmap$1;
        private final Object x;
        private final Object residual;
        private final Object direction;
        private final int iter;
        private final boolean converged;
        public double rtr$lzy1;
        private final ConjugateGradient<T, M> $outer;

        public State(ConjugateGradient $outer, T x, T residual, T direction, int iter, boolean converged) {
            this.x = x;
            this.residual = residual;
            this.direction = direction;
            this.iter = iter;
            this.converged = converged;
            if ($outer == null) {
                throw new NullPointerException();
            }
            this.$outer = $outer;
        }

        public int hashCode() {
            int n = -889275714;
            n = Statics.mix((int)n, (int)this.productPrefix().hashCode());
            n = Statics.mix((int)n, (int)Statics.anyHash(this.x()));
            n = Statics.mix((int)n, (int)Statics.anyHash(this.residual()));
            n = Statics.mix((int)n, (int)Statics.anyHash(this.direction()));
            n = Statics.mix((int)n, (int)this.iter());
            n = Statics.mix((int)n, (int)(this.converged() ? 1231 : 1237));
            return Statics.finalizeHash((int)n, (int)5);
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public boolean equals(Object x$0) {
            if (this == x$0) return true;
            Object object = x$0;
            if (!(object instanceof State)) return false;
            if (((State)object).breeze$optimize$linear$ConjugateGradient$State$$$outer() != this.$outer) return false;
            State state = (State)object;
            if (this.iter() != state.iter()) return false;
            if (this.converged() != state.converged()) return false;
            if (!BoxesRunTime.equals(this.x(), state.x())) return false;
            if (!BoxesRunTime.equals(this.residual(), state.residual())) return false;
            if (!BoxesRunTime.equals(this.direction(), state.direction())) return false;
            if (!state.canEqual(this)) return false;
            return true;
        }

        public String toString() {
            return ScalaRunTime$.MODULE$._toString((Product)this);
        }

        public boolean canEqual(Object that) {
            return that instanceof State;
        }

        public int productArity() {
            return 5;
        }

        public String productPrefix() {
            return "State";
        }

        public Object productElement(int n) {
            Object object;
            int n2 = n;
            switch (n2) {
                case 0: {
                    object = this._1();
                    break;
                }
                case 1: {
                    object = this._2();
                    break;
                }
                case 2: {
                    object = this._3();
                    break;
                }
                case 3: {
                    object = BoxesRunTime.boxToInteger((int)this._4());
                    break;
                }
                case 4: {
                    object = BoxesRunTime.boxToBoolean((boolean)this._5());
                    break;
                }
                default: {
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger((int)n).toString());
                }
            }
            return object;
        }

        public String productElementName(int n) {
            String string;
            int n2 = n;
            switch (n2) {
                case 0: {
                    string = "x";
                    break;
                }
                case 1: {
                    string = "residual";
                    break;
                }
                case 2: {
                    string = "direction";
                    break;
                }
                case 3: {
                    string = "iter";
                    break;
                }
                case 4: {
                    string = "converged";
                    break;
                }
                default: {
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger((int)n).toString());
                }
            }
            return string;
        }

        public T x() {
            return this.x;
        }

        public T residual() {
            return this.residual;
        }

        public T direction() {
            return this.direction;
        }

        public int iter() {
            return this.iter;
        }

        public boolean converged() {
            return this.converged;
        }

        public double rtr() {
            long l;
            long l2;
            while ((l2 = LazyVals$.MODULE$.STATE(l = LazyVals$.MODULE$.get((Object)this, OFFSET$0), 0)) != 3L) {
                if (l2 == 0L) {
                    if (!LazyVals$.MODULE$.CAS((Object)this, OFFSET$0, l, 1, 0)) continue;
                    try {
                        double d;
                        this.rtr$lzy1 = d = BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.$outer.breeze$optimize$linear$ConjugateGradient$$space.hasOps()).apply(this.residual())).dot(this.residual(), this.$outer.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                        LazyVals$.MODULE$.setFlag((Object)this, OFFSET$0, 3, 0);
                        return d;
                    }
                    catch (Throwable throwable) {
                        LazyVals$.MODULE$.setFlag((Object)this, OFFSET$0, 0, 0);
                        throw throwable;
                    }
                }
                LazyVals$.MODULE$.wait4Notification((Object)this, OFFSET$0, l, 0);
            }
            return this.rtr$lzy1;
        }

        public State copy(T x, T residual, T direction, int iter, boolean converged) {
            return new State(this.$outer, x, residual, direction, iter, converged);
        }

        public Object copy$default$1() {
            return this.x();
        }

        public Object copy$default$2() {
            return this.residual();
        }

        public Object copy$default$3() {
            return this.direction();
        }

        public int copy$default$4() {
            return this.iter();
        }

        public boolean copy$default$5() {
            return this.converged();
        }

        public T _1() {
            return this.x();
        }

        public T _2() {
            return this.residual();
        }

        public T _3() {
            return this.direction();
        }

        public int _4() {
            return this.iter();
        }

        public boolean _5() {
            return this.converged();
        }

        public final ConjugateGradient<T, M> breeze$optimize$linear$ConjugateGradient$State$$$outer() {
            return this.$outer;
        }
    }
}

