/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.forward.ConfigSpatial;
import deepboof.forward.SpatialAveragePooling;
import deepboof.forward.SpatialPadding2D_F32;
import deepboof.impl.forward.standard.SpatialWindowChannel;
import deepboof.tensors.Tensor_F32;
import java.util.List;

public class SpatialAveragePooling_F32
extends SpatialWindowChannel<Tensor_F32, SpatialPadding2D_F32>
implements SpatialAveragePooling<Tensor_F32> {
    protected float poolingSize;

    public SpatialAveragePooling_F32(ConfigSpatial config, SpatialPadding2D_F32 padding) {
        super(config, padding);
    }

    @Override
    public void _initialize() {
        super._initialize();
        if (this.shapeInput.length != 3) {
            throw new IllegalArgumentException("Expected 3D spatial tensor");
        }
        this.shapeOutput = (int[])this.shapeInput.clone();
        this.shapeOutput[1] = this.Ho;
        this.shapeOutput[2] = this.Wo;
        this.poolingSize = this.WW * this.HH;
    }

    @Override
    public void _setParameters(List<Tensor_F32> parameters) {
    }

    @Override
    public void _forward(Tensor_F32 input, Tensor_F32 output) {
        this.forwardChannel(input, output);
    }

    @Override
    protected void forwardAt_inner(Tensor_F32 input, int batch, int channel, int inY, int inX, int outY, int outX) {
        int inputIndexRow = input.idx(batch, channel, inY, inX);
        float sum = 0.0f;
        for (int j = 0; j < this.HH; ++j) {
            int inputIndex = inputIndexRow;
            for (int i = 0; i < this.WW; ++i) {
                sum += input.d[inputIndex++];
            }
            inputIndexRow += this.W;
        }
        ((Tensor_F32)this.output).d[((Tensor_F32)this.output).idx((int)batch, (int)channel, (int)outY, (int)outX)] = sum / this.poolingSize;
    }

    @Override
    protected void forwardAt_border(SpatialPadding2D_F32 padded, int batch, int channel, int padY, int padX, int outY, int outX) {
        int row0 = padY;
        int row1 = padY + this.HH;
        row0 += padded.getClippingOffsetRow(row0);
        row1 += padded.getClippingOffsetRow(row1);
        int col0 = padX;
        int col1 = padX + this.WW;
        col0 += padded.getClippingOffsetCol(col0);
        col1 += padded.getClippingOffsetCol(col1);
        float sum = 0.0f;
        for (int j = row0; j < row1; ++j) {
            for (int i = col0; i < col1; ++i) {
                sum += padded.get(batch, channel, j, i);
            }
        }
        ((Tensor_F32)this.output).d[((Tensor_F32)this.output).idx((int)batch, (int)channel, (int)outY, (int)outX)] = sum / (float)((row1 - row0) * (col1 - col0));
    }

    @Override
    public Class<Tensor_F32> getTensorType() {
        return Tensor_F32.class;
    }

    @Override
    public ConfigSpatial getConfiguration() {
        return this.config;
    }
}

