/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.function.Function;

public class SequentialBlock
extends AbstractBlock {
    private static final byte VERSION = 2;

    public SequentialBlock() {
        super((byte)2);
    }

    public SequentialBlock addAll(Block ... blocks) {
        this.addAll(Arrays.asList(blocks));
        return this;
    }

    public SequentialBlock addAll(Collection<Block> blocks) {
        blocks.forEach(this::add);
        return this;
    }

    public SequentialBlock add(Block block) {
        if (block != null) {
            this.addChildBlock(block.getClass().getSimpleName(), block);
        }
        return this;
    }

    public SequentialBlock add(Function<NDList, NDList> f) {
        this.add(new LambdaBlock(f));
        return this;
    }

    public void removeLastBlock() {
        this.children.remove(this.children.size() - 1);
    }

    public void replaceLastBlock(Block block) {
        this.removeLastBlock();
        if (block != null) {
            this.add(block);
        }
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList current = inputs;
        for (Block block : this.children.values()) {
            current = block.forward(parameterStore, current, training);
        }
        return current;
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        Shape[] shapes = inputShapes;
        for (Block child : this.getChildren().values()) {
            shapes = child.initialize(manager, dataType, shapes);
        }
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        if (this.children.isEmpty()) {
            throw new IllegalArgumentException("The sequential block is empty");
        }
        Shape[] current = inputs;
        for (Block block : this.children.values()) {
            current = block.getOutputShapes(manager, current);
        }
        return current;
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Sequential(\n");
        for (Block block : this.children.values()) {
            String blockString = block.toString().replaceAll("(?m)^", "\t");
            sb.append(blockString).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}

