/*
 * Copyright © 2016 BDO-Emu authors. All rights reserved.
 * Viewing, editing, running and distribution of this software strongly prohibited.
 * Author: xTz, Anton Lasevich, Tibald
 */

package host.anzo.commons.io.binary;

import host.anzo.commons.unsafe.ByteBufferCleaner;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.text.TextStringBuilder;
import org.jetbrains.annotations.NotNull;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.Buffer;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

/**
 * @author ANZO
 * @since 07.05.2017
 */
@Slf4j
public @Getter class ByteBufferEx implements AutoCloseable {
	private ByteBuffer buffer;
	private final boolean isDirect;
	private final AtomicBoolean isDestroyed = new AtomicBoolean(false);
	private final AtomicReference<ByteBufferExState> byteBufferState = new AtomicReference<>(ByteBufferExState.Idle);

	/**
	 * Create ByteBufferEx with specified order and capacity
	 * @param capacity initial capacity of ByteBufferEx
	 * @param byteOrder byte order of ByteBufferEx
	 * @param isDirect {@code true} if buffer must be DirectByteBuffer
	 */
	public ByteBufferEx(int capacity, ByteOrder byteOrder, boolean isDirect) {
		this.isDirect = isDirect;
		this.buffer = isDirect ? ByteBuffer.allocateDirect(capacity) : ByteBuffer.allocate(capacity);
		this.buffer.order(byteOrder);
	}

	/**
	 * Wrap a specified byte array to ByteBufferEx
	 * @param data byte array to wrap
	 * @param byteOrder byte order of created ByteBufferEx
	 */
	public ByteBufferEx(byte @NotNull [] data, ByteOrder byteOrder) {
		this(data.length, byteOrder, false);
		this.buffer.put(data);
		this.buffer.position(0);
	}

	/**
	 * Wrap specified InputStream to ByteBufferEx
	 * @param inputStream InputStream to wrap
	 * @param byteOrder byte order of created ByteBufferEx
	 */
	public ByteBufferEx(InputStream inputStream, ByteOrder byteOrder) throws IOException {
		this(IOUtils.toByteArray(inputStream), byteOrder);
	}

	/**
	 * @return current cursor position in buffer
	 */
	public int position() {
		int position = 0;
		if (checkAccess()) {
			position = buffer.position();
			checkDestroy();
		}
		return position;
	}

	/**
	 * Set the current cursor position in byte buffer
	 * @param pos cursor position
	 * @return buffer with changed to specified cursor position
	 */
	public Buffer position(int pos) {
		Buffer buffer = null;
		if (checkAccess()) {
			buffer = this.buffer.position(pos);
			checkDestroy();
		}
		return buffer;
	}

	/**
	 * Set byte buffer limit
	 * @param limit limit to set
	 */
	public void limit(int limit) {
		if (checkAccess()) {
			buffer.limit(limit);
			checkDestroy();
		}
	}

	/**
	 * @return current limit of byte buffer
	 */
	public int limit() {
		int limit = 0;
		if (checkAccess()) {
			limit = buffer.limit();
			checkDestroy();
		}
		return limit;
	}

	/**
	 * @return remaining byte count from current cursor position to limit
	 */
	public int remaining() {
		int remaining = 0;
		if (checkAccess()) {
			remaining = buffer.remaining();
			checkDestroy();
		}
		return remaining;
	}

	/**
	 * @return {@code true} if buffer has remaining bytes from current cursor position
	 */
	public boolean hasRemaining() {
		boolean hasRemaining = false;
		if (checkAccess()) {
			hasRemaining = buffer.hasRemaining();
			checkDestroy();
		}
		return hasRemaining;
	}

	/**
	 * Flip byte buffer content
	 */
	public void flip() {
		if (checkAccess()) {
			buffer.flip();
			checkDestroy();
		}
	}

	/**
	 * @return byte buffer capacity
	 */
	public int capacity() {
		int capacity = 0;
		if (checkAccess()) {
			capacity = buffer.capacity();
			checkDestroy();
		}
		return capacity;
	}

	/**
	 * Compact byte buffer
	 */
	public void compact() {
		if (checkAccess()) {
			buffer.compact();
			checkDestroy();
		}
	}

	/**
	 * Clear byte buffer content
	 */
	public void clear() {
		if (checkAccess()) {
			buffer.clear();
			checkDestroy();
		}
	}

	/**
	 * Skip specified bytes count and set cursor position
	 * @param bytes bytes count
	 */
	public final void skip(int bytes) {
		if (checkAccess()) {
			if (buffer.remaining() < bytes) {
				throw new BufferUnderflowException();
			}
			buffer.position(buffer.position() + bytes);
			checkDestroy();
		}
	}

	/**
	 * Skip all remaining data and set cursor to end of buffer
	 */
	public final void skipAll() {
		if (checkAccess()) {
			buffer.position(buffer.limit());
			checkDestroy();
		}
	}

	/**
	 * Read bytes to specified byte array
	 * @param dst array to write
	 */
	public final void readB(byte[] dst) {
		if (checkAccess()) {
			buffer.get(dst);
			checkDestroy();
		}
	}

	/**
	 * Read specified bytes count to byte array
	 * @param len bytes count
	 * @return byte array
	 */
	public final byte @NotNull [] readB(int len) {
		final byte[] tmp = new byte[len];
		if (checkAccess()) {
			buffer.get(tmp);
			checkDestroy();
		}
		return tmp;
	}

	public final @NotNull ByteBufferEx readBuffer(int len) {
		final byte[] tmp = new byte[len];
		if (checkAccess()) {
			buffer.get(tmp);
			checkDestroy();
		}
		return new ByteBufferEx(tmp, ByteOrder.LITTLE_ENDIAN);
	}

	public final void readB(byte[] dst, int offset, int len) {
		if (checkAccess()) {
			buffer.get(dst, offset, len);
			checkDestroy();
		}
	}

	public final byte readC() {
		byte result = 0;
		if (checkAccess()) {
			result = (byte) (buffer.get() & 0xFF);
			checkDestroy();
		}
		return result;
	}

	public final int readCD() {
		int result = 0;
		if (checkAccess()) {
			result = buffer.get() & 0xFF;
			checkDestroy();
		}
		return result;
	}

	public final boolean readCB() {
		boolean result = false;
		if (checkAccess()) {
			result = (buffer.get() & 0xFF) == 1;
			checkDestroy();
		}
		return result;
	}
	
	public final boolean readCB(int align) {
		boolean result = false;
		if (checkAccess()) {
			result = (buffer.get() & 0xFF) == 1;
			buffer.get(new byte[align - 1]);
			checkDestroy();
		}
		return result;
	}

	public final int readHD() {
		int result = 0;
		if (checkAccess()) {
			result = buffer.getShort() & 0xFFFF;
			checkDestroy();
		}
		return result;
	}

	public final short readH() {
		short result = 0;
		if (checkAccess()) {
			result = (short) (buffer.getShort() & 0xFFFF);
			checkDestroy();
		}
		return result;
	}

	public final short readH(int position) {
		short result = 0;
		if (checkAccess()) {
			result = (short) (buffer.getShort(position) & 0xFFFF);
			checkDestroy();
		}
		return result;
	}

	public final int readD() {
		int result = 0;
		if (checkAccess()) {
			result = buffer.getInt();
			checkDestroy();
		}
		return result;
	}

	public final int readD(int position) {
		int result = 0;
		if (checkAccess()) {
			result = buffer.getInt(position);
			checkDestroy();
		}
		return result;
	}

	public final long readDQ() {
		long result = 0;
		if (checkAccess()) {
			result = buffer.getInt() & 0xFFFFFFFFL;
			checkDestroy();
		}
		return result;
	}

	public final long readDQ(int position) {
		long result = 0;
		if (checkAccess()) {
			result = buffer.getInt(position);
			checkDestroy();
		}
		return result;
	}

	public final int readD3() {
		int result = 0;
		if (checkAccess()) {
			result = buffer.get() & 0xFF;
			result |= ((buffer.get() << 8) & 0xFF00);
			result |= ((buffer.get() << 16) & 0xFF0000);
			checkDestroy();
		}
		return result;
	}

	public final long readQ() {
		long result = 0;
		if (checkAccess()) {
			result = buffer.getLong();
			checkDestroy();
		}
		return result;
	}

	public final float readF() {
		float result = 0;
		if (checkAccess()) {
			result = buffer.getFloat();
			checkDestroy();
		}
		return result;
	}

	public final String readStringUnicode(int padding) {
		final TextStringBuilder stringBuilder = new TextStringBuilder();
		if (checkAccess()) {
			for (char c; padding > 0; ) {
				padding -= 2;
				c = buffer.getChar();
				if (c == 0) {
					break;
				}
				stringBuilder.append(c);
			}
			skip(padding);
			checkDestroy();
		}
		return stringBuilder.build();
	}

	public final String readStringUnicodeNT() {
		final TextStringBuilder stringBuilder = new TextStringBuilder();
		if (checkAccess()) {
			for (char c; (c = buffer.getChar()) != 0;) {
				stringBuilder.append(c);
			}
			checkDestroy();
		}
		return stringBuilder.build();
	}

	public final String readStringUnicodeNT(int size) {
		final TextStringBuilder stringBuilder = new TextStringBuilder();
		if (checkAccess()) {
			for (char c; size > 0; size -= 2) {
				if ((c = buffer.getChar()) != 0) {
					stringBuilder.append(c);
				}
			}
			checkDestroy();
		}
		return stringBuilder.build();
	}

	/**
	 * Read ANSI string with known size
	 * @param size string size
	 * @return string from buffer with specified size
	 */
	public final @NotNull String readString(int size) {
		final byte[] bytes = new byte[size];
		if (checkAccess()) {
			buffer.get(bytes);
			checkDestroy();
		}
		return new String(bytes);
	}

	public final @NotNull String readString() {
		if (checkAccess()) {
			final int length = readHD();
			if (length > 0) {
				final byte[] bytes = new byte[length];
				buffer.get(bytes);
				checkDestroy();
				return new String(bytes);
			}
		}
		return "";
	}

	/**
	 * Read null terminated ANSI string
	 * @param size string size
	 * @return string from buffer with specified size
	 */
	public final @NotNull String readStringNT(int size) {
		final TextStringBuilder stringBuilder = new TextStringBuilder();
		if (checkAccess()) {
			final byte[] byteBuffer = readB(size);
			for (int byteIndex = 0; byteIndex < size; byteIndex++) {
				final byte charByte = byteBuffer[byteIndex];
				if (charByte == 0) {
					break;
				}
				stringBuilder.append((char) charByte);
			}
			checkDestroy();
		}
		return stringBuilder.get();
	}

	public final <T extends Enum<T>> T readEnum(Class<T> enumClass, T defaultValue) {
		final int ordinalValue = readCD();
		try {
			return enumClass.getEnumConstants()[ordinalValue];
		}
		catch (Exception e) {
			return defaultValue;
		}
	}

	public final void write(@NotNull Object value, Integer... params) {
		final Class<?> clazz = value.getClass();
		if (clazz == Short.class) {
			writeH((short)value);
		}
		else if (clazz == Integer.class) {
			writeD((int)value);
		}
		else if (clazz == Long.class) {
			writeQ((long)value);
		}
		else if (clazz == Double.class) {
			writeF((double)value);
		}
		else if (clazz == Float.class) {
			writeF((float)value);
		}
		else if (clazz == Byte.class) {
			writeC((byte)value);
		}
		else if (clazz == Boolean.class) {
			writeC((boolean)value ? 1 : 0);
		}
		else if (clazz == byte[].class) {
			writeB((byte[])value);
		}
		else if (clazz == String.class) {
			writeStringUnicode((String)value, params[0]);
		}
		else {
			log.error("Method didn't support write for clazz=[{}]", clazz);
		}
	}

	public final void writeC(boolean value) {
		if (checkAccess()) {
			buffer.put((byte) (value ? 1 : 0));
			checkDestroy();
		}
	}

	public final void writeC(int value) {
		if (checkAccess()) {
			buffer.put((byte) value);
			checkDestroy();
		}
	}

	public final void writeC(Enum<?> value) {
		if (checkAccess()) {
			buffer.put((byte) value.ordinal());
			checkDestroy();
		}
	}

	public final void writeC(byte[] byteArray) {
		if (checkAccess()) {
			for (int value : byteArray) {
				writeC(value);
			}
			checkDestroy();
		}
	}

	public final void writeH(boolean value) {
		if (checkAccess()) {
			buffer.putShort((short) (value ? 1 : 0));
			checkDestroy();
		}
	}

	public final void writeH(int value) {
		if (checkAccess()) {
			buffer.putShort((short) value);
			checkDestroy();
		}
	}

	public final void writeH(short[] shortArray) {
		if (checkAccess()) {
			for (short value : shortArray) {
				writeH(value);
			}
			checkDestroy();
		}
	}

	public final void writeH(int[] intArray) {
		if (checkAccess()) {
			for (int value : intArray) {
				writeH(value);
			}
			checkDestroy();
		}
	}

	public final void writeD(boolean value) {
		if (checkAccess()) {
			buffer.putInt(value ? 1 : 0);
			checkDestroy();
		}
	}

	public final void writeD(int value) {
		if (checkAccess()) {
			buffer.putInt(value);
			checkDestroy();
		}
	}

	public final void writeD(int value, int align) {
		if (checkAccess()) {
			buffer.putInt(value);
			buffer.put(new byte[align - 4]);
			checkDestroy();
		}
	}

	public final void writeD(long value) {
		if (checkAccess()) {
			buffer.putInt((int) (value & 0xFFFFFFFF));
			checkDestroy();
		}
	}

	public final void writeD(int[] intArray) {
		if (checkAccess()) {
			for (int value : intArray) {
				writeD(value);
			}
			checkDestroy();
		}
	}

	public final void writeQ(boolean value) {
		if (checkAccess()) {
			buffer.putLong(value ? 1 : 0);
			checkDestroy();
		}
	}

	public final void writeQ(long value) {
		if (checkAccess()) {
			buffer.putLong(value);
			checkDestroy();
		}
	}

	public final void writeQ(long[] longArray) {
		if (checkAccess()) {
			for (long value : longArray) {
				writeQ(value);
			}
			checkDestroy();
		}
	}

	public final void writeF(float value) {
		if (checkAccess()) {
			buffer.putFloat(value);
			checkDestroy();
		}
	}

	public final void writeF(float[] floatArray) {
		if (checkAccess()) {
			for (float value : floatArray) {
				writeF(value);
			}
			checkDestroy();
		}
	}

	public final void writeF(double value) {
		if (checkAccess()) {
			buffer.putFloat((float) value);
			checkDestroy();
		}
	}

	public final void writeF(int value) {
		if (checkAccess()) {
			buffer.putFloat(value);
			checkDestroy();
		}
	}

	public final void writeArray(Object[] objectArray) {
		if (checkAccess()) {
			for (Object object : objectArray) {
				write(object);
			}
			checkDestroy();
		}
	}

	public final void writeB(byte[] data) {
		if (checkAccess()) {
			buffer.put(data);
			checkDestroy();
		}
	}

	public final void writeB(int size) {
		if (checkAccess()) {
			buffer.put(new byte[size]);
			checkDestroy();
		}
	}

	/**
	 * Write specified value to buffer as 24bit integer
	 * @param value value to write
	 */
	public final void writeD3(int value) {
		if (checkAccess()) {
			buffer.put((byte) (value & 0xFF));
			buffer.put((byte) ((value & 0xFF00) >> 8));
			buffer.put((byte) ((value & 0xFF0000) >> 16));
			checkDestroy();
		}
	}

	public final void writeString(CharSequence charSequence, int size) {
		if (checkAccess()) {
			int length = 0;
			if (charSequence != null) {
				length = charSequence.length();
				for (int i = 0; i < length; i++) {
					buffer.put((byte) (charSequence.charAt(i) & 0xFF));
				}
			}
			if (length < size) {
				writeB(new byte[size - length]);
			}
			checkDestroy();
		}
	}

	public final void writeString(CharSequence charSequence) {
		if (checkAccess()) {
			if (charSequence != null) {
				int length = charSequence.length();
				buffer.putShort((short) length);
				for (int i = 0; i < length; i++) {
					buffer.put((byte) (charSequence.charAt(i) & 0xFF));
				}
			}
			checkDestroy();
		}
	}

	public final void writeStringUnicodeNT(CharSequence charSequence) {
		if (checkAccess()) {
			final int length = charSequence.length();
			buffer.putShort((short) (length * 2 + 2));
			for (int i = 0; i < length; i++) {
				buffer.putChar(charSequence.charAt(i));
			}
			buffer.putChar('\000');
			checkDestroy();
		}
	}

	public final void writeStringUnicode(CharSequence charSequence, int size) {
		if (checkAccess()) {
			if (charSequence == null) {
				buffer.put(new byte[size]);
			}
			else {
				int startPosition = buffer.position();
				for (int i = 0; i < charSequence.length(); i++) {
					if ((buffer.position() - startPosition) < size) {
						buffer.putChar(charSequence.charAt(i));
					}
				}
				int length = buffer.position() - startPosition;
				if (length < size) {
					writeB(new byte[size - length]);
				}
			}
			checkDestroy();
		}
	}

	public int getPosition() {
		return buffer.position();
	}

	public byte[] toByteArray() {
		if (buffer.hasArray()) {
			final byte[] array = buffer.array();
			return Arrays.copyOfRange(array, 0, buffer.limit());
		}
		return null;
	}

	@SuppressWarnings("UnusedReturnValue")
	public boolean writeToFile(Path path) {
		try (FileOutputStream fileOutputStream = new FileOutputStream(path.toFile(), false)) {
			try(FileChannel channel = fileOutputStream.getChannel()) {
				buffer.flip();
				return channel.write(buffer) > 0;
			}
		}
		catch (Exception e) {
			log.error("Error while writing buffer to file [{}]", path.toString(), e);
		}
		return false;
	}

	/**
	 * @return {@code true} if buffer can be accessed now
	 */
	public boolean checkAccess() {
		return !isDirect || byteBufferState.compareAndSet(ByteBufferExState.Idle, ByteBufferExState.Access) || byteBufferState.get() == ByteBufferExState.Access;
	}

	/**
	 * Check if buffer must be destroyed
	 */
	public void checkDestroy() {
		if (isDestroyed.get()) {
			byteBufferState.set(ByteBufferExState.Destroyed);
			destroy();
		}
		else {
			byteBufferState.set(ByteBufferExState.Idle);
		}
	}

	/**
	 * Set buffer must be destroyed after current operation completed
	 */
	public void tryDestroy() {
		isDestroyed.set(true);
		if (!isDirect || byteBufferState.compareAndSet(ByteBufferExState.Idle, ByteBufferExState.Destroyed)) {
			destroy();
		}
	}

	/**
	 * Destroy buffer and free native memory for DirectBuffer
	 */
	private void destroy() {
		if (buffer != null) {
			buffer.clear();

			if (isDirect) {
				ByteBufferCleaner.clean(buffer);
			}
			buffer = null;
		}
	}

	@Override
	public void close() throws Exception {
		destroy();
	}

	private enum ByteBufferExState {
		Idle,
		Access,
		Destroyed
	}
}