package org.deeplearning4j.plot;

import com.google.common.util.concurrent.AtomicDouble;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.quadtree.QuadTree;
import org.deeplearning4j.clustering.vptree.VpTreeNode;
import org.deeplearning4j.clustering.vptree.VpTreePoint;
import org.deeplearning4j.clustering.vptree.VpTreePointINDArray;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne.class */
public class BarnesHutTsne extends Tsne implements Model {
    private int n;
    private int d;
    private double perplexity;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private INDArray p;
    private INDArray x;
    private int numDimensions;
    public static final String Y_GRAD = "yIncs";

    public BarnesHutTsne(INDArray iNDArray, int i, int i2, INDArray iNDArray2, int i3, double d, double d2, int i4, int i5, int i6, double d3, double d4, double d5) {
        this.numDimensions = 0;
        this.n = i;
        this.d = i2;
        this.y = iNDArray2;
        this.x = iNDArray;
        this.numDimensions = i3;
        this.perplexity = d;
        this.theta = d2;
        this.maxIter = i4;
        this.stopLyingIteration = i5;
        this.momentum = d3;
        this.finalMomentum = d4;
        this.learningRate = d5;
        this.switchMomentumIteration = i6;
    }

    @Override // org.deeplearning4j.plot.Tsne
    public INDArray computeGaussianPerplexity(final INDArray iNDArray, double d) {
        int rows = iNDArray.rows();
        final int i = (int) (3.0d * d);
        this.rows = Nd4j.zeros(rows + 1);
        this.cols = Nd4j.zeros(rows, i);
        this.vals = Nd4j.zeros(rows, i);
        for (int i2 = 1; i2 < rows; i2++) {
            this.rows.putScalar(i2, this.rows.getDouble(i2 - 1) + i);
        }
        final INDArray ones = Nd4j.ones(rows, 1);
        final double log = Math.log(d);
        final List<VpTreePointINDArray> dataPoints = VpTreePointINDArray.dataPoints(iNDArray);
        final VpTreeNode buildVpTree = VpTreeNode.buildVpTree(dataPoints);
        log.info("Calculating probabilities of data similarities...");
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        for (int i3 = 0; i3 < rows; i3++) {
            if (i3 % 500 == 0) {
                log.info("Handled " + i3 + " records");
            }
            final int i4 = i3;
            newFixedThreadPool.submit(new Runnable() { // from class: org.deeplearning4j.plot.BarnesHutTsne.1
                @Override // java.lang.Runnable
                public void run() {
                    double d2 = Double.NEGATIVE_INFINITY;
                    double d3 = Double.POSITIVE_INFINITY;
                    Counter findNearByPointsWithDistancesK = buildVpTree.findNearByPointsWithDistancesK((VpTreePoint) dataPoints.get(i4), i + 1);
                    INDArray slice = iNDArray.slice(i4);
                    Pair<INDArray, INDArray> hBeta = BarnesHutTsne.this.hBeta(slice, BarnesHutTsne.this.toNDArray(findNearByPointsWithDistancesK), ones.getDouble(i4));
                    INDArray second = hBeta.getSecond();
                    INDArray sub = hBeta.getFirst().sub(Double.valueOf(log));
                    for (int i5 = 0; BooleanIndexing.and(Transforms.abs(sub), Conditions.greaterThan(Double.valueOf(BarnesHutTsne.this.tolerance))) && i5 < 50; i5++) {
                        if (BooleanIndexing.and(sub, Conditions.greaterThan(0))) {
                            if (Double.isInfinite(d3)) {
                                ones.putScalar(i4, ones.getDouble(i4) * 2.0d);
                            } else {
                                ones.putScalar(i4, (ones.getDouble(i4) + d3) / 2.0d);
                            }
                            d2 = ones.getDouble(i4);
                        } else {
                            if (Double.isInfinite(d2)) {
                                ones.putScalar(i4, ones.getDouble(i4) / 2.0d);
                            } else {
                                ones.putScalar(i4, (ones.getDouble(i4) + d2) / 2.0d);
                            }
                            d3 = ones.getDouble(i4);
                        }
                        sub = BarnesHutTsne.this.hBeta(slice, BarnesHutTsne.this.toNDArray(findNearByPointsWithDistancesK), ones.getDouble(i4)).getFirst().subi(Double.valueOf(log));
                    }
                    INDArray div = second.div(second.sum(Integer.MAX_VALUE));
                    INDArray index = BarnesHutTsne.this.toIndex(findNearByPointsWithDistancesK);
                    for (int i6 = 0; i6 < i; i6++) {
                        BarnesHutTsne.this.cols.putScalar(new int[]{BarnesHutTsne.this.rows.getInt(new int[]{BarnesHutTsne.this.n}), i6}, index.getDouble(i6 + 1));
                        BarnesHutTsne.this.vals.putScalar(new int[]{BarnesHutTsne.this.rows.getInt(new int[]{BarnesHutTsne.this.n}), i6}, div.getDouble(i6));
                    }
                    BarnesHutTsne.this.cols.slice(i4).assign(BarnesHutTsne.this.toIndex(findNearByPointsWithDistancesK));
                }
            });
        }
        try {
            newFixedThreadPool.shutdown();
            newFixedThreadPool.awaitTermination(1L, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        return this.vals;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.x;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
    }

    @Override // org.deeplearning4j.plot.Tsne
    protected Pair<Double, INDArray> gradient(INDArray iNDArray) {
        this.p = iNDArray;
        return new Pair<>(Double.valueOf(score()), getGradient().gradientLookupTable().get(Y_GRAD));
    }

    @Override // org.deeplearning4j.plot.Tsne
    public INDArray getYGradient(int i, INDArray iNDArray, INDArray iNDArray2) {
        INDArray create = Nd4j.create(this.y.shape());
        for (int i2 = 0; i2 < i; i2++) {
            create.putRow(i2, Nd4j.tile(iNDArray.getRow(i2).mul(iNDArray2.getRow(i2)), new int[]{this.y.columns(), 1}).transpose().mul(this.y.getRow(i2).broadcast(this.y.shape()).sub(this.y)).sum(0));
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public INDArray toIndex(Counter<VpTreePointINDArray> counter) {
        INDArray create = Nd4j.create(counter.size());
        List<VpTreePointINDArray> sortedKeys = counter.getSortedKeys();
        for (int i = 0; i < sortedKeys.size(); i++) {
            create.putScalar(i, sortedKeys.get(i).getIndex());
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public INDArray toNDArray(Counter<VpTreePointINDArray> counter) {
        INDArray create = Nd4j.create(counter.size());
        List<VpTreePointINDArray> sortedKeys = counter.getSortedKeys();
        for (int i = 0; i < sortedKeys.size(); i++) {
            create.putScalar(i, counter.getCount(sortedKeys.get(i)));
        }
        return create;
    }

    public Pair<INDArray, INDArray> hBeta(INDArray iNDArray, INDArray iNDArray2, double d) {
        INDArray exp = Transforms.exp(iNDArray.neg().muli(Double.valueOf(d)).muli(iNDArray2));
        INDArray sum = exp.sum(Integer.MAX_VALUE);
        INDArray addi = Transforms.log(sum).addi(iNDArray.mul(exp).sum(0).muli(Double.valueOf(d)).muli(iNDArray2).divi(sum));
        exp.divi(sum);
        return new Pair<>(addi, exp);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        if (this.theta == 0.0d) {
            this.y = super.calculate(this.x, this.numDimensions, this.perplexity);
            return;
        }
        INDArray computeGaussianPerplexity = computeGaussianPerplexity(this.x, this.perplexity);
        for (int i = 0; i < this.maxIter; i++) {
            step(computeGaussianPerplexity, i);
            if (i == this.switchMomentumIteration) {
                this.momentum = this.finalMomentum;
            }
            if (i == this.stopLyingIteration) {
                computeGaussianPerplexity.divi(4);
            }
            if (this.iterationListener != null) {
                this.iterationListener.iterationDone(i);
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        QuadTree quadTree = new QuadTree(this.y);
        INDArray create = Nd4j.create(2);
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        for (int i = 0; i < this.y.rows(); i++) {
            quadTree.computeNonEdgeForces(i, this.theta, create, atomicDouble);
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.y.rows(); i2++) {
            INDArray slice = this.rows.slice(i2);
            int i3 = slice.getInt(new int[]{0});
            int i4 = slice.getInt(new int[]{1});
            for (int i5 = i3; i5 < i4; i5++) {
                create.assign(this.y.slice(i2));
                create.subi(this.cols.getRow(i5));
                double dot = (1.0d / (1.0d + Nd4j.getBlasWrapper().dot(create, create))) / atomicDouble.doubleValue();
                double d2 = this.vals.getDouble(i5, 0);
                d += d2 * Math.log((d2 + 1.401298464324817E-45d) / (dot + 3.4028234663852886E38d));
            }
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray transform(INDArray iNDArray) {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        return 0;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient getGradient() {
        if (this.yIncs == null) {
            this.yIncs = Nd4j.zeros(this.y.shape());
        }
        if (this.gains == null) {
            this.gains = Nd4j.ones(this.y.shape());
        }
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        INDArray create = Nd4j.create(this.p.rows(), this.p.columns());
        INDArray create2 = Nd4j.create(this.p.rows(), this.p.columns());
        QuadTree quadTree = new QuadTree(this.p);
        quadTree.computeEdgeForces(this.rows, this.cols, this.p, this.p.rows(), create);
        for (int i = 0; i < this.p.rows(); i++) {
            quadTree.computeNonEdgeForces(i, this.theta, create2, atomicDouble);
        }
        INDArray subi = create.subi(create2.divi(atomicDouble));
        this.gains = this.gains.add(Double.valueOf(0.2d)).muli(subi.cond(Conditions.greaterThan(0)).neqi(this.yIncs.cond(Conditions.greaterThan(0)))).addi(this.gains.mul(Double.valueOf(0.8d)).muli(subi.cond(Conditions.greaterThan(0)).eqi(this.yIncs.cond(Conditions.greaterThan(0)))));
        BooleanIndexing.applyWhere(this.gains, Conditions.lessThan(Double.valueOf(this.minGain)), new Value(Double.valueOf(this.minGain)));
        INDArray mul = this.gains.mul(subi);
        if (this.useAdaGrad) {
            mul = this.adaGrad.getGradient(mul);
        } else {
            mul.muli(Double.valueOf(this.learningRate));
        }
        this.yIncs.muli(Double.valueOf(this.momentum)).subi(mul);
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientLookupTable().put(Y_GRAD, this.yIncs);
        return defaultGradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return 0;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
    }
}
