/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class Embedding<T>
extends ParameterBlock {
    private static final byte VERSION = 2;
    private int embeddingSize;
    private boolean useDefault;
    private DataType dataType;
    private Map<T, Integer> embedder;
    private int numItems;
    private Parameter embedding;

    Embedding(Builder<T> builder) {
        this.embeddingSize = ((Builder)builder).embeddingSize;
        this.useDefault = ((Builder)builder).useDefault;
        this.dataType = ((Builder)builder).dataType;
        this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT);
        this.embedder = new ConcurrentHashMap<T, Integer>(((Builder)builder).items.size());
        this.numItems = 0;
        if (this.useDefault) {
            ++this.numItems;
        }
        for (Object item : ((Builder)builder).items) {
            this.embedder.put(item, this.numItems++);
        }
        this.inputShapes = new Shape[]{new Shape(-1L)};
    }

    public Embedding(NDArray embedding, List<T> items) {
        this.embeddingSize = Math.toIntExact(embedding.getShape().get(1));
        this.useDefault = false;
        this.dataType = embedding.getDataType();
        this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT);
        this.embedding.setArray(embedding);
        this.numItems = items.size();
        this.embedder = new ConcurrentHashMap<T, Integer>(this.numItems);
        for (int i = 0; i < items.size(); ++i) {
            this.embedder.put(items.get(i), i);
        }
        this.inputShapes = new Shape[]{new Shape(-1L)};
    }

    public static Builder<?> builder() {
        return new Builder();
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[]{inputShapes[0].addAll(new Shape(this.embeddingSize))};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.singletonList(this.embedding);
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        if ("embedding".equals(name)) {
            return new Shape(this.numItems, this.embeddingSize);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        NDList opInputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = opInputs.head().getNDArrayInternal();
        NDList result = ex.embedding(opInputs, this.numItems, this.embeddingSize, this.dataType, params);
        if (inputs.singletonOrThrow().getShape().dimension() == 0) {
            result = new NDList(result.singletonOrThrow().reshape(this.embeddingSize));
        }
        return result;
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(2);
        this.saveInputShapes(os);
        this.embedding.save(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.embedding.load(manager, is);
    }

    public boolean hasItem(T item) {
        return this.embedder.containsKey(item);
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        NDArray items = inputs.singletonOrThrow();
        Device device = items.getDevice();
        NDList ret = new NDList(2);
        if (items.getShape().dimension() == 0) {
            ret.add(items.reshape(1L));
        } else {
            ret.add(items);
        }
        ret.add(parameterStore.getValue(this.embedding, device));
        return ret;
    }

    public NDArray embed(NDManager manager, T[] items) {
        return manager.create(Arrays.stream(items).mapToInt(this::embedHelper).toArray());
    }

    public NDArray embed(NDManager manager, T item) {
        return manager.create(this.embedHelper(item));
    }

    private int embedHelper(T value) {
        if (this.embedder.containsKey(value)) {
            return this.embedder.get(value);
        }
        if (this.useDefault) {
            return 0;
        }
        throw new IllegalArgumentException("The provided item was not found");
    }

    public static final class Builder<T> {
        private Class<T> embeddingType;
        private Collection<T> items;
        private int embeddingSize;
        private boolean useDefault = true;
        private DataType dataType = DataType.FLOAT32;

        Builder() {
        }

        private Builder(Class<T> embeddingType, Builder<?> parent) {
            this.embeddingType = embeddingType;
            this.embeddingSize = parent.embeddingSize;
            this.useDefault = parent.useDefault;
            this.dataType = parent.dataType;
        }

        public Class<T> getEmbeddingType() {
            return this.embeddingType;
        }

        public <T> Builder<T> setType(Class<T> embeddingType) {
            return new Builder<T>(embeddingType, this);
        }

        public Builder<T> setItems(Collection<T> items) {
            this.items = items;
            return this;
        }

        public Builder<T> setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this;
        }

        public Builder<T> optUseDefault(boolean useDefault) {
            this.useDefault = useDefault;
            return this;
        }

        public Builder<T> optDataType(DataType dataType) {
            this.dataType = dataType;
            return this;
        }

        public Embedding<T> build() {
            if (this.items == null) {
                throw new IllegalArgumentException("You must specify the items to embed");
            }
            if (this.embeddingSize == 0) {
                throw new IllegalArgumentException("You must specify the embedding size");
            }
            return new Embedding(this);
        }
    }
}

