package host.anzo.commons.graphics.image;

import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

/**
 * @author ANZO
 * @since 8/3/2023
 */
@Slf4j
public class HeatMapImage {
	private final Map<Integer, LinkedList<Point>> pointsMap = new HashMap<>();
	private int maxOccurrence = 1;

	private final BufferedImage backgroundImage;
	private final int backgroundWidth;
	private final int backgroundHeight;

	public HeatMapImage(@NotNull BufferedImage backgroundImage, final @NotNull Collection<Point> points) {
		this.backgroundImage = backgroundImage;
		this.backgroundWidth = backgroundImage.getWidth();
		this.backgroundHeight = backgroundImage.getHeight();

		for (final Point point : points) {
			final int hash = getKey(point);
			if (pointsMap.containsKey(hash)) {
				final LinkedList<Point> thisList = pointsMap.get(hash);
				thisList.add(point);
				if (thisList.size() > maxOccurrence) {
					maxOccurrence = thisList.size();
				}
			}
			else {
				final LinkedList<Point> newList = new LinkedList<>();
				newList.add(point);
				pointsMap.put(hash, newList);
			}
		}
	}

	/**
	 * creates the heatmap.
	 *
	 * @param multiplier this value will multiply the calculated opacity of every point.
	 *                   This leads to a HeatMap that is easier to read, especially
	 *                   when there are not too many points or the points are to spread out. Pass 1.0f for original.
	 */
	public BufferedImage createHeatMap(final float multiplier, final boolean withBackground) {
		final BufferedImage circle = loadImage("heatmap/heatmap_circle.png");
		if (circle == null) {
			return null;
		}

		final BufferedImage heatMap = new BufferedImage(backgroundWidth, backgroundHeight, BufferedImage.TYPE_4BYTE_ABGR);
		paintInColor(heatMap);

		for (LinkedList<Point> currentPoints : pointsMap.values()) {
			// calculate opaqueness based on the number of current point occurrences
			float opaque = currentPoints.size() / (float) maxOccurrence;

			// adjust opacity so the heatmap is easier to read
			opaque = opaque * multiplier;
			if (opaque > 1) {
				opaque = 1;
			}

			final Point currentPoint = currentPoints.get(0);

			// draw a circle which gets transparent from middle to outside
			// (which opaqueness is set to "opaque")
			// at the position specified by the center of the currentPoint
			addImage(heatMap, circle, opaque,
					(currentPoint.x - (circle.getWidth() / 2)),
					(currentPoint.y - (circle.getWidth() / 2)));
		}

		// negate the image
		negateImage(heatMap);

		// remap black/white with color spectrum from white over red, orange, yellow, green to blue
		remap(heatMap);

		if (withBackground) {
			// Blend a heat map to background image with 40% transparency
			final BufferedImage output = backgroundImage;
			addImage(output, heatMap, 0.4f);
			return output;
		}
		return heatMap;
	}

	/**
	 * Save heat map to file by specified path
	 * @param path file path
	 */
	public void save(String path) {
		try {
			ImageIO.write(createHeatMap(0.3f, true), "png", new File(path));
		}
		catch (final IOException e) {
			log.error("Error while saving heat map to path=[{}]", path, e);
		}
	}

	/**
	 * Remaps black and white picture with colors. It uses the colors from spectrum image file.
	 * The whiter a pixel is, the more it will get a color from the
	 * bottom of it. Black will stay black.
	 *
	 * @param heatMapBW black and white heat map
	 */
	private void remap(final @NotNull BufferedImage heatMapBW) {
		final BufferedImage colorGradiant = loadImage("heatmap/heatmap_spectrum.png");
		if (colorGradiant == null) {
			return;
		}
		final int width = heatMapBW.getWidth();
		final int height = heatMapBW.getHeight();
		final int gradientHight = colorGradiant.getHeight() - 1;
		for (int i = 0; i < width; i++) {
			for (int j = 0; j < height; j++) {
				// get heatMapBW color values:
				final int rGB = heatMapBW.getRGB(i, j);

				// calculate multiplier to be applied to the height of gradiant.
				float multiplier = rGB & 0xff; // blue
				multiplier *= ((rGB >>> 8)) & 0xff; // green
				multiplier *= (rGB >>> 16) & 0xff; // red
				multiplier /= 16581375; // 255f * 255f * 255f

				// apply multiplier
				final int y = (int) (multiplier * gradientHight);

				// remap values
				// calculate new value based on whiteness of heatMap
				// (the whiter, the more a color from the top of colorGradiant
				// will be chosen.
				final int mapedRGB = colorGradiant.getRGB(0, y);
				// set new value
				heatMapBW.setRGB(i, j, mapedRGB);
			}
		}
	}

	/**
	 * Returns a negated version of this image.
	 * @param img buffer to negate
	 */
	private void negateImage(final @NotNull BufferedImage img) {
		final int width = img.getWidth();
		final int height = img.getHeight();
		for (int x = 0; x < width; x++) {
			for (int y = 0; y < height; y++) {
				final int rgb = img.getRGB(x, y);
				// Swaps values
				// i.e., 255, 255, 255 (white)
				// becomes 0, 0, 0 (black)
				final int r = Math.abs(((rgb >>> 16) & 0xff) - 255); // red
				// inverted
				final int g = Math.abs(((rgb >>> 8) & 0xff) - 255); // green
				// inverted
				final int b = Math.abs((rgb & 0xff) - 255); // blue inverted
				// transform back to pixel value and set it
				img.setRGB(x, y, (r << 16) | (g << 8) | b);
			}
		}
	}

	/**
	 * Changes all pixels in the buffer to the provided color.
	 * @param buff buffer
	 */
	private void paintInColor(final @NotNull BufferedImage buff) {
		final Graphics2D g2 = buff.createGraphics();
		g2.setColor(Color.white);
		g2.fillRect(0, 0, buff.getWidth(), buff.getHeight());
		g2.dispose();
	}

	/**
	 * Prints the contents of buff2 on buff1 with the given opaque value
	 * starting at position 0, 0.
	 * @param buff1 buffer
	 * @param buff2 buffer to add to buff1
	 */
	private void addImage(final BufferedImage buff1, final BufferedImage buff2, float opaque) {
		addImage(buff1, buff2, opaque, 0, 0);
	}

	/**
	 * prints the contents of buff2 on buff1 with the given opaque value.
	 *
	 * @param buff1  buffer
	 * @param buff2  buffer
	 * @param opaque how opaque the second buffer should be drawn
	 * @param x      x position where the second buffer should be drawn
	 * @param y      y position where the second buffer should be drawn
	 */
	private void addImage(final @NotNull BufferedImage buff1, final BufferedImage buff2,
	                      final float opaque, final int x, final int y) {
		final Graphics2D g2d = buff1.createGraphics();
		g2d.setComposite(AlphaComposite.getInstance(AlphaComposite.SRC_OVER, opaque));
		g2d.drawImage(buff2, x, y, null);
		g2d.dispose();
	}

	/**
	 * returns a BufferedImage from the Image provided.
	 *
	 * @param fileName path to image
	 * @return loaded image
	 */
	private @Nullable BufferedImage loadImage(final String fileName) {
		try (final InputStream inputStream = getClass().getClassLoader().getResourceAsStream(fileName)) {
			if (inputStream == null) {
				return null;
			}
			return ImageIO.read(inputStream);
		}
		catch (Exception e) {
			log.error("Error loading image path=[{}]", fileName, e);
			return null;
		}
	}

	/**
	 * returns a hash calculated by the given point.
	 *
	 * @param point a point
	 * @return hash value
	 */
	private int getKey(final @NotNull Point point) {
		return ((point.x << 19) | (point.y << 7));
	}
}