package org.zalando.straw;

import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.net.URL;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


public class Straw {

    public static final int SECOND = 1000;
    public static final long OFFSET_BEGIN = -1;

    public static final class HttpException extends Exception {
        public HttpException(String message) {
            super(message);
        }
    }

    static final class Cursor {

        static Cursor extract(String line) throws Exception {
            try {
                Scanner scanner = new Scanner(line);
                int partition = Integer.parseInt(scanner.findInLine("\\d+"));
                long offset = Long.parseLong(scanner.findInLine("\\d+"));
                return new Cursor(partition, offset);
            } catch (NumberFormatException e) {
                throw new Exception(line);
            }
        }

        final int partition;
        final long offset;

        Cursor(int partition, long offset) {
            this.partition = partition;
            this.offset = offset;
        }

        @Override public String toString() {
            return String.format("{\"partition\":\"%d\",\"offset\":\"%s\"}", partition, offset());
        }

        private String offset() {
            return offset == OFFSET_BEGIN ? "BEGIN" : Long.toString(offset);
        }
    }

    private final ExecutorService _executor = Executors.newSingleThreadExecutor();
    private final URL _url;
    private final Map<Integer, Long> _cursors;
    private final boolean _allPartitions;

    public Straw(URL url, Map<Integer, Long> cursors) {
        _url = url;
        _cursors = new HashMap(cursors);
        _allPartitions = cursors.isEmpty();
    }

    public void start() {
        _executor.submit(() -> { while (true) fetchStream(); });
    }

    protected String loadToken() throws Exception {
        return System.getenv("TOKEN");
    }

    protected void handleEvents(String json) throws Exception {
        logDebug("handleEvents: " + json);
    }

    protected void logDebug(String message) { System.out.println("DEBUG: " + message); }

    protected void logInfo(String message) { System.out.println("INFO: " + message); }

    protected void logError(Exception e) { e.printStackTrace(System.err); }

    private void fetchStream() {
        logInfo("fetchStream: " + (_allPartitions ? "all partitions (END)" : cursorString()));
        try {
            SSLSocket socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(_url.getHost(), port());
            try {
                socket.setSoTimeout(60 * SECOND);
                sendRequest(socket);
                BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8"));
                skipHeaders(in);
                // simple state machine to read chunked encoding. each line can be SIZE, DATA or EMPTY.
                // a single nakadi batch can span multiple chunks, for example:
                // SIZE DATA SIZE DATA SIZE DATA EMPTY
                StringBuilder batch = new StringBuilder();
                boolean isData = false;
                String line;
                while ((line = in.readLine()) != null) {
                    if (isData) {
                        batch.append(line);
                        isData = false;
                    } else {
                        if (line.isEmpty()) {
                            handleBatch(batch.toString());
                            batch.setLength(0);
                        } else {
                            isData = true;
                        }
                    }
                }
            } finally {
                socket.close();
            }
        } catch (Exception e) {
            logError(e);
            tryToSleep(2 * SECOND);
        }
    }

    private void handleBatch(String batch) throws Exception {
        if (!batch.isEmpty()) {
            Cursor cursor = Cursor.extract(batch);
            handleEvents(batch);
            // no exception, so we can update _cursors
            if (cursor.offset > _cursors.getOrDefault(cursor.partition, OFFSET_BEGIN)) {
                _cursors.put(cursor.partition, cursor.offset);
            } 
        }
    }

    private void sendRequest(SSLSocket socket) throws Exception {
        socket.startHandshake();
        PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(socket.getOutputStream())));
        out.println("GET " + requestPath() + " HTTP/1.1");
        out.println("Host: " + _url.getHost());
        out.println("User-Agent: straw");
        if (!_allPartitions) out.println("X-Nakadi-Cursors: " + cursorString());
        out.println("Authorization: Bearer " + loadToken().trim()); //FIXME: token can be null!
        out.println();
        out.flush();
    }

    private int port() {
        return _url.getPort() == -1 ? 443 : _url.getPort();
    }

    private String requestPath() {
        return _url.getQuery() == null ? _url.getPath() : _url.getPath() + "?" + _url.getQuery();
    }

    private String cursorString() {
        List<Cursor> result = new ArrayList();
        for (int partition : _cursors.keySet()) {
            result.add(new Cursor(partition, _cursors.get(partition)));
        }
        return Arrays.toString(result.toArray());
    }

    private static void skipHeaders(BufferedReader out) throws IOException, HttpException {
        String line, status = null;
        while ((line = out.readLine()) != null) {
            if (status == null) {
                status = line.split("\\s", 2)[1]; // skip protocol part
                if (!status.startsWith("200")) throw new HttpException(status);
            } else if (line.trim().isEmpty()) {
                 break;
            }
        }
    }

    private static void tryToSleep(int millis) {
        try { Thread.sleep(millis); } catch (InterruptedException ignored) {}
    }
}
