"use strict";
var __extends = (this && this.__extends) || (function () {
    var extendStatics = Object.setPrototypeOf ||
        ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
        function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
    return function (d, b) {
        extendStatics(d, b);
        function __() { this.constructor = d; }
        d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
    };
})();
Object.defineProperty(exports, "__esModule", { value: true });
var math_1 = require("./math");
var ndarray = require("./ndarray");
var ndarray_1 = require("./ndarray");
var addscaledmat_gpu_1 = require("./webgl/addscaledmat_gpu");
var argmaxequals_gpu_1 = require("./webgl/argmaxequals_gpu");
var argminmax_gpu_1 = require("./webgl/argminmax_gpu");
var batchnorm_gpu_1 = require("./webgl/batchnorm_gpu");
var binaryop_gpu_1 = require("./webgl/binaryop_gpu");
var concat3d_gpu_1 = require("./webgl/concat3d_gpu");
var conv_backprop_gpu_1 = require("./webgl/conv_backprop_gpu");
var conv_gpu_1 = require("./webgl/conv_gpu");
var copy_gpu_1 = require("./webgl/copy_gpu");
var gpgpu_context_1 = require("./webgl/gpgpu_context");
var gpgpu_math = require("./webgl/gpgpu_math");
var gpgpu_util = require("./webgl/gpgpu_util");
var logsumexp_gpu_1 = require("./webgl/logsumexp_gpu");
var max_pool_backprop_gpu_1 = require("./webgl/max_pool_backprop_gpu");
var minmax_gpu_1 = require("./webgl/minmax_gpu");
var mulmat_gpu_1 = require("./webgl/mulmat_gpu");
var pool_gpu_1 = require("./webgl/pool_gpu");
var reducesum_gpu_1 = require("./webgl/reducesum_gpu");
var resize_bilinear_gpu_1 = require("./webgl/resize_bilinear_gpu");
var texture_manager_1 = require("./webgl/texture_manager");
var unaryop_gpu_1 = require("./webgl/unaryop_gpu");
var webgl_util = require("./webgl/webgl_util");
var NDArrayMathGPU = (function (_super) {
    __extends(NDArrayMathGPU, _super);
    function NDArrayMathGPU(gpgpu, safeMode) {
        if (safeMode === void 0) { safeMode = true; }
        var _this = _super.call(this, safeMode) || this;
        _this.binaryCache = {};
        if (gpgpu == null) {
            var gl = gpgpu_util.createWebGLContext();
            _this.gpgpu = new gpgpu_context_1.GPGPUContext(gl);
            _this.gpgpuCreatedLocally = true;
        }
        else {
            _this.gpgpu = gpgpu;
            _this.gpgpuCreatedLocally = false;
        }
        _this.textureManager = new texture_manager_1.TextureManager(_this.gpgpu);
        ndarray.initializeGPU(_this.gpgpu, _this.textureManager);
        return _this;
    }
    NDArrayMathGPU.prototype.getGPGPUContext = function () {
        return this.gpgpu;
    };
    NDArrayMathGPU.prototype.cloneInternal = function (ndarray) {
        var texShape = ndarray.getTextureShapeRC();
        var source = ndarray.as2D(texShape[0], texShape[1]);
        var output = this.makeOutputArray(texShape);
        this.copy2D(source, [0, 0], texShape, output, [0, 0], texShape);
        return output.reshape(ndarray.shape);
    };
    NDArrayMathGPU.prototype.slice2DInternal = function (input, beginRowCol, sizeRowCol) {
        var result = ndarray_1.NDArray.make(sizeRowCol, {
            texture: this.textureManager.acquireTexture(sizeRowCol),
            textureShapeRC: sizeRowCol
        });
        this.copy2DInternal(input, beginRowCol, sizeRowCol, result, [0, 0], sizeRowCol);
        return result;
    };
    NDArrayMathGPU.prototype.copy2DInternal = function (source, sourceBeginRowCol, sourceSizeRowCol, dest, destBeginRowCol, destSizeRowCol) {
        var program = new copy_gpu_1.Copy2DProgram(sourceSizeRowCol[1], destSizeRowCol[1]);
        var customSetup = program.getCustomSetupFunc(sourceBeginRowCol, destBeginRowCol, destSizeRowCol);
        this.compileAndRun(program, [source], dest, customSetup);
    };
    NDArrayMathGPU.prototype.concat3DInternal = function (x1, x2, axis) {
        var program = new concat3d_gpu_1.Concat3DProgram(x1.shape, x2.shape, axis);
        return this.compileAndRun(program, [x1, x2]);
    };
    NDArrayMathGPU.prototype.scaledArrayAddInternal = function (c1, a, c2, b) {
        var program = new addscaledmat_gpu_1.AddScaledMatProgram(a.shape, b.shape);
        return this.compileAndRun(program, [a, b, c1, c2]);
    };
    NDArrayMathGPU.prototype.negInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.NEG);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.makeOutputArray = function (shape) {
        var textureShapeRC = webgl_util.getTextureShapeFromLogicalShape(this.gpgpu.gl, shape);
        var texture = this.textureManager.acquireTexture(textureShapeRC);
        return ndarray_1.NDArray.make(shape, { texture: texture, textureShapeRC: textureShapeRC });
    };
    NDArrayMathGPU.prototype.compileAndRun = function (program, inputs, output, customSetup) {
        var _this = this;
        if (output == null) {
            output = this.makeOutputArray(program.outputShape);
        }
        var key = gpgpu_math.makeShaderKey(program, inputs, output);
        var binary = this.getAndSaveBinary(key, function () {
            return gpgpu_math.compileProgram(_this.gpgpu, program, inputs, output);
        });
        gpgpu_math.runProgram(binary, inputs, output, customSetup);
        return output;
    };
    NDArrayMathGPU.prototype.matMulInternal = function (a, b, aOrientation, bOrientation) {
        var program = new mulmat_gpu_1.MatMulProgram(a.shape, b.shape, aOrientation, bOrientation);
        return this.compileAndRun(program, [a, b]);
    };
    NDArrayMathGPU.prototype.multiplyInternal = function (a, b) {
        var program = new binaryop_gpu_1.BinaryOpProgram('*', a.shape, b.shape);
        return this.compileAndRun(program, [a, b]);
    };
    NDArrayMathGPU.prototype.batchNormalization3DInternal = function (x, mean, variance, varianceEpsilon, scale, offset) {
        if (varianceEpsilon === void 0) { varianceEpsilon = 0.000001; }
        var inputs = [x, mean, variance];
        var offsetShape = null;
        if (offset != null) {
            offsetShape = offset.shape;
            inputs.push(offset);
        }
        var scaleShape = null;
        if (scale != null) {
            scaleShape = scale.shape;
            inputs.push(scale);
        }
        var program = new batchnorm_gpu_1.BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
        return this.compileAndRun(program, inputs);
    };
    NDArrayMathGPU.prototype.switchDimInternal = function (a, newDim) {
        throw new Error('Not yet implemented!');
    };
    NDArrayMathGPU.prototype.sumInternal = function (a) {
        var program = new reducesum_gpu_1.ReduceSumProgram(a.size);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.argMinInternal = function (a) {
        var program = new argminmax_gpu_1.ArgMinMaxProgram(a.size, 'min');
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.argMaxInternal = function (a) {
        var program = new argminmax_gpu_1.ArgMinMaxProgram(a.size, 'max');
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.argMaxEqualsInternal = function (x1, x2) {
        var program = new argmaxequals_gpu_1.ArgMaxEqualsProgram(x1.size, x2.size);
        return this.compileAndRun(program, [x1, x2]);
    };
    NDArrayMathGPU.prototype.topKInternal = function (ndarray, k) {
        throw new Error('topK GPU not yet implemented!');
    };
    NDArrayMathGPU.prototype.minInternal = function (a) {
        var program = new minmax_gpu_1.MinMaxProgram(a.size, 'min');
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.maxInternal = function (a) {
        var program = new minmax_gpu_1.MinMaxProgram(a.size, 'max');
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.divideInternal = function (a, b) {
        var program = new binaryop_gpu_1.BinaryOpProgram('/', a.shape, b.shape);
        return this.compileAndRun(program, [a, b]);
    };
    NDArrayMathGPU.prototype.addInternal = function (a, b) {
        var program = new binaryop_gpu_1.BinaryOpProgram('+', a.shape, b.shape);
        return this.compileAndRun(program, [a, b]);
    };
    NDArrayMathGPU.prototype.subInternal = function (a, b) {
        var program = new binaryop_gpu_1.BinaryOpProgram('-', a.shape, b.shape);
        return this.compileAndRun(program, [a, b]);
    };
    NDArrayMathGPU.prototype.logSumExpInternal = function (a) {
        var program = new logsumexp_gpu_1.LogSumExpProgram(a.size);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.expInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.EXP);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.logInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.LOG);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.reluInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.RELU);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.sigmoidInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.SIGMOID);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.tanhInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.TANH);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.sinInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.SIN);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.stepInternal = function (a) {
        var program = new unaryop_gpu_1.UnaryOpProgram(a.shape, unaryop_gpu_1.UnaryOp.STEP);
        return this.compileAndRun(program, [a]);
    };
    NDArrayMathGPU.prototype.conv2dInternal = function (x, weights, bias, stride, zeroPad) {
        var fieldSize = weights.shape[0];
        var outputDepth = weights.shape[3];
        var program = new conv_gpu_1.Conv2DProgram(x.shape, fieldSize, outputDepth, stride, zeroPad, bias != null);
        var inputs = bias != null ? [x, weights, bias] : [x, weights];
        return this.compileAndRun(program, inputs);
    };
    NDArrayMathGPU.prototype.conv2dBackPropInternal = function (x, dy, weights, stride, pad) {
        var fSize = weights.shape[0];
        var dw = this.conv2dDerWeights(x, dy, fSize, stride, pad);
        var db = this.conv2dDerBias(dy);
        var dx = this.conv2dTransposeInternal(dy, weights, null, stride, pad);
        return { dx: dx, db: db, dw: dw };
    };
    NDArrayMathGPU.prototype.conv2dTransposeInternal = function (x, weights, bias, origStride, origPad) {
        var origInputDepth = weights.shape[2];
        var fieldSize = weights.shape[0];
        var program = new conv_backprop_gpu_1.Conv2DTransposeProgram(x.shape, fieldSize, origInputDepth, origStride, origPad, bias != null);
        var inputs = bias != null ? [x, weights, bias] : [x, weights];
        return this.compileAndRun(program, inputs);
    };
    NDArrayMathGPU.prototype.conv2dDerWeights = function (x, dY, fSize, stride, zeroPad) {
        var outputDepth = dY.shape[2];
        var program = new conv_backprop_gpu_1.Conv2DDerWeightsProgram(x.shape, fSize, outputDepth, stride, zeroPad);
        return this.compileAndRun(program, [x, dY]);
    };
    NDArrayMathGPU.prototype.conv2dDerBias = function (dY) {
        var program = new conv_backprop_gpu_1.Conv2DDerBiasProgram(dY.shape);
        return this.compileAndRun(program, [dY]);
    };
    NDArrayMathGPU.prototype.maxPoolInternal = function (x, fSize, stride, pad) {
        var program = new pool_gpu_1.Pool2DProgram(x.shape, fSize, stride, pad, 'max', false);
        return this.compileAndRun(program, [x]);
    };
    NDArrayMathGPU.prototype.minPoolInternal = function (x, fSize, stride, pad) {
        var program = new pool_gpu_1.Pool2DProgram(x.shape, fSize, stride, pad, 'min', false);
        return this.compileAndRun(program, [x]);
    };
    NDArrayMathGPU.prototype.avgPoolInternal = function (x, fSize, stride, pad) {
        var program = new pool_gpu_1.Pool2DProgram(x.shape, fSize, stride, pad, 'avg', false);
        return this.compileAndRun(program, [x]);
    };
    NDArrayMathGPU.prototype.maxPoolBackpropInternal = function (dy, x, fSize, origStride, origPad) {
        var getPositions = true;
        var maxPoolPositionsProgram = new pool_gpu_1.Pool2DProgram(x.shape, fSize, origStride, origPad, 'max', getPositions);
        var maxPoolPositions = this.compileAndRun(maxPoolPositionsProgram, [x]);
        var maxPoolBackPropProgram = new max_pool_backprop_gpu_1.MaxPool2DBackpropProgram(dy.shape, fSize, origStride, origPad);
        var result = this.compileAndRun(maxPoolBackPropProgram, [dy, maxPoolPositions]);
        maxPoolPositions.dispose();
        return result;
    };
    NDArrayMathGPU.prototype.resizeBilinear3DInternal = function (x, newShape2D, alignCorners) {
        var program = new resize_bilinear_gpu_1.ResizeBilinear3DProgram(x.shape, newShape2D, alignCorners);
        return this.compileAndRun(program, [x]);
    };
    NDArrayMathGPU.prototype.getAndSaveBinary = function (key, getBinary) {
        if (!(key in this.binaryCache)) {
            this.binaryCache[key] = getBinary();
        }
        return this.binaryCache[key];
    };
    NDArrayMathGPU.prototype.getTextureManager = function () {
        return this.textureManager;
    };
    NDArrayMathGPU.prototype.dispose = function () {
        for (var key in this.binaryCache) {
            this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);
        }
        this.textureManager.dispose();
        if (this.gpgpuCreatedLocally) {
            this.gpgpu.dispose();
        }
    };
    return NDArrayMathGPU;
}(math_1.NDArrayMath));
exports.NDArrayMathGPU = NDArrayMathGPU;
//# sourceMappingURL=math_gpu.js.map