/*
 * Decompiled with CFR 0.152.
 */
package org.sentrysoftware.winrm.service.client.encryption;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import org.apache.cxf.helpers.IOUtils;
import org.apache.cxf.message.Message;
import org.sentrysoftware.winrm.service.client.auth.ntlm.NTCredentialsWithEncryption;
import org.sentrysoftware.winrm.service.client.encryption.ByteArrayUtils;
import org.sentrysoftware.winrm.service.client.encryption.NtlmEncryptionUtils;

public class Decryptor {
    private final NTCredentialsWithEncryption credentials;
    private byte[] rawBytes;
    private byte[] encryptedPayloadBytes;
    private int index;
    private int lastBlockStart;
    private int lastBlockEnd;
    private byte[] signatureBytes;
    private byte[] sealedBytes;
    private byte[] unsealedBytes;

    public Decryptor(NTCredentialsWithEncryption credentials) {
        this.credentials = credentials;
    }

    public void handle(Message message) {
        boolean isEncrypted;
        Object contentType = message.get((Object)"Content-Type");
        boolean bl = isEncrypted = contentType != null && contentType.toString().startsWith("multipart/encrypted");
        if (isEncrypted) {
            if (this.credentials == null) {
                throw new IllegalStateException("Encrypted payload from server when no credentials with encryption known");
            }
            if (!this.credentials.isAuthenticated()) {
                throw new IllegalStateException("Encrypted payload from server when not authenticated");
            }
            try {
                this.decrypt(message);
            }
            catch (Exception e) {
                throw new IllegalStateException(e);
            }
        } else if (this.credentials != null && this.credentials.isAuthenticated()) {
            throw new IllegalStateException("Unencrypted payload from server when authenticated and encryption is required");
        }
    }

    void decrypt(Message message) throws IOException {
        try (InputStream in = (InputStream)message.getContent(InputStream.class);){
            this.rawBytes = IOUtils.readBytesFromStream((InputStream)in);
        }
        this.unwrap();
        int signatureLength = (int)ByteArrayUtils.readLittleEndianUnsignedInt(this.encryptedPayloadBytes, 0);
        this.signatureBytes = Arrays.copyOfRange(this.encryptedPayloadBytes, 4, 4 + signatureLength);
        this.sealedBytes = Arrays.copyOfRange(this.encryptedPayloadBytes, 4 + signatureLength, this.encryptedPayloadBytes.length);
        this.unseal();
        this.verify();
        message.setContent(InputStream.class, (Object)new ByteArrayInputStream(this.unsealedBytes));
    }

    private void verify() throws IOException {
        long seqNum = ByteArrayUtils.readLittleEndianUnsignedInt(this.signatureBytes, 12);
        int checkSumOffset = this.credentials.hasNegotiateFlag(524288L) ? 4 : 8;
        byte[] checksum = Arrays.copyOfRange(this.signatureBytes, checkSumOffset, 12);
        try (ByteArrayOutputStream signature = new ByteArrayOutputStream();){
            NtlmEncryptionUtils.calculateSignature(this.unsealedBytes, seqNum, signature, this.credentials, NTCredentialsWithEncryption::getServerSigningKey, this.credentials.getStatefulDecryptor()::update);
            byte[] expectedChecksum = Arrays.copyOfRange(signature.toByteArray(), checkSumOffset, 12);
            long expectedSeqNum = ByteArrayUtils.readLittleEndianUnsignedInt(signature.toByteArray(), 12);
            if (!Arrays.equals(checksum, expectedChecksum)) {
                throw new IllegalStateException(String.format("Checksum mismatch\n%s--\n%s", ByteArrayUtils.formatHexDump(checksum), ByteArrayUtils.formatHexDump(expectedChecksum)));
            }
            if (expectedSeqNum != seqNum) {
                throw new IllegalStateException(String.format("Sequence number mismatch: %d != %d", seqNum, expectedSeqNum));
            }
        }
        this.credentials.getSequenceNumberIncoming().incrementAndGet();
    }

    void unwrap() {
        this.index = 0;
        this.skipOver("--Encrypted Boundary\r\n");
        this.skipUntil("\n--Encrypted Boundary\r\n");
        this.skipUntil("\r\n");
        this.lastBlockStart = this.index;
        this.index = this.lastBlockEnd = this.rawBytes.length - "--Encrypted Boundary--\r\n".length();
        this.skipOver("--Encrypted Boundary--\r\n");
        this.encryptedPayloadBytes = Arrays.copyOfRange(this.rawBytes, this.lastBlockStart, this.lastBlockEnd);
    }

    void skipOver(String s) {
        this.skipOver(s.getBytes());
    }

    void skipOver(byte[] expected) {
        int i = 0;
        while (i < expected.length) {
            if (this.index >= this.rawBytes.length) {
                throw new IllegalStateException(String.format("Invalid format for response from server; terminated early (%d) when expecting '%s'\n%s", i, new String(expected), ByteArrayUtils.formatHexDump(this.rawBytes)));
            }
            if (expected[i++] == this.rawBytes[this.index++]) continue;
            throw new IllegalStateException(String.format("Invalid format for response from server; mismatch at position %d (%d) when expecting '%s'\n%s", this.index, i, new String(expected), ByteArrayUtils.formatHexDump(this.rawBytes)));
        }
    }

    void skipUntil(String str) {
        byte[] expected = str.getBytes();
        int nextBlock = this.index;
        while (true) {
            for (int i = 0; i < expected.length && nextBlock + i < this.rawBytes.length; ++i) {
                if (nextBlock + i >= this.rawBytes.length) {
                    throw new IllegalStateException(String.format("Invalid format for response from server; terminated early (%d) when looking for '%s'\n%s", i, new String(expected), ByteArrayUtils.formatHexDump(this.rawBytes)));
                }
                if (expected[i] == this.rawBytes[nextBlock + i]) continue;
                ++nextBlock;
            }
            break;
        }
        this.lastBlockStart = this.index;
        this.lastBlockEnd = nextBlock;
        this.index = nextBlock + expected.length;
    }

    private void unseal() {
        this.unsealedBytes = this.credentials.getStatefulDecryptor().update(this.sealedBytes);
    }
}

