/*
 * Decompiled with CFR 0.152.
 */
package deepboof.graph;

import deepboof.Function;
import deepboof.Tensor;
import deepboof.graph.InputAddress;
import deepboof.graph.Node;
import deepboof.misc.TensorFactory;
import deepboof.misc.TensorOps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ddogleg.struct.Tuple2;

public class FunctionSequence<T extends Tensor<T>, F extends Function<T>> {
    protected List<Node<T, F>> sequence = new ArrayList<Node<T, F>>();
    protected Map<String, Node<T, F>> lookup = new HashMap<String, Node<T, F>>();
    protected Map<String, Tuple2<T, T>> outputStorage = new HashMap<String, Tuple2<T, T>>();
    protected TensorFactory<T> factory;
    boolean verbose = false;

    public FunctionSequence(List<Node<T, F>> sequence, Class<T> type) {
        this.sequence = sequence;
        for (Node<T, F> n : sequence) {
            if (this.lookup.containsKey(n.name)) {
                throw new IllegalArgumentException("Conflict. Multiple nodes with the same name. " + n.name);
            }
            this.lookup.put(n.name, n);
        }
        this.factory = new TensorFactory(type);
    }

    public void initialize(int[] inputShape) {
        this.initializeSequence(inputShape);
    }

    private void initializeSequence(int[] inputShape) {
        if (this.sequence.get((int)0).sources.size() != 0) {
            throw new RuntimeException("Input sequence can't have a source address!");
        }
        ArrayList<int[]> inputs = new ArrayList<int[]>();
        this.sequence.get((int)0).function.initialize(inputShape);
        this.outputStorage.put(this.sequence.get((int)0).name, new Tuple2(this.factory.create(new int[0]), this.factory.create(new int[0])));
        if (this.verbose) {
            System.out.println("ROOT ========= " + this.sequence.get((int)0).name);
            this.printOutput(this.sequence.get(0), inputShape);
        }
        for (int i = 1; i < this.sequence.size(); ++i) {
            Node<T, F> node = this.sequence.get(i);
            if (this.verbose) {
                System.out.println("============== " + node.name);
            }
            this.outputStorage.put(node.name, new Tuple2(this.factory.create(new int[0]), this.factory.create(new int[0])));
            if (node.sources.size() == 0) {
                throw new RuntimeException("No sources!  Node = " + node.name);
            }
            inputs.clear();
            for (int j = 0; j < node.sources.size(); ++j) {
                InputAddress addr = node.sources.get(j);
                Node<T, F> src = this.lookup.get(addr.nodeName);
                if (src == null) {
                    throw new RuntimeException("Can't find input node from name. Bad network");
                }
                inputs.add(src.function.getOutputShape());
                if (!this.verbose) continue;
                System.out.println("   input addr " + addr.nodeName);
            }
            if (inputs.size() == 1) {
                node.function.initialize((int[])inputs.get(0));
                if (!this.verbose) continue;
                this.printOutput(node, (int[])inputs.get(0));
                continue;
            }
            if (node.combine == null) {
                throw new RuntimeException("Must specify a combine operator if there are multiple sources");
            }
            node.combine.initialize(inputs);
            node.function.initialize(node.combine.getOutputShape());
            if (!this.verbose) continue;
            this.printOutput(node, node.combine.getOutputShape());
        }
    }

    private void printOutput(Node<T, F> node, int[] input) {
        int[] output = node.function.getOutputShape();
        String sin = TensorOps.toStringShape(input);
        String sout = TensorOps.toStringShape(output);
        System.out.printf("%30s input %25s  out = %25s\n", node.function.getClass().getSimpleName(), sin, sout);
    }

    private void declareOutputStorage(int numBatch) {
        if (this.sequence.size() == 1) {
            return;
        }
        for (int i = 0; i < this.sequence.size(); ++i) {
            Node<T, F> node = this.sequence.get(i);
            Tuple2<T, T> storage = this.outputStorage.get(node.name);
            if (i == 0 || node.sources.size() == 1) {
                if (i != this.sequence.size() - 1) {
                    ((Tensor)storage.d0).reshape(TensorOps.WI(numBatch, node.function.getOutputShape()));
                }
                storage.d1 = null;
                continue;
            }
            if (i != this.sequence.size() - 1) {
                ((Tensor)storage.d0).reshape(TensorOps.WI(node.function.getOutputShape()));
            }
            ((Tensor)storage.d1).reshape(TensorOps.WI(node.combine.getOutputShape()));
        }
    }

    public void setParameters(Map<String, List<T>> nodeParameters) {
        for (int i = 0; i < this.sequence.size(); ++i) {
            Node<T, F> node = this.sequence.get(i);
            List<T> parameters = nodeParameters.get(node.name);
            if (parameters == null) continue;
            node.function.setParameters(parameters);
        }
    }

    public void process(T input, T output) {
        if (this.sequence.size() == 1) {
            this.sequence.get((int)0).function.forward(input, output);
            return;
        }
        this.declareOutputStorage(((Tensor)input).length(0));
        Node<T, F> node = this.sequence.get(0);
        Tuple2<T, T> storage = this.outputStorage.get(node.name);
        node.function.forward(input, (Tensor)((Tensor)storage.d0));
        ArrayList<Tensor> inputs = new ArrayList<Tensor>();
        for (int i = 1; i < this.sequence.size() - 1; ++i) {
            Node<T, F> node2 = this.sequence.get(i);
            Tuple2<T, T> nodeOutput = this.outputStorage.get(node2.name);
            inputs.clear();
            for (int j = 0; j < node2.sources.size(); ++j) {
                InputAddress addr = node2.sources.get(j);
                inputs.add((Tensor)this.outputStorage.get((Object)addr.nodeName).d0);
            }
            if (node2.sources.size() == 1) {
                node2.function.forward((Tensor)((Tensor)inputs.get(0)), (Tensor)((Tensor)nodeOutput.d0));
                continue;
            }
            node2.combine.combine(inputs, (Tensor)nodeOutput.d1);
            node2.function.forward((Tensor)((Tensor)nodeOutput.d1), (Tensor)((Tensor)nodeOutput.d0));
        }
        Node<T, F> node3 = this.sequence.get(this.sequence.size() - 1);
        inputs.clear();
        for (int j = 0; j < node3.sources.size(); ++j) {
            InputAddress addr = node3.sources.get(j);
            inputs.add((Tensor)this.outputStorage.get((Object)addr.nodeName).d0);
        }
        if (node3.sources.size() == 1) {
            node3.function.forward((Tensor)((Tensor)inputs.get(0)), output);
        } else {
            Tuple2<T, T> nodeOutput = this.outputStorage.get(node3.name);
            node3.combine.combine(inputs, (Tensor)nodeOutput.d1);
            node3.function.forward((Tensor)((Tensor)nodeOutput.d1), output);
        }
    }

    public List<Node<T, F>> getSequence() {
        return this.sequence;
    }

    public T getNodeOutput(int index) {
        return (T)((Tensor)this.outputStorage.get((Object)this.sequence.get((int)index).name).d0);
    }

    public int[] getOutputShape() {
        return this.sequence.get((int)(this.sequence.size() - 1)).function.getOutputShape();
    }

    public Class<T> getTensorType() {
        return this.factory.getTensorType();
    }
}

