/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.forward.ConfigConvolve2D;
import deepboof.forward.SpatialConvolve2D;
import deepboof.forward.SpatialPadding2D_F64;
import deepboof.impl.forward.standard.SpatialWindowImage;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class SpatialConvolve2D_F64
extends SpatialWindowImage<Tensor_F64, SpatialPadding2D_F64>
implements SpatialConvolve2D<Tensor_F64> {
    protected int F;
    protected Tensor_F64 weights;
    protected Tensor_F64 bias;
    protected double[] cacheLocal = new double[0];

    public SpatialConvolve2D_F64(ConfigConvolve2D config, SpatialPadding2D_F64 padding) {
        super(config, padding);
        this.F = config.F;
    }

    @Override
    public void _initialize() {
        super._initialize();
        this.shapeOutput = TensorOps.WI(this.F, this.Ho, this.Wo);
        this.shapeParameters.add(TensorOps.WI(this.F, this.C, this.HH, this.WW));
        this.shapeParameters.add(TensorOps.WI(this.F));
    }

    @Override
    public void _setParameters(List<Tensor_F64> parameters) {
        this.weights = parameters.get(0);
        this.bias = parameters.get(1);
        this.cacheLocal = new double[this.HH * this.WW * this.C];
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        super.forwardImage(input, output);
    }

    @Override
    protected void forwardAt_inner(Tensor_F64 input, int batch, int inY, int inX, int outY, int outX) {
        int cacheIndex = 0;
        for (int channel = 0; channel < this.C; ++channel) {
            int indexImageStart = input.idx(batch, channel, inY, inX);
            for (int kerY = 0; kerY < this.HH; ++kerY) {
                int indexI = indexImageStart;
                for (int kerX = 0; kerX < this.WW; ++kerX) {
                    this.cacheLocal[cacheIndex++] = input.d[indexI++];
                }
                indexImageStart += this.W;
            }
        }
        this.convolveCache(batch, outY, outX);
    }

    @Override
    protected void forwardAt_border(SpatialPadding2D_F64 padded, int batch, int padY, int padX, int outY, int outX) {
        int cacheIndex = 0;
        for (int channel = 0; channel < this.C; ++channel) {
            for (int kerY = 0; kerY < this.HH; ++kerY) {
                for (int kerX = 0; kerX < this.WW; ++kerX) {
                    this.cacheLocal[cacheIndex++] = padded.get(batch, channel, padY + kerY, padX + kerX);
                }
            }
        }
        this.convolveCache(batch, outY, outX);
    }

    private void convolveCache(int batch, int outY, int outX) {
        int length = this.C * this.HH * this.WW;
        double[] d = this.weights.d;
        int indexW = this.weights.startIndex;
        for (int kernelIndex = 0; kernelIndex < this.F; ++kernelIndex) {
            double sum = 0.0;
            int cacheIndex = 0;
            while (cacheIndex < length) {
                sum += this.cacheLocal[cacheIndex++] * d[indexW++];
            }
            ((Tensor_F64)this.output).d[((Tensor_F64)this.output).idx((int)batch, (int)kernelIndex, (int)outY, (int)outX)] = sum += this.bias.d[this.bias.idx(kernelIndex)];
        }
    }

    @Override
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }

    @Override
    public ConfigConvolve2D getConfiguration() {
        return (ConfigConvolve2D)this.config;
    }
}

