/*
 * Decompiled with CFR 0.152.
 */
package org.cicirello.permutations.distance;

import java.util.Arrays;
import org.cicirello.permutations.Permutation;
import org.cicirello.permutations.distance.NormalizedPermutationDistanceMeasurerDouble;

public final class WeightedKendallTauDistance
implements NormalizedPermutationDistanceMeasurerDouble {
    private final double[] weights;
    private final double maxDistance;

    public WeightedKendallTauDistance(double[] weights) {
        this.weights = (double[])weights.clone();
        double max = 0.0;
        for (int i = 0; i < weights.length - 1; ++i) {
            double runningSum = 0.0;
            for (int j = i + 1; j < weights.length; ++j) {
                runningSum += weights[j];
            }
            max += weights[i] * runningSum;
        }
        this.maxDistance = max;
    }

    public int supportedLength() {
        return this.weights.length;
    }

    @Override
    public double distancef(Permutation p1, Permutation p2) {
        if (p1.length() != this.weights.length || p2.length() != this.weights.length) {
            throw new IllegalArgumentException("p1 and/or p2 not of supported length of this instance");
        }
        int[] invP1 = p1.getInverse();
        int[] arrayP2 = new int[invP1.length];
        double[] w = new double[this.weights.length];
        for (int i = 0; i < arrayP2.length; ++i) {
            arrayP2[i] = invP1[p2.get(i)];
            w[arrayP2[i]] = this.weights[p2.get(i)];
        }
        return this.countWeightedInversions(arrayP2, w);
    }

    @Override
    public double maxf(int length) {
        return this.maxDistance;
    }

    private double countWeightedInversions(int[] array, double[] w) {
        if (array.length <= 1) {
            return 0.0;
        }
        int m = array.length >> 1;
        int[] left = Arrays.copyOfRange(array, 0, m);
        int[] right = Arrays.copyOfRange(array, m, array.length);
        double weightedCount = this.countWeightedInversions(left, w) + this.countWeightedInversions(right, w);
        int i = 0;
        int j = 0;
        int k = 0;
        while (i < left.length && j < right.length) {
            if (left[i] < right[j]) {
                array[k] = left[i];
                ++i;
                ++k;
                continue;
            }
            double leftWeights = 0.0;
            for (int x = i; x < left.length; ++x) {
                leftWeights += w[left[x]];
            }
            weightedCount += w[right[j]] * leftWeights;
            array[k] = right[j];
            ++j;
            ++k;
        }
        while (i < left.length) {
            array[k] = left[i];
            ++i;
            ++k;
        }
        while (j < right.length) {
            array[k] = right[j];
            ++j;
            ++k;
        }
        return weightedCount;
    }
}

