/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tvm.contrib;

import org.apache.tvm.Device;
import org.apache.tvm.Function;
import org.apache.tvm.Module;
import org.apache.tvm.NDArray;

public class GraphModule {
    private Module module;
    private Device device;
    private Function fsetInput;
    private Function frun;
    private Function fgetOutput;
    private Function fgetInput;
    private Function fdebugGetOutput;
    private Function floadParams;

    GraphModule(Module module, Device dev) {
        this.module = module;
        this.device = dev;
        this.fsetInput = module.getFunction("set_input");
        this.frun = module.getFunction("run");
        this.fgetInput = module.getFunction("get_input");
        this.fgetOutput = module.getFunction("get_output");
        try {
            this.fdebugGetOutput = module.getFunction("debug_get_output");
        }
        catch (IllegalArgumentException illegalArgumentException) {
            // empty catch block
        }
        this.floadParams = module.getFunction("load_params");
    }

    public void release() {
        this.fsetInput.release();
        this.frun.release();
        this.fgetInput.release();
        this.fgetOutput.release();
        if (this.fdebugGetOutput != null) {
            this.fdebugGetOutput.release();
        }
        this.floadParams.release();
        this.module.release();
    }

    public GraphModule setInput(String key, NDArray value) {
        NDArray input = value;
        if (!value.device().equals(this.device)) {
            input = NDArray.empty(value.shape(), this.device);
            value.copyTo(input);
        }
        this.fsetInput.pushArg(key).pushArg(input).invoke();
        return this;
    }

    public GraphModule setInput(int key, NDArray value) {
        NDArray input = value;
        if (!value.device().equals(this.device)) {
            input = NDArray.empty(value.shape(), this.device);
            value.copyTo(input);
        }
        this.fsetInput.pushArg(key).pushArg(input).invoke();
        return this;
    }

    public GraphModule run() {
        this.frun.invoke();
        return this;
    }

    public NDArray getInput(int index, NDArray out) {
        this.fgetInput.pushArg(index).pushArg(out).invoke();
        return out;
    }

    public NDArray getOutput(int index, NDArray out) {
        this.fgetOutput.pushArg(index).pushArg(out).invoke();
        return out;
    }

    public NDArray debugGetOutput(String node, NDArray out) {
        if (this.fdebugGetOutput == null) {
            throw new RuntimeException("Please compile runtime with USE_PROFILER = ON");
        }
        this.fdebugGetOutput.pushArg(node).pushArg(out).invoke();
        return out;
    }

    public NDArray debugGetOutput(int node, NDArray out) {
        if (this.fdebugGetOutput == null) {
            throw new RuntimeException("Please compile runtime with USE_PROFILER = ON");
        }
        this.fdebugGetOutput.pushArg(node).pushArg(out).invoke();
        return out;
    }

    public GraphModule loadParams(byte[] params) {
        this.floadParams.pushArg(params).invoke();
        return this;
    }

    public Function getFunction(String key) {
        return this.module.getFunction(key);
    }
}

