"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var conv_util = require("../conv_util");
var Pool2DProgram = (function () {
    function Pool2DProgram(xShape, fSize, stride, pad, poolType, computePositions) {
        this.variableNames = ['x'];
        if (poolType === 'avg' && computePositions) {
            throw new Error('Cannot compute positions for average pool.');
        }
        var returnValue = 'minMaxValue';
        if (computePositions) {
            returnValue = 'minMaxPosition';
        }
        else if (poolType === 'avg') {
            returnValue = "avgValue / " + fSize * fSize + ".0";
        }
        var xRowsLimit = xShape[0] - 0.5;
        var xColsLimit = xShape[1] - 0.5;
        this.params = [stride, pad, fSize, computePositions];
        this.outputShape =
            conv_util.computeOutputShape3D(xShape, fSize, xShape[2], stride, pad);
        this.userCode = "\n      void main() {\n        vec3 coords = getOutputCoords();\n        float yR = coords.x;\n        float yC = coords.y;\n        float d = coords.z;\n\n        vec2 xRCCorner = vec2(yR, yC) * vec2(" + stride + ".0, " + stride + ".0) -\n            vec2(" + pad + ".0, " + pad + ".0);\n        float xRCorner = xRCCorner.x;\n        float xCCorner = xRCCorner.y;\n\n        // max/min x(?, ?, d) to get y(yR, yC, d).\n        // ? = to be determined\n        float minMaxValue = 0.0;\n        float minMaxValueFound = 0.0;\n        float minMaxPosition = 0.0;\n        float avgValue = 0.0;\n\n        for (int iwR = 0; iwR < " + fSize + "; iwR++) {\n          float wR = float(iwR);\n          float xR = xRCorner + wR;\n\n          if (xR < 0.0 || xR > " + xRowsLimit + ") {\n            continue;\n          }\n\n          for (int iwC = 0; iwC < " + fSize + "; iwC++) {\n            float wC = float(iwC);\n            float xC = xCCorner + wC;\n\n            if (xC < 0.0 || xC > " + xColsLimit + ") {\n              continue;\n            }\n\n            float value = getX(xR, xC, d);\n\n            if (isNaN(value)) {\n              setOutput(value);\n              return;\n            }\n\n            if (" + (poolType === 'avg') + ") {\n              avgValue += value;\n            } else {\n              // If a min / max value has already been found, use it. If not,\n              // use the current value.\n              float currMinMaxValue = mix(\n                  value, minMaxValue, minMaxValueFound);\n              if (value " + (poolType === 'min' ? '<=' : '>=') + " currMinMaxValue) {\n                minMaxValue = value;\n                minMaxValueFound = 1.0;\n                if (" + computePositions + ") {\n                  minMaxPosition = wR * " + fSize + ".0 + wC;\n                }\n              }\n            }\n          }\n        }\n        setOutput(" + returnValue + ");\n      }\n    ";
    }
    return Pool2DProgram;
}());
exports.Pool2DProgram = Pool2DProgram;
//# sourceMappingURL=pool_gpu.js.map