/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.nn.meta;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import ml.shifu.guagua.example.nn.NNUtils;
import ml.shifu.guagua.io.HaltBytable;

public class NNParams
extends HaltBytable {
    private double[] weights;
    private double[] gradients;
    private double testError = 0.0;
    private double trainError = 0.0;
    private long trainSize = 0L;

    public double[] getWeights() {
        return this.weights;
    }

    public void setWeights(double[] weights) {
        this.weights = weights;
    }

    public double getTestError() {
        return this.testError;
    }

    public void setTestError(double testError) {
        this.testError = testError;
    }

    public double getTrainError() {
        return this.trainError;
    }

    public void setTrainError(double trainError) {
        this.trainError = trainError;
    }

    public void accumulateGradients(double[] gradients) {
        if (this.gradients == null) {
            this.gradients = new double[gradients.length];
            Arrays.fill(this.gradients, 0.0);
        }
        if (this.weights == null) {
            this.weights = new double[gradients.length];
            NNUtils.randomize(gradients.length, this.weights);
        }
        for (int i = 0; i < gradients.length; ++i) {
            int n = i;
            this.gradients[n] = this.gradients[n] + gradients[i];
        }
    }

    public double[] getGradients() {
        return this.gradients;
    }

    public void setGradients(double[] gradients) {
        this.gradients = gradients;
    }

    public long getTrainSize() {
        return this.trainSize;
    }

    public void setTrainSize(long trainSize) {
        this.trainSize = trainSize;
    }

    public void accumulateTrainSize(long size) {
        this.trainSize = this.getTrainSize() + size;
    }

    public void reset() {
        this.setTrainSize(0L);
        if (this.gradients != null) {
            Arrays.fill(this.gradients, 0.0);
        }
    }

    public void doWrite(DataOutput out) throws IOException {
        out.writeDouble(this.getTrainError());
        out.writeDouble(this.getTestError());
        out.writeLong(this.getTrainSize());
        out.writeInt(this.getWeights().length);
        for (double weight : this.getWeights()) {
            out.writeDouble(weight);
        }
        out.writeInt(this.getGradients().length);
        for (double gradient : this.getGradients()) {
            out.writeDouble(gradient);
        }
    }

    public void doReadFields(DataInput in) throws IOException {
        this.trainError = in.readDouble();
        this.testError = in.readDouble();
        this.trainSize = in.readLong();
        int len = in.readInt();
        double[] weights = new double[len];
        for (int i = 0; i < len; ++i) {
            weights[i] = in.readDouble();
        }
        this.weights = weights;
        len = in.readInt();
        double[] gradients = new double[len];
        for (int i = 0; i < len; ++i) {
            gradients[i] = in.readDouble();
        }
        this.gradients = gradients;
    }

    public String toString() {
        return String.format("NNParams [testError=%s, trainError=%s, trainSize=%s, weights=%s, gradients%s]", this.testError, this.trainError, this.trainSize, Arrays.toString(this.weights), Arrays.toString(this.gradients));
    }
}

