/*
 * Decompiled with CFR 0.152.
 */
package deepboof.misc;

import deepboof.Tensor;
import deepboof.misc.TensorOps_F32;
import deepboof.misc.TensorOps_F64;
import deepboof.tensors.Tensor_F32;
import deepboof.tensors.Tensor_F64;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

public class TensorOps {
    public static <T> List<T> WT(T ... elements) {
        ArrayList<T> list = new ArrayList<T>();
        for (int i = 0; i < elements.length; ++i) {
            list.add(elements[i]);
        }
        return list;
    }

    public static int[] WI(int ... elements) {
        return elements;
    }

    public static int[] WI(int a, int[] elements) {
        int[] out = new int[1 + elements.length];
        out[0] = a;
        System.arraycopy(elements, 0, out, 1, elements.length);
        return out;
    }

    public static int[] WI(int[] elements, int a) {
        int[] out = new int[1 + elements.length];
        System.arraycopy(elements, 0, out, 0, elements.length);
        out[elements.length] = a;
        return out;
    }

    public static int[] TH(int[] elements) {
        int[] out = new int[elements.length - 1];
        System.arraycopy(elements, 1, out, 0, out.length);
        return out;
    }

    public static List<int[]> WI(int a, List<int[]> list) {
        ArrayList<int[]> output = new ArrayList<int[]>();
        for (int[] elements : list) {
            output.add(TensorOps.WI(a, elements));
        }
        return output;
    }

    public static <T extends Tensor> T AD(T input) {
        if (!(input instanceof Tensor_F64)) {
            throw new RuntimeException("Unsupported type");
        }
        Tensor_F64 out = new Tensor_F64();
        out.shape = TensorOps.WI(1, input.shape);
        ((Tensor)out).setData(input.getData());
        out.computeStrides();
        return (T)out;
    }

    public static int sumTensorLength(List<int[]> shapes) {
        int total = 0;
        for (int i = 0; i < shapes.size(); ++i) {
            total += TensorOps.tensorLength(shapes.get(i));
        }
        return total;
    }

    public static int tensorLength(int ... shape) {
        if (shape.length == 0) {
            return 0;
        }
        int N = 1;
        for (int i = 0; i < shape.length; ++i) {
            N *= shape[i];
        }
        return N;
    }

    public static void checkShape(String which, List<int[]> expected, List<Tensor<?>> actual, boolean ignoreAxis0) {
        if (expected.size() != actual.size()) {
            throw new IllegalArgumentException(which + ": Unexpected number of tensors. " + expected.size() + " vs " + actual.size());
        }
        for (int i = 0; i < expected.size(); ++i) {
            int[] e = expected.get(i);
            int[] a = actual.get(i).getShape();
            TensorOps.checkShape(which, i, e, a, ignoreAxis0);
        }
    }

    public static void checkShape(Tensor_F64 a, Tensor_F64 b) {
        if (a.shape.length != b.shape.length) {
            throw new IllegalArgumentException("Dimension of tensors do not match. " + a.shape.length + " " + b.shape.length);
        }
        for (int i = 0; i < a.shape.length; ++i) {
            int da = a.shape[i];
            int db = b.shape[i];
            if (da == db) continue;
            throw new IllegalArgumentException("dimension " + i + "  does not match.  " + da + "  " + db);
        }
    }

    public static void checkShape(Tensor_F32 a, Tensor_F32 b) {
        if (a.shape.length != b.shape.length) {
            throw new IllegalArgumentException("Dimension of tensors do not match. " + a.shape.length + " " + b.shape.length);
        }
        for (int i = 0; i < a.shape.length; ++i) {
            int da = a.shape[i];
            int db = b.shape[i];
            if (da == db) continue;
            throw new IllegalArgumentException("dimension " + i + "  does not match.  " + da + "  " + db);
        }
    }

    public static void checkShape(String which, int tensor, int[] expected, int[] actual, boolean ignoreAxis0) {
        if (ignoreAxis0) {
            if (expected.length + 1 != actual.length) {
                String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";
                throw new IllegalArgumentException(header + " dimension doesn't match, expected = " + (expected.length + 1) + " found = " + actual.length);
            }
            for (int i = 0; i < expected.length; ++i) {
                if (expected[i] == actual[i + 1]) continue;
                String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";
                throw new IllegalArgumentException(header + " shapes don't match, expected = " + TensorOps.toStringShape(expected) + ", found = " + TensorOps.toStringShapeA(actual));
            }
        } else {
            if (expected.length != actual.length) {
                String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";
                throw new IllegalArgumentException(header + " dimension doesn't match, expected = " + expected.length + " found = " + actual.length);
            }
            for (int i = 0; i < expected.length; ++i) {
                if (expected[i] == actual[i]) continue;
                String header = tensor >= 0 ? which + ":  Tensor[" + tensor + "] " : which + ": ";
                throw new IllegalArgumentException(header + " shapes don't match, expected = " + TensorOps.toStringShape(expected) + ", found = " + TensorOps.toStringShape(actual));
            }
        }
    }

    public static String toStringShapeA(int[] shape) {
        String out = "( * , ";
        for (int i = 1; i < shape.length; ++i) {
            out = out + shape[i] + " , ";
        }
        return out + ")";
    }

    public static String toStringShape(int[] shape) {
        String out = "( ";
        for (int i = 0; i < shape.length; ++i) {
            out = out + shape[i] + " , ";
        }
        return out + ")";
    }

    public static int outerLength(int[] shape, int startDimen) {
        if (startDimen >= shape.length) {
            return 0;
        }
        int D = 1;
        for (int i = startDimen; i < shape.length; ++i) {
            D *= shape[i];
        }
        return D;
    }

    public static File pathToRoot() {
        for (File active = new File(".").getAbsoluteFile(); active != null; active = active.getParentFile()) {
            boolean foundModules = false;
            boolean foundExamples = false;
            boolean foundSettings = false;
            File[] children = active.listFiles();
            if (children == null) break;
            for (File d : children) {
                if (d.isDirectory() && d.getName().endsWith("modules")) {
                    foundModules = true;
                }
                if (d.isDirectory() && d.getName().endsWith("examples")) {
                    foundExamples = true;
                }
                if (!d.isFile() || !d.getName().equals("settings.gradle")) continue;
                foundSettings = true;
            }
            if (!foundModules || !foundExamples || !foundSettings) continue;
            return active;
        }
        throw new RuntimeException("Cant find the project root directory");
    }

    public static double elementSum(Tensor tensor) {
        if (tensor instanceof Tensor_F64) {
            return TensorOps_F64.elementSum((Tensor_F64)tensor);
        }
        if (tensor instanceof Tensor_F32) {
            return TensorOps_F32.elementSum((Tensor_F32)tensor);
        }
        throw new IllegalArgumentException("Support not added yet for this tensor type");
    }

    public static void fill(Tensor t, double value) {
        if (t instanceof Tensor_F64) {
            TensorOps_F64.fill((Tensor_F64)t, value);
        } else if (t instanceof Tensor_F32) {
            TensorOps_F32.fill((Tensor_F32)t, (float)value);
        } else {
            throw new IllegalArgumentException("Support not added yet for this tensor type");
        }
    }
}

