/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.lnr;

import java.util.Arrays;
import java.util.Random;
import ml.shifu.guagua.example.lnr.LinearRegressionParams;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;
import ml.shifu.guagua.util.NumberFormatUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LinearRegressionMaster
implements MasterComputable<LinearRegressionParams, LinearRegressionParams> {
    private static final Logger LOG = LoggerFactory.getLogger(LinearRegressionMaster.class);
    private static final Random RANDOM = new Random();
    private int inputNum;
    private double[] weights;
    private double learnRate;

    private void init(MasterContext<LinearRegressionParams, LinearRegressionParams> context) {
        this.inputNum = NumberFormatUtils.getInt((String)"lr.input.num", (int)2);
        this.learnRate = NumberFormatUtils.getDouble((String)"lr.learning.rate", (double)0.1);
    }

    public LinearRegressionParams compute(MasterContext<LinearRegressionParams, LinearRegressionParams> context) {
        if (context.isFirstIteration()) {
            this.init(context);
            this.weights = new double[this.inputNum + 1];
            for (int i = 0; i < this.weights.length; ++i) {
                this.weights[i] = RANDOM.nextDouble();
            }
        } else {
            double[] gradients = new double[this.inputNum + 1];
            double sumError = 0.0;
            int size = 0;
            for (LinearRegressionParams param : context.getWorkerResults()) {
                if (param != null) {
                    for (int i = 0; i < gradients.length; ++i) {
                        int n = i;
                        gradients[n] = gradients[n] + param.getParameters()[i];
                    }
                    sumError += param.getError();
                }
                ++size;
            }
            for (int i = 0; i < this.weights.length; ++i) {
                int n = i;
                this.weights[n] = this.weights[n] - this.learnRate * gradients[i];
            }
            LOG.info("DEBUG: Weights: {}", (Object)Arrays.toString(this.weights));
            LOG.info("Iteration {} with error {}", (Object)context.getCurrentIteration(), (Object)(sumError / (double)size));
        }
        return new LinearRegressionParams(this.weights);
    }
}

