/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;

public class Linear
extends AbstractBlock {
    private static final byte VERSION = 4;
    private long units;
    private long inputFeatures;
    private Shape inputShape;
    private Parameter weight;
    private Parameter bias;

    Linear(Builder builder) {
        super((byte)4);
        this.units = builder.units;
        this.weight = this.addParameter(new Parameter("weight", this, ParameterType.WEIGHT), (Shape[] inputShapes) -> new Shape(this.units, this.inputFeatures));
        if (builder.bias) {
            this.bias = this.addParameter(new Parameter("bias", this, ParameterType.BIAS), new Shape(this.units));
        }
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        Device device = input.getDevice();
        NDArray weightArr = parameterStore.getValue(this.weight, device, training);
        NDArray biasArr = parameterStore.getValue(this.bias, device, training);
        return Linear.linear(input, weightArr, biasArr);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        return new Shape[]{this.inputShape.addAll(new Shape(this.units))};
    }

    @Override
    public PairList<String, Shape> describeInput() {
        return new PairList<String, Shape>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override
    public void beforeInitialize(Shape[] inputShapes) {
        this.inputShapes = inputShapes;
        Shape input = inputShapes[0];
        this.inputFeatures = input.get(input.dimension() - 1);
        this.inputShape = input.slice(0, input.dimension() - 1);
    }

    @Override
    protected void saveMetadata(DataOutputStream os) throws IOException {
        os.writeLong(this.units);
        os.writeLong(this.inputFeatures);
        os.write(this.inputShape.getEncoded());
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version < 1 || version > 4) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        if (version == 4) {
            this.units = is.readLong();
            this.inputFeatures = is.readLong();
        } else if (version == 2) {
            if (is.readBoolean()) {
                throw new IllegalArgumentException("flatten is not supported!");
            }
            this.inputFeatures = is.readLong();
        } else if (version == 3) {
            this.units = is.readLong();
            if (is.readBoolean()) {
                throw new IllegalArgumentException("flatten is not supported!");
            }
            this.inputFeatures = is.readLong();
        } else {
            this.inputFeatures = Shape.decode(is).size();
        }
        this.inputShape = Shape.decode(is);
    }

    public static NDList linear(NDArray input, NDArray weight) {
        return Linear.linear(input, weight, null);
    }

    public static NDList linear(NDArray input, NDArray weight, NDArray bias) {
        return input.getNDArrayInternal().linear(input, weight, bias);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private long units;
        private boolean bias = true;

        Builder() {
        }

        public Builder setUnits(long units) {
            this.units = units;
            return this;
        }

        public Builder optBias(boolean bias) {
            this.bias = bias;
            return this;
        }

        public Linear build() {
            Preconditions.checkArgument(this.units > 0L, "You must specify unit");
            return new Linear(this);
        }
    }
}

