/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tvm.rpc;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import org.apache.tvm.Function;
import org.apache.tvm.Module;
import org.apache.tvm.TVMContext;
import org.apache.tvm.rpc.RPC;
import org.apache.tvm.rpc.TVMRemoteContext;

public class RPCSession {
    private final Module session;
    private final int tblIndex;
    private final Map<String, Function> remoteFuncs = new HashMap<String, Function>();

    RPCSession(Module sess) {
        this.session = sess;
        this.tblIndex = (int)RPC.getApi("SessTableIndex").pushArg(this.session).invoke().asLong();
    }

    public Function getFunction(String name) {
        return this.session.getFunction(name);
    }

    public TVMContext context(String devType, int devId) {
        TVMContext ctx = new TVMContext(devType, devId);
        int encode = (this.tblIndex + 1) * 128;
        return new TVMRemoteContext(ctx.deviceType + encode, devId, this);
    }

    public TVMContext context(String devType) {
        return this.context(devType, 0);
    }

    public TVMContext context(int devType, int devId) {
        int encode = (this.tblIndex + 1) * 128;
        return new TVMRemoteContext(devType + encode, devId, this);
    }

    public TVMContext context(int devType) {
        return this.context(devType, 0);
    }

    public TVMContext cpu(int devId) {
        return this.context(1, devId);
    }

    public TVMContext cpu() {
        return this.cpu(0);
    }

    public TVMContext gpu(int devId) {
        return this.context(2, devId);
    }

    public TVMContext gpu() {
        return this.gpu(0);
    }

    public TVMContext cl(int devId) {
        return this.context(4, devId);
    }

    public TVMContext cl() {
        return this.cl(0);
    }

    public TVMContext vulkan(int devId) {
        return this.context(7, devId);
    }

    public TVMContext vulkan() {
        return this.vulkan(0);
    }

    public TVMContext metal(int devId) {
        return this.context(8, devId);
    }

    public TVMContext metal() {
        return this.metal(0);
    }

    public void upload(byte[] data, String target) {
        if (target == null) {
            throw new IllegalArgumentException("Please specify the upload target");
        }
        String funcName = "upload";
        Function remoteFunc = this.remoteFuncs.get("upload");
        if (remoteFunc == null) {
            remoteFunc = this.getFunction("tvm.rpc.server.upload");
            this.remoteFuncs.put("upload", remoteFunc);
        }
        remoteFunc.pushArg(target).pushArg(data).invoke();
    }

    public void upload(File data, String target) throws IOException {
        byte[] blob = RPCSession.getBytesFromFile(data);
        this.upload(blob, target);
    }

    public void upload(File data) throws IOException {
        this.upload(data, data.getName());
    }

    public byte[] download(String path) {
        String name = "download";
        Function func = this.remoteFuncs.get("download");
        if (func == null) {
            func = this.getFunction("tvm.rpc.server.download");
            this.remoteFuncs.put("download", func);
        }
        return func.pushArg(path).invoke().asBytes();
    }

    public Module loadModule(String path) {
        return RPC.getApi("LoadRemoteModule").pushArg(this.session).pushArg(path).invoke().asModule();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static byte[] getBytesFromFile(File file) throws IOException {
        int offset;
        long length = file.length();
        if (length > Integer.MAX_VALUE) {
            throw new IOException("File " + file.getName() + " is too large!");
        }
        byte[] bytes = new byte[(int)length];
        int numRead = 0;
        try (FileInputStream is = new FileInputStream(file);){
            for (offset = 0; offset < bytes.length && (numRead = ((InputStream)is).read(bytes, offset, bytes.length - offset)) >= 0; offset += numRead) {
            }
        }
        if (offset < bytes.length) {
            throw new IOException("Could not completely read file " + file.getName());
        }
        return bytes;
    }
}

