/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tvm;

import java.util.HashMap;
import java.util.Map;
import org.apache.tvm.APIInternal;
import org.apache.tvm.Base;
import org.apache.tvm.TVMValue;
import org.apache.tvm.TVMValueLong;

public class Device {
    static final int kDLCPU = 1;
    static final int kDLCUDA = 2;
    static final int kDLCUDAHost = 3;
    static final int kDLOpenCL = 4;
    static final int kDLVulkan = 7;
    static final int kDLMetal = 8;
    static final int kDLVPI = 9;
    static final int kDLROCM = 10;
    static final int kDLROCMHost = 11;
    static final int kDLExtDev = 12;
    static final int kDLCUDAManaged = 13;
    static final int kDLOneAPI = 14;
    static final int kDLWebGPU = 15;
    static final int kDLHexagon = 16;
    static final int kDLAOCL = 32;
    static final int kDLSDAccel = 33;
    static final int kOpenGL = 34;
    static final int kDLMicroDev = 35;
    private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
    private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();
    public final int deviceType;
    public final int deviceId;

    public static Device cpu(int devId) {
        return new Device(1, devId);
    }

    public static Device cpu() {
        return Device.cpu(0);
    }

    public static Device cuda(int devId) {
        return new Device(2, devId);
    }

    public static Device cuda() {
        return Device.cuda(0);
    }

    public static Device opencl(int devId) {
        return new Device(4, devId);
    }

    public static Device opencl() {
        return Device.opencl(0);
    }

    public static Device vulkan(int devId) {
        return new Device(7, devId);
    }

    public static Device vulkan() {
        return Device.vulkan(0);
    }

    public static Device metal(int devId) {
        return new Device(8, devId);
    }

    public static Device metal() {
        return Device.metal(0);
    }

    public static Device vpi(int devId) {
        return new Device(9, devId);
    }

    public static Device vpi() {
        return Device.vpi(0);
    }

    public static Device hexagon(int devId) {
        return new Device(16, devId);
    }

    public static Device hexagon() {
        return Device.hexagon(0);
    }

    public Device(int deviceType, int deviceId) {
        this.deviceType = deviceType;
        this.deviceId = deviceId;
    }

    public Device(String deviceType, int deviceId) {
        this(STR2MASK.get(deviceType), deviceId);
    }

    public boolean exist() {
        TVMValue ret = APIInternal.get("_GetDeviceAttr").pushArg(this.deviceType).pushArg(this.deviceId).pushArg(0).invoke();
        return ((TVMValueLong)ret).value != 0L;
    }

    public long maxThreadsPerBlock() {
        TVMValue ret = APIInternal.get("_GetDeviceAttr").pushArg(this.deviceType).pushArg(this.deviceId).pushArg(1).invoke();
        return ((TVMValueLong)ret).value;
    }

    public long warpSize() {
        TVMValue ret = APIInternal.get("_GetDeviceAttr").pushArg(this.deviceType).pushArg(this.deviceId).pushArg(2).invoke();
        return ((TVMValueLong)ret).value;
    }

    public void sync() {
        Base.checkCall(Base._LIB.tvmSynchronize(this.deviceType, this.deviceId));
    }

    public int hashCode() {
        return this.deviceType << 16 | this.deviceId;
    }

    public boolean equals(Object other) {
        if (other != null && other instanceof Device) {
            Device obj = (Device)other;
            return this.deviceId == obj.deviceId && this.deviceType == obj.deviceType;
        }
        return false;
    }

    public String toString() {
        if (this.deviceType >= 128) {
            int tblId = this.deviceType / 128 - 1;
            int devType = this.deviceType % 128;
            return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), this.deviceId);
        }
        return String.format("%s(%d)", MASK2STR.get(this.deviceType), this.deviceId);
    }

    static {
        MASK2STR.put(1, "cpu");
        MASK2STR.put(2, "cuda");
        MASK2STR.put(4, "opencl");
        MASK2STR.put(7, "vulkan");
        MASK2STR.put(8, "metal");
        MASK2STR.put(9, "vpi");
        MASK2STR.put(16, "hexagon");
        STR2MASK.put("cpu", 1);
        STR2MASK.put("cuda", 2);
        STR2MASK.put("cl", 4);
        STR2MASK.put("opencl", 4);
        STR2MASK.put("vulkan", 7);
        STR2MASK.put("metal", 8);
        STR2MASK.put("vpi", 9);
        STR2MASK.put("hexagon", 16);
    }
}

