/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ParallelBlock
extends AbstractBlock {
    private static final byte VERSION = 2;
    private Function<List<NDList>, NDList> function;

    public ParallelBlock(Function<List<NDList>, NDList> function) {
        this(function, Collections.emptyList());
    }

    public ParallelBlock(Function<List<NDList>, NDList> function, List<Block> blocks) {
        super((byte)2);
        this.function = function;
        this.addAll(blocks);
    }

    public final ParallelBlock addAll(Block ... blocks) {
        return this.addAll(Arrays.asList(blocks));
    }

    public final ParallelBlock addAll(Collection<Block> blocks) {
        blocks.forEach(this::add);
        return this;
    }

    public final ParallelBlock add(Block block) {
        if (block != null) {
            this.addChildBlock(block.getClass().getSimpleName(), block);
        }
        return this;
    }

    public final ParallelBlock add(Function<NDList, NDList> f) {
        return this.add(new LambdaBlock(f));
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        return this.function.apply(this.children.values().stream().map(block -> block.forward(parameterStore, inputs, training, params)).collect(Collectors.toList()));
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        for (Block child : this.getChildren().values()) {
            child.initialize(manager, dataType, inputShapes);
        }
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        if (this.children.isEmpty()) {
            throw new IllegalArgumentException("The parallel block is empty");
        }
        try (NDManager subManager = manager.newSubManager();){
            ArrayList<NDList> inputs = new ArrayList<NDList>();
            for (Block block : this.children.values()) {
                Shape[] shapes = block.getOutputShapes(manager, inputShapes);
                NDList output = new NDList(shapes.length);
                for (Shape shape : shapes) {
                    output.add(subManager.create(shape));
                }
                inputs.add(output);
            }
            NDList output = this.function.apply(inputs);
            Shape[] outputShapes = new Shape[output.size()];
            for (int i = 0; i < output.size(); ++i) {
                outputShapes[i] = ((NDArray)output.get(i)).getShape();
            }
            Shape[] shapeArray = outputShapes;
            return shapeArray;
        }
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Parallel(\n");
        for (Block block : this.children.values()) {
            String blockString = block.toString().replaceAll("(?m)^", "\t");
            sb.append(blockString).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}

