/*
 * Decompiled with CFR 0.152.
 */
package deepboof.io.torch7;

import deepboof.io.torch7.ConvertTorchToBoofForward;
import deepboof.io.torch7.TorchType;
import deepboof.io.torch7.struct.TorchBoolean;
import deepboof.io.torch7.struct.TorchByteStorage;
import deepboof.io.torch7.struct.TorchCharStorage;
import deepboof.io.torch7.struct.TorchDoubleStorage;
import deepboof.io.torch7.struct.TorchFloatStorage;
import deepboof.io.torch7.struct.TorchGeneric;
import deepboof.io.torch7.struct.TorchList;
import deepboof.io.torch7.struct.TorchLongStorage;
import deepboof.io.torch7.struct.TorchNumber;
import deepboof.io.torch7.struct.TorchObject;
import deepboof.io.torch7.struct.TorchReference;
import deepboof.io.torch7.struct.TorchReferenceable;
import deepboof.io.torch7.struct.TorchStorage;
import deepboof.io.torch7.struct.TorchString;
import deepboof.io.torch7.struct.TorchTensor;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public abstract class ParseTorch7 {
    protected FileInputStream stream;
    protected DataInput input;
    protected Map<Integer, TorchReferenceable> masterTable = new HashMap<Integer, TorchReferenceable>();
    protected boolean verbose = false;

    public <T extends TorchObject> T parseOne(File file) throws IOException {
        return (T)this.parse(file).get(0);
    }

    public <T> T parseIntoBoof(File file) throws IOException {
        return ConvertTorchToBoofForward.convert(this.parseOne(file));
    }

    public List<TorchObject> parse(File file) throws IOException {
        this.masterTable = new HashMap<Integer, TorchReferenceable>();
        this.stream = new FileInputStream(file);
        this.input = new DataInputStream(this.stream);
        ArrayList<TorchObject> list = new ArrayList<TorchObject>();
        try {
            while (true) {
                list.add((TorchObject)this.parseNext(true));
            }
        }
        catch (EOFException eOFException) {
            this.stream.close();
            return list;
        }
    }

    private <T extends TorchObject> T parseNext(boolean useCached) throws IOException {
        TorchType type = this.readType();
        if (this.verbose) {
            System.out.println("========== Type = " + (Object)((Object)type));
        }
        TorchObject found = null;
        switch (type) {
            case TORCH: {
                int index = this.readS32();
                found = this.lookupObject(index, useCached);
                if (found != null) break;
                found = this.parseTorchObject(index);
                break;
            }
            case RECUR_FUNCTION: {
                int index = this.readS32();
                found = this.lookupObject(index, useCached);
                if (found != null) break;
                found = this.parseRecurFunction(index);
                break;
            }
            case TABLE: {
                found = this.parseTable();
                break;
            }
            case STRING: {
                found = this.parseString();
                break;
            }
            case BOOLEAN: {
                found = this.parseBoolean();
                break;
            }
            case NUMBER: {
                found = this.parseNumber();
                break;
            }
            case NIL: {
                if (!this.verbose) break;
                System.out.println("  ignoring nil");
                break;
            }
            default: {
                if (!this.verbose) break;
                System.out.println("Unsupported object type " + (Object)((Object)type));
            }
        }
        if (found != null) {
            if (found instanceof TorchReference) {
                TorchReference r = (TorchReference)found;
                found = this.masterTable.get(r.id);
            } else if (found instanceof TorchReferenceable) {
                TorchReferenceable tr = (TorchReferenceable)found;
                this.masterTable.put(tr.index, tr);
            }
        }
        return (T)found;
    }

    private TorchObject parseRecurFunction(int index) throws IOException {
        String moo = this.readString();
        if (this.verbose) {
            System.out.println("   not sure what to do with recur functions.  Here's their string:");
            System.out.println("   " + moo);
        }
        return this.parseNext(true);
    }

    private TorchObject parseTorchObject(int index) throws IOException {
        int version = this.stringToVersionNumber(this.readString());
        String className = this.readString();
        if (this.verbose) {
            System.out.println("  index = " + index + "  version = " + version + "  className = " + className);
        }
        TorchReferenceable ret = null;
        if (className.startsWith("torch.")) {
            if ((className = this.cudaToFloat(className)).endsWith("Storage")) {
                ret = this.parseStorage(className);
            } else if (className.endsWith("Tensor")) {
                ret = this.parseTensor();
            }
        }
        if (ret == null) {
            switch (className) {
                default: 
            }
            TorchGeneric t = new TorchGeneric();
            ret = t;
            TorchGeneric innerTable = (TorchGeneric)this.parseNext(true);
            if (innerTable == null) {
                throw new RuntimeException("Probably an unsupported type.  Add support for " + className);
            }
            t.map = innerTable.map;
        }
        ret.index = index;
        ret.version = version;
        ret.torchName = className;
        return ret;
    }

    private TorchObject lookupObject(int index, boolean useCached) {
        if (useCached && this.masterTable.containsKey(index)) {
            if (this.verbose) {
                System.out.println("reference index = " + index);
            }
            TorchReference ret = new TorchReference();
            ret.id = index;
            return ret;
        }
        return null;
    }

    private String cudaToFloat(String className) {
        if (className.equals("torch.CudaStorage")) {
            className = "torch.FloatStorage";
        } else if (className.equals("torch.CudaTensor")) {
            className = "torch.FloatTensor";
        }
        return className;
    }

    private TorchTensor parseTensor() throws IOException {
        TorchTensor t = new TorchTensor();
        int dimension = this.readS32();
        if (dimension != 0) {
            t.shape = this.readShape(dimension);
            if (this.verbose) {
                System.out.println("   shape dimension = " + t.shape.length);
            }
            this.readShape(dimension);
            t.startIndex = (int)this.readS64() - 1;
            t.storage = (TorchStorage)this.parseNext(true);
            if (this.verbose && (t.storage.size() != t.length() || t.startIndex != 0)) {
                System.out.println("subtensor.  storage " + t.storage.size() + "  tensor " + t.length() + "  offset " + t.startIndex);
            }
        } else {
            int a = this.readS32();
            long b = this.readS64();
            if (this.verbose) {
                System.out.println("    no dimension.  Weird variable " + a + " " + b);
            }
        }
        return t;
    }

    private TorchStorage parseStorage(String name) throws IOException {
        TorchStorage out;
        int size = (int)this.readS64();
        switch (name) {
            case "torch.LongStorage": {
                TorchLongStorage t = new TorchLongStorage(size);
                for (int i = 0; i < size; ++i) {
                    t.data[i] = this.readS64();
                }
                out = t;
                break;
            }
            case "torch.FloatStorage": {
                TorchFloatStorage t = new TorchFloatStorage(size);
                this.readArrayFloat(size, t.data);
                out = t;
                break;
            }
            case "torch.DoubleStorage": {
                TorchDoubleStorage t = new TorchDoubleStorage(size);
                this.readArrayDouble(size, t.data);
                out = t;
                break;
            }
            case "torch.ByteStorage": {
                TorchByteStorage t = new TorchByteStorage(size);
                this.readArrayByte(size, t.data);
                out = t;
                break;
            }
            case "torch.CharStorage": {
                TorchCharStorage t = new TorchCharStorage(size / 2 + size % 2);
                this.readArrayChar(size, t.data);
                out = t;
                break;
            }
            default: {
                throw new IOException("Unsupported storage type.  Please add support " + name);
            }
        }
        if (this.verbose) {
            out.printSummary();
        }
        return out;
    }

    private void printNextHex(int N) throws IOException {
        for (int i = 0; i < N; ++i) {
            System.out.printf("%02x ", this.input.readByte() & 0xFF);
        }
        System.out.println();
    }

    private TorchObject parseTable() throws IOException {
        int index = this.readS32();
        int size = this.readS32();
        if (this.verbose) {
            System.out.println("  idx = " + index);
            System.out.println("  size = " + size);
        }
        HashMap<Object, TorchObject> map = new HashMap<Object, TorchObject>();
        for (int i = 0; i < size; ++i) {
            Object key;
            Object o_key = this.parseNext(true);
            if (o_key instanceof TorchString) {
                key = ((TorchString)o_key).message;
            } else if (o_key instanceof TorchNumber) {
                key = ((TorchNumber)o_key).value;
            } else {
                throw new RuntimeException("Add support for " + o_key);
            }
            Object value = this.parseNext(true);
            if (key.equals("_type")) {
                TorchString s = (TorchString)value;
                s.message = this.cudaToFloat(s.message);
            }
            if (map.put(key, (TorchObject)value) == null) continue;
            throw new RuntimeException("Probably a bug in the parser.  Same key assigned twice");
        }
        if (size > 0 && this.isList(map)) {
            ArrayList listKeys = new ArrayList();
            listKeys.addAll(map.keySet());
            Collections.sort(listKeys);
            TorchList t = new TorchList();
            t.index = index;
            for (int i = 0; i < listKeys.size(); ++i) {
                Double number = (Double)listKeys.get(i);
                int value = number.intValue();
                if (value != i + 1) {
                    throw new RuntimeException("Not actually a complete sequential list");
                }
                t.list.add((TorchObject)map.get(number));
            }
            return t;
        }
        TorchGeneric t = new TorchGeneric();
        t.map = map;
        t.index = index;
        return t;
    }

    private boolean isList(Map<Object, TorchObject> map) {
        for (Object o : map.keySet()) {
            if (o instanceof Double) continue;
            return false;
        }
        return true;
    }

    private TorchObject parseString() throws IOException {
        TorchString ret = new TorchString();
        ret.message = this.readString();
        if (this.verbose) {
            System.out.println("   " + ret.message);
        }
        return ret;
    }

    private TorchObject parseBoolean() throws IOException {
        TorchBoolean ret = new TorchBoolean();
        ret.value = this.readBoolean();
        if (this.verbose) {
            System.out.println("   " + ret.value);
        }
        return ret;
    }

    private TorchObject parseNumber() throws IOException {
        TorchNumber ret = new TorchNumber();
        ret.value = this.readDouble();
        return ret;
    }

    private void parseRecurFunction() throws IOException {
        int index = this.readS32();
        throw new IOException("Not supported yet.  RecurFunction");
    }

    public int stringToVersionNumber(String line) {
        if (line.length() < 3 || line.charAt(0) != 'V' || line.charAt(1) != ' ') {
            throw new RuntimeException("Old format.  Add support for this");
        }
        String substring = line.substring(2, line.length());
        return Integer.parseInt(substring);
    }

    public abstract int[] readShape(int var1) throws IOException;

    public abstract TorchType readType() throws IOException;

    public abstract boolean readBoolean() throws IOException;

    public abstract double readDouble() throws IOException;

    public abstract float readFloat() throws IOException;

    public abstract String readString() throws IOException;

    public abstract long readS64() throws IOException;

    public abstract int readS32() throws IOException;

    public abstract int readU8() throws IOException;

    public abstract void readArrayDouble(int var1, double[] var2) throws IOException;

    public abstract void readArrayFloat(int var1, float[] var2) throws IOException;

    public abstract void readArrayChar(int var1, char[] var2) throws IOException;

    public abstract void readArrayByte(int var1, byte[] var2) throws IOException;

    public boolean isVerbose() {
        return this.verbose;
    }

    public ParseTorch7 setVerbose(boolean verbose) {
        this.verbose = verbose;
        return this;
    }
}

