package org.v2u.stupidql;

import org.apache.commons.dbutils.*;
import org.apache.commons.dbutils.handlers.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.sql.Connection;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.AbstractMap.SimpleEntry;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class StupidQL {
    public static final String IGNORE = "*IGNORE";
    public static final String FIELDS = "*FIELDS";
    private static final Pattern FORMAT_PH_PATTERN = Pattern.compile("([#@]?)\\{([^}]+?)}");
    private static final Pattern INT_PATTERN = Pattern.compile("^\\d+$");
    private static final Pattern ABBR_TO_WORD = Pattern.compile("([A-Z]+)([A-Z][a-z])");
    private static final Pattern CAMEL_TO_SNAKE = Pattern.compile("([a-z\\d])([A-Z])");
    private static final Map<Class<?>, Map<String, Info>> CLASS_INFO_CACHE = new ConcurrentHashMap<>();
    private static final Logger log = LoggerFactory.getLogger(StupidQL.class);
    protected final List<String> queryParts = new ArrayList<>();
    protected final List<List<Object>> bindValues = new ArrayList<>();
    protected Object[][] batchBindValues = null;
    protected final Map<String, Integer> marks = new HashMap<>();
    protected Connection txConn = null;
    protected Function<String, String> quoter = makeQuoter("`");
    private Function<String, String> namingStrategy = StupidQL::toSnake;
    protected QueryRunner runner;

    public static StupidQL init(DataSource ds) {
        return new StupidQL(ds);
    }

    protected StupidQL init(Connection conn) {
        StupidQL stupidQL = copy();
        stupidQL.runner = new QueryRunner();
        stupidQL.txConn = conn;
        return stupidQL;
    }

    protected StupidQL(DataSource ds) {
        runner = new QueryRunner(ds);
    }

    public String getSql() {
        return String.join(" ", queryParts);
    }

    public Object[] getParams() {
        List<Object> result = new ArrayList<>();
        for(List<Object> vals : bindValues) {
            result.addAll(vals);
        }
        return result.toArray();
    }

    public StupidQL reset() {
        StupidQL stupidQL = copy();
        stupidQL.queryParts.clear();
        stupidQL.bindValues.clear();
        stupidQL.marks.clear();
        stupidQL.batchBindValues = null;
        return stupidQL;
    }

    public Connection getTxConn() {
        if (txConn == null) {
            throw new StupidException("Not in a transaction context");
        }
        return txConn;
    }

    protected Connection selectConn() throws SQLException {
        if (txConn != null) {
            return txConn;
        }

        if (runner.getDataSource() != null) {
            return runner.getDataSource().getConnection();
        }

        throw new StupidException("Unable to obtain database connection");
    }


    private <T> T execute(ConnectionCallback<T> action) {
        Connection conn = null;
        try {
            conn = selectConn();
            return action.execWithConn(conn);
        } catch (SQLException cause) {
            throw new StupidException(cause);
        } finally {
            if (conn != null && conn != txConn) {
                try {
                    conn.close();
                } catch (SQLException ex) {
                    log.warn("Failed to close connection: {}", ex.getMessage());
                }
            }
        }
    }

    public <T> T fetch(final ResultSetHandler<T> rst) {
        debug();
        return execute(conn -> runner.query(conn, getSql(), rst, getParams()));
    }

    public <T> List<T> fetchBeans(Class<T> beanType) {
        RowProcessor rowProcessor = new BasicRowProcessor(new StupidBeanProcessor(beanType));
        ResultSetHandler<List<T>> handler = new BeanListHandler<>(beanType, rowProcessor);
        return fetch(handler);
    }

    public <T> T fetchBean(Class<T> beanType) {
        List<T> result = fetchBeans(beanType);
        if (result.size() > 1) {
            String error = String.format(
              "Non-unique result: query returned %d rows when expecting exactly one row",
              result.size());
            throw new StupidException( error);
        }
        return result.isEmpty() ? null : result.get(0);
    }

    public <K, V> Map<K, V> fetchBeanMap(Class<V> type, Function<V, K> keyExtractor) {
        return fetchBeans(type).stream().collect(Collectors.toMap(keyExtractor, v -> v, (a, b) -> {
            throw new StupidException("duplicate key in map: " + keyExtractor.apply(a));
        }));
    }

    public <K, V> Map<K, List<V>> fetchBeanGroup(Class<V> type, Function<V, K> keyExtractor) {
        return fetchBeans(type).stream().collect(Collectors.groupingBy(keyExtractor));
    }

    public List<Map<String, Object>> fetchMaps() {
        MapListHandler handler = new MapListHandler();
        return fetch(handler);
    }

    public Map<String, Object> fetchMap() {
        MapHandler handler = new MapHandler();
        return fetch(handler);
    }

    public <T> T fetchScalar(Class<T> retType) {
        ScalarHandler<T> handler = new ScalarHandler<>();
        return fetch(handler);
    }

    public StupidQL select(String table, String where, Object... v) {
        return select(table).add("where " + where, v);
    }

    public StupidQL select(String table) {
        return addRaw("select").mark(FIELDS, "*").addRaw("from " + this.quote(table));
    }

    public StupidQL fields(String... fields) {
        if (!marks.containsKey(FIELDS)) {
            throw new StupidException("no fields mark!");
        }

        return mark(FIELDS, String.join(", ", fields));
    }

    public int update() {
        debug();
        return execute(conn -> runner.update(conn, getSql(), getParams()));
    }

    public int update(Object bean, String where, Object... v) {
        return update(getTableName(bean), bean, where, v);
    }

    public int update(String table, Object bean, String where, Object... v) {
        Map<String, Object> params = beanToColumnMap(bean, true, false);
        SimpleEntry<List<String>, List<Object>> kv = pair(params);
        String fields = kv.getKey().stream().map(k -> String.format("%s = ?", k)).collect(Collectors.joining(", "));
        String sql = String.format("update %s set %s", quote(table), fields);
        return addRaw(sql, kv.getValue().toArray()).add("where " + where, v).update();
    }

    public int delete() {
        return update();
    }

    public int delete(String table, String where, Object... v) {
        return addRaw("delete from " + quote(table)).add("where " + where, v).delete();
    }

    public <T> T insert(Class<T> pkType) {
        debug();
        return execute(conn -> runner.insert(conn, getSql(), new ScalarHandler<>(), getParams()));
    }

    public Long insert() {
        Number pk = insert(Number.class);
        return pk == null ? null : pk.longValue();
    }

    public Long insert(Object bean) {
        return addInsert(bean).insert();
    }

    public Long insert(String table, Object bean) {
        return addInsert(table, bean).insert();
    }

    public StupidQL addInsert(Object bean) {
        return addInsert(getTableName(bean), bean);
    }

    public StupidQL addInsert(String table, Object bean) {
        Map<String, Object> params = beanToColumnMap(bean, false, true);
        SimpleEntry<List<String>, List<Object>> kv = pair(params);
        String fields = String.join(", ", kv.getKey());
        String sql = String.format("into %s (%s) values (?)", quote(table), fields);
        return addRaw("insert").mark(IGNORE, "").addRaw(sql, kv.getValue());
    }

    public StupidQL addInsertBatch(List<?> batch) {
        if (batch == null || batch.isEmpty()) {
            throw new StupidException("Batch params empty");
        }

        return addInsertBatch(getTableName(batch.get(0)), batch);
    }

    public StupidQL addInsertBatch(String table, List<?> batch) {
        if (table == null || table.trim().isEmpty()) {
            throw new StupidException("table name required");
        }
        if (batch == null || batch.isEmpty()) {
            throw new StupidException("Batch params empty");
        }

        StupidQL query = this.reset(); // 创建一个干净的副本进行操作

        List<Map<String, Object>> rows = new ArrayList<>();
        batch.forEach(it -> rows.add(beanToColumnMap(it, false, true)));

        // 从第一个对象获取所有列名
        List<String> keys = new ArrayList<>(rows.get(0).keySet());

        // 构建SQL模板
        String fieldStr = keys.stream().map(this::quote).collect(Collectors.joining(", "));
        String placeholders = String.join(", ", Collections.nCopies(keys.size(), "?"));
        String sql = String.format("into %s (%s) values (%s)", quote(table), fieldStr, placeholders);

        query = query.addRaw("insert").mark(IGNORE, "");
        query.appendSqlAndArgs(sql, new ArrayList<>());

        // 构建批量参数数组
        Object[][] batchArgs = new Object[rows.size()][keys.size()];
        for (int i = 0; i < rows.size(); i++) {
            Map<String, Object> params = rows.get(i);
            for (int j = 0; j < keys.size(); j++) {
                batchArgs[i][j] = params.get(keys.get(j));
            }
        }

        query.batchBindValues = batchArgs; // 将二维参数数组存入新成员变量

        return query;
    }

    public <R> List<R> insertBatch(Class<R> pkType) {
        if (this.batchBindValues == null) {
            throw new StupidException("not a batch operation");
        }
        debug();
        return execute(conn -> runner.insertBatch(conn, getSql(), new ColumnListHandler<>(), this.batchBindValues));
    }

    public List<Long> insertBatch() {
        return insertBatch(Number.class)
          .stream()
          .map(Number::longValue)
          .collect(Collectors.toList());
    }

    public <R> R transaction(Function<StupidQL, R> action) {
        if (txConn != null) {
            throw new StupidException("Nested transactions are not allowed");
        }
        Connection conn = null;
        Boolean isAutoCommit = null;
        Integer isolationLevel = null;
        try {
            conn = runner.getDataSource().getConnection();

            isAutoCommit = conn.getAutoCommit();
            isolationLevel = conn.getTransactionIsolation();

            conn.setAutoCommit(false);
            StupidQL tx = init(conn);
            R result = action.apply(tx);
            conn.commit();
            return result;
        } catch (Throwable e) {
            try {
                if (conn != null) {
                    conn.rollback();
                }
            } catch (SQLException ex) {
                throw new StupidException("Failed to rollback transaction", ex);
            }
            throw new StupidException("Transaction failed", e);
        } finally {
            try {
                if (conn != null) {
                    if (isAutoCommit != null) {
                        conn.setAutoCommit(isAutoCommit);
                    }
                    if (isolationLevel != null) {
                        conn.setTransactionIsolation(isolationLevel);
                    }
                    conn.close();
                }
            } catch (SQLException e) {
                log.warn("Failed to close connection", e);
            }
        }
    }

    public String getTableName(Object bean) {
        Class<?> beanClass = bean.getClass();
        Info info = beanClass.getAnnotation(Info.class);
        if (info == null) {
            return namingStrategy.apply(beanClass.getSimpleName());
        }
        return info.name();
    }

    protected StupidQL copy() {
        StupidQL query = new StupidQL(runner.getDataSource());
        query.txConn = txConn;
        query.queryParts.addAll(queryParts);
        query.bindValues.addAll(bindValues);
        query.marks.putAll(marks);
        query.quoter = quoter;
        query.batchBindValues = batchBindValues;

        return query;
    }

    public StupidQL setQuoter(Function<String, String> quoter) {
        StupidQL query = copy();
        query.quoter = quoter;
        return query;
    }

    public static Function<String, String> makeQuoter(String quote) {
        return identifier -> {
            if (identifier == null) {
                throw new StupidException("Identifier cannot be null");
            }

            if (identifier.trim().isEmpty()) {
                throw new StupidException("Identifier cannot be empty");
            }

            return Arrays.stream(identifier.split("\\."))
              .map(part -> quote + part.replace( quote, quote + quote) + quote)
              .collect(Collectors.joining("."));
        };
    }

    public String quote(String identity) {
        return quoter.apply(identity);
    }

    public StupidQL add(boolean yes, String tpl, Object... v) {
        if (yes) return add(tpl, v);
        return this;
    }

    public SimpleEntry<String, List<Object>> parseText(String tpl, Object... v) {
        Matcher m = FORMAT_PH_PATTERN.matcher(tpl);
        StringBuffer sb = new StringBuffer();
        List<Object> sqlParams = new ArrayList<>();

        AtomicReference<Map<String, Object>> namedArgs = new AtomicReference<>(null);

        Runnable namedArgsInit = () -> {
            if(namedArgs.get() != null) return;

            if(v.length > 0) {
                Object arg = v[v.length - 1];
                namedArgs.set(beanToMap(arg));
            } else {
                throw new IllegalArgumentException("Named argument required at latest position");
            }
        };


        boolean find = false;
        while (m.find()) {
            find = true;

            Object arg;
            String key = m.group(2).trim();

            if(INT_PATTERN.matcher(key).find()) {
                int idx = Integer.parseInt(key);
                if (idx < 1 || idx > v.length) {
                    throw new IllegalArgumentException("Index " + (idx) + " out of range");
                }
                arg = v[idx - 1];
            } else {
                namedArgsInit.run();
                if (!namedArgs.get().containsKey(key)) {
                    throw new IllegalArgumentException("Named argument '"+ key +"' not found");
                }
                arg = namedArgs.get().get(key);
            }

            String identity = m.group(1);
            switch (identity) {
                case "": //raw replace
                    if (arg == null) throw new IllegalArgumentException("null not allowed: " + key);
                    m.appendReplacement(sb, Matcher.quoteReplacement(arg.toString()));
                    break;
                case "@": //quote replace
                    if (arg == null) throw new IllegalArgumentException("null not allowed: " + key);
                    m.appendReplacement(sb, Matcher.quoteReplacement(quote(arg.toString())));
                    break;
                case "#": //sql param
                    sqlParams.add(arg);
                    m.appendReplacement(sb, "?");
                    break;
            }
        }
        m.appendTail(sb);

        if (!find) {
            return new SimpleEntry<>(sb.toString(), Arrays.asList(v));
        }

        return new SimpleEntry<>(sb.toString(), sqlParams);
    }

    public StupidQL add(String tpl, Object... v) {
        SimpleEntry<String, List<Object>> parsed = parseText(tpl, v);
        return addRaw(parsed.getKey(), parsed.getValue().toArray());
    }

    public StupidQL addRaw(boolean yes, String sql, Object... v) {
        if (yes) return addRaw(sql, v);
        return this;
    }

    public StupidQL addRaw(String sql, Object... v) {
        StupidQL query = copy();
        query.appendParams(sql, v);
        return query;
    }


    public StupidQL namingStrategy(Function<String, String> fn) {
        StupidQL query = copy();
        query.namingStrategy = fn;
        return query;
    }

    public StupidQL mark(String name, String sql, Object ...v) {
        StupidQL query = copy();
        SimpleEntry<String, List<Object>> parsed = parseText(sql, v);
        if (query.marks.containsKey(name)) {
            Integer idx = query.marks.get(name);
            query.queryParts.set(idx, parsed.getKey());
            query.bindValues.set(idx, parsed.getValue());

        } else {
            query.appendSqlAndArgs(sql, parsed.getValue());
            query.marks.put(name, query.queryParts.size() - 1);
        }
        return query;
    }

    protected void appendParams(String sql, Object[] args) {
        String[] parts = (sql + " ").split("\\?");

        if (parts.length != args.length + 1) {
            String msg = String.format("Placeholders length (%d) doesn't match parameters length (%d)",
                                          parts.length - 1,
                                          args.length);
            throw new StupidException(msg);
        }

        List<String> localParts = new ArrayList<>();
        List<Object> localArgs = new ArrayList<>();

        for (int i = 0; i < parts.length; i++) {
            if (args.length <= i) {
                localParts.add(parts[i]);
                continue;
            }

            Object arg = args[i];

            if (arg == null) {
                localParts.add(parts[i] + '?');
                localArgs.add(null);
            } else if (arg instanceof Collection) {
                Collection<?> listArgs = (Collection<?>) arg;
                localParts.add(parts[i] + makeArrayPlaceHolders(listArgs.size()));
                localArgs.addAll(listArgs);
            } else if (arg.getClass().isArray()) {
                List<Object> listArgs = Arrays.asList((Object[]) arg);
                localParts.add(parts[i] + makeArrayPlaceHolders(listArgs.size()));
                localArgs.addAll(listArgs);
            } else {
                localParts.add(parts[i] + '?');
                localArgs.add(arg);
            }
        }

        appendSqlAndArgs(String.join(" ", localParts), localArgs);
    }

    protected void appendSqlAndArgs(String sql, List<Object> args) {
        queryParts.add(sql);
        bindValues.add(args);
    }

    public String makeArrayPlaceHolders(int len) {
        StringBuilder marks = new StringBuilder();
        for (int i = 0; i < len; i++) {
            if (i > 0) marks.append(',');
            marks.append('?');
        }
        return marks.toString();
    }

    protected void debug() {
        if (log.isDebugEnabled()) {
            log.debug(this.toString());
        }
    }

    @Override
    public String toString() {
        String sql = getSql().replaceAll("(#|--)[^\\n]*", "").replace("\n", " ").trim();
        return "SQL: " + sql + " #" + Arrays.toString(getParams());
    }

    public static String toSnake(String input) {
        if (input == null || input.isEmpty()) {
            return input;
        }
        String s = ABBR_TO_WORD.matcher(input).replaceAll("$1_$2");
        s = CAMEL_TO_SNAKE.matcher(s).replaceAll("$1_$2");
        return s.toLowerCase();
    }

    public static Map<String, Object> mapOf(Object... v) {
        Map<String, Object> result = new HashMap<>();
        if(v.length % 2 != 0) throw new IllegalArgumentException("argument length must be even");
        for (int i = 0; i < v.length; i += 2) {
            String prefix = v[i].toString();
            result.put(prefix, v[i+1]);
            Map<String, Object> temp = beanToMap(v[i+1]);
            if (temp.isEmpty()) continue;
            for (Map.Entry<String, Object> kv : temp.entrySet()) {
                result.put(prefix + "." + kv.getKey(), kv.getValue());
            }
        }
        return result;
    }

    @SuppressWarnings("unchecked rawtypes")
    public static Map<String, Object> beanToMap(Object bean) {
        if(bean instanceof Map) {
            Map<String, Object> result = new HashMap<>();
            ((Map) bean).forEach((k, v) -> result.put(k.toString(), v));
            return result;
        }

        try {
            Map<String, Object> result = new HashMap<>();
            Class<?> clazz = bean.getClass();
            PropertyDescriptor[] props = Introspector.getBeanInfo(clazz, Object.class).getPropertyDescriptors();
            for (PropertyDescriptor pd : props) {
                Method getter = pd.getReadMethod();  // getXxx()
                if (getter == null || getter.getDeclaringClass().getName().startsWith("java.")) {
                    continue;
                }

                int modifiers = getter.getModifiers();
                if (!Modifier.isPublic(modifiers) || Modifier.isStatic(modifiers)) {
                    continue;
                }

                result.put(pd.getName(), getter.invoke(bean));
            }
            return result;
        } catch (Exception ex) {
            throw new StupidException(ex);
        }
    }

    @SuppressWarnings("unchecked rawtypes")
    protected Map<String, Object> beanToColumnMap(Object bean, boolean isUpdate, boolean isInsert) {
        Map<String, Object> result = new HashMap<>();

        if(bean instanceof Map) {
            ((Map) bean).forEach((k, v) -> {
                if(v != null) result.put(namingStrategy.apply(k.toString()), v);
            });
            return result;
        }

        Map<String, Object> source = beanToMap(bean);
        Map<String, Info> infoMap = getCachedClassInfo(bean.getClass());
        for (Map.Entry<String, Info> kv : infoMap.entrySet()) {
            boolean ignoreNull = true;
            String fieldName = kv.getKey();
            String columnName = namingStrategy.apply(fieldName);
            Info info = kv.getValue();
            if (info != null) {
                if (!info.exists()) continue;
                if (isUpdate && !info.update()) continue;
                if (isInsert && !info.insert()) continue;

                ignoreNull = info.ignoreNull();

                String annoColumnName = info.name();
                if (annoColumnName != null && !annoColumnName.trim().isEmpty()) {
                    columnName = annoColumnName;
                }
            }

            Object fieldValue = source.get(fieldName);
            if (!ignoreNull || fieldValue != null) {
                result.put(columnName, fieldValue);
            }
        }
        return result;
    }

    public SimpleEntry<List<String>, List<Object>> pair(Map<String, Object> data) {
        List<String> keys = new ArrayList<>(data.size());
        List<Object> values = new ArrayList<>(data.size());

        int index = 0;
        for (Map.Entry<String, Object> kv : data.entrySet()) {
            String key = quote(kv.getKey());
            keys.add(index, key);
            values.add(index, kv.getValue());
            index++;
        }

        return new SimpleEntry<>(keys, values);
    }

    public static Map<String, Info> getCachedClassInfo(Class<?> type) {
        return CLASS_INFO_CACHE.computeIfAbsent(type, t -> {
            Class<?> superClass = type.getSuperclass();

            Map<String, Info> fieldInfoMap = new HashMap<>();

            if (superClass != null && !superClass.getName().startsWith("java.")) {
                Map<String, Info> superFieldInfoMap = getCachedClassInfo(superClass);
                fieldInfoMap.putAll(superFieldInfoMap);
            }

            for (Field field : t.getDeclaredFields()) {
                Info info = field.getAnnotation(Info.class);
                fieldInfoMap.put(field.getName(), info);
            }
            return fieldInfoMap;
        });
    }

    @FunctionalInterface
    static
    interface ConnectionCallback<T> {
        T execWithConn(Connection conn) throws SQLException;
    }

    public static class StupidBeanProcessor extends BeanProcessor {
        private final Map<String, Info> infoMap;

        public StupidBeanProcessor(Class<?> type) {
            this.infoMap = getCachedClassInfo(type);
        }

        @Override
        protected int[] mapColumnsToProperties(final ResultSetMetaData rsmd, final PropertyDescriptor[] props) throws SQLException {
            final int cols = rsmd.getColumnCount();
            final int[] columnToProperty = new int[cols + 1];
            Arrays.fill(columnToProperty, PROPERTY_NOT_FOUND);

            for (int col = 1; col <= cols; col++) {
                String columnName = rsmd.getColumnLabel(col);

                if (null == columnName || columnName.isEmpty()) {
                    columnName = rsmd.getColumnName(col);
                }

                final String generousColumnName = columnName
                  .replace("_", "")   // more idiomatic to Java
                  .replace(" ", "");  // can't have spaces in property names

                for (int i = 0; i < props.length; i++) {
                    String propName = props[i].getName();

                    Info info = infoMap.get(propName);
                    if (info != null && !info.name().isEmpty()) {
                        propName = info.name();
                    }

                    // see if either the column name, or the generous one matches
                    if (columnName.equalsIgnoreCase(propName) || generousColumnName.equalsIgnoreCase(propName)) {
                        columnToProperty[col] = i;
                        break;
                    }
                }
            }

            return columnToProperty;
        }
    }

    public static class StupidException extends RuntimeException {
        public StupidException(String message) {
            super(message);
        }

        public StupidException(Throwable cause) {
            super(cause);
        }

        public StupidException(String message, Throwable cause) {
            super(message, cause);
        }
    }

    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.FIELD, ElementType.TYPE})
    public static @interface Info {
        String name() default "";

        boolean exists() default true;

        boolean update() default true;

        boolean insert() default true;

        boolean ignoreNull() default true;
    }
}