package org.v2u.toy.duck;

import org.apache.commons.dbutils.*;
import org.apache.commons.dbutils.handlers.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.math.BigInteger;
import java.sql.Connection;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.AbstractMap.SimpleEntry;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class Duck {
    public static final String IGNORE = "*IGNORE";
    public static final String FIELDS = "*FIELDS";
    private static final Pattern FORMAT_PH_PATTERN = Pattern.compile("([#@]?)\\{([^}]+?)}");
    private static final Pattern INT_PATTERN = Pattern.compile("^\\d+$");
    private static final Pattern ABBR_TO_WORD = Pattern.compile("([A-Z]+)([A-Z][a-z])");
    private static final Pattern CAMEL_TO_SNAKE = Pattern.compile("([a-z\\d])([A-Z])");
    private static final Map<Class<?>, Map<String, Info>> CLASS_INFO_CACHE = new ConcurrentHashMap<>();
    private static final Logger log = LoggerFactory.getLogger(Duck.class);
    protected final List<String> queryParts = new ArrayList<>();
    protected final List<Object> bindValues = new ArrayList<>();
    protected final Map<String, Integer> marks = new HashMap<>();
    protected Connection txConn = null;
    protected Function<String, String> quoter = makeQuoter("`");
    private Function<String, String> namingStrategy = Duck::toSnake;
    protected QueryRunner runner;

    public static Duck init(DataSource ds) {
        return new Duck(ds);
    }

    protected Duck init(Connection conn) {
        Duck duck = copy();
        duck.runner = new QueryRunner();
        duck.txConn = conn;
        return duck;
    }

    protected Duck(DataSource ds) {
        runner = new QueryRunner(ds);
    }

    public String getSql() {
        return String.join(" ", queryParts);
    }

    public Object[] getParams() {
        return bindValues.toArray();
    }

    public Duck reset() {
        this.queryParts.clear();
        this.bindValues.clear();
        this.marks.clear();
        return this;
    }

    public Connection getTxConn() {
        if (txConn == null) {
            throw new DuckException("Not in a transaction context");
        }
        return txConn;
    }

    protected Connection selectConn() throws SQLException {
        if (txConn != null) {
            return txConn;
        }

        if (runner.getDataSource() != null) {
            return runner.getDataSource().getConnection();
        }

        throw new DuckException("Unable to obtain database connection");
    }


    private <T> T execute(ConnectionCallback<T> action) {
        Connection conn = null;
        try {
            conn = selectConn();
            return action.execWithConn(conn);
        } catch (SQLException cause) {
            throw new DuckException(cause);
        } finally {
            if (conn != null && conn != txConn) {
                try {
                    conn.close();
                } catch (SQLException ex) {
                    log.warn("Failed to close connection: " + ex.getMessage());
                }
            }
        }
    }

    public <T> T fetch(final ResultSetHandler<T> rst) {
        debug();
        return execute(conn -> runner.query(conn, getSql(), rst, getParams()));
    }

    public <T> List<T> fetchBeans(Class<T> beanType) {
        RowProcessor rowProcessor = new BasicRowProcessor(new DuckBeanProcessor(beanType));
        ResultSetHandler<List<T>> handler = new BeanListHandler<>(beanType, rowProcessor);
        return fetch(handler);
    }

    public <T> T fetchBean(Class<T> beanType) {
        List<T> result = fetchBeans(beanType);
        if (result.size() > 1) {
            String errmsg = String.format(
              "Non-unique result: query returned %d rows when expecting exactly one row",
              result.size());
            throw new DuckException(errmsg);
        }
        return result.isEmpty() ? null : result.get(0);
    }

    public <K, V> Map<K, V> fetchBeanMap(Class<V> type, Function<V, K> keyExtractor) {
        return fetchBeans(type).stream().collect(Collectors.toMap(keyExtractor, v -> v, (a, b) -> {
            throw new DuckException("duplicate key in map: " + keyExtractor.apply(a));
        }));
    }

    public <K, V> Map<K, List<V>> fetchBeanGroup(Class<V> type, Function<V, K> keyExtractor) {
        return fetchBeans(type).stream().collect(Collectors.groupingBy(keyExtractor));
    }

    public List<Map<String, Object>> fetchMaps() {
        MapListHandler handler = new MapListHandler();
        return fetch(handler);
    }

    public Map<String, Object> fetchMap() {
        MapHandler handler = new MapHandler();
        return fetch(handler);
    }

    public <T> T fetchScalar(Class<T> retType) {
        ScalarHandler<T> handler = new ScalarHandler<>();
        return fetch(handler);
    }

    public Duck select(String table) {
        return addRaw("select").mark(FIELDS, "*").addRaw("from " + this.quote(table));
    }

    public Duck fields(String... fields) {
        if (!marks.containsKey(FIELDS)) {
            throw new DuckException("no fields mark!");
        }

        return mark(FIELDS, String.join(", ", fields));
    }

    public Duck select(String table, String where, Object... v) {
        return select(table).add("where " + where, v);
    }

    public int update() {
        debug();
        return execute(conn -> runner.update(conn, getSql(), getParams()));
    }

    public int update(Object bean, String where, Object... v) {
        Info info = mustGetTableConfig(bean);
        return update(info.name(), bean, where, v);
    }

    public int update(String tableName, Object bean, String where, Object... v) {
        Map<String, Object> params = beanToMap(bean, true, false);
        SimpleEntry<String, Object[]> paramsKv = pairUpdate(params);
        String sql = String.format("update %s set %s", quote(tableName), paramsKv.getKey());
        return addRaw(sql, paramsKv.getValue()).add("where " + where, v).update();
    }

    public int delete() {
        return update();
    }

    public int delete(String table, String where, Object... v) {
        return addRaw("delete from " + quote(table)).add("where " + where, v).delete();
    }

    public <T> T insert(Class<T> pkType) {
        debug();
        return execute(conn -> runner.insert(conn, getSql(), new ScalarHandler<>(), getParams()));
    }

    public Long insert() {
        BigInteger pk = insert(BigInteger.class);
        return pk == null ? null : pk.longValue();
    }

    public Duck onDuplicateIgnore() {
        if (!marks.containsKey(IGNORE)) {
            throw new DuckException("no ignore mark!");
        }

        return mark(IGNORE, "ignore");
    }

    public Duck addInsert(String table, Object bean) {
        Map<String, Object> params = beanToMap(bean, false, true);
        SimpleEntry<String, List<Object>> kv = pairInsert(params);
        String sql = String.format("into %s (%s) values (?)", quote(table), kv.getKey());
        return addRaw("insert").mark(IGNORE, "").addRaw(sql, kv.getValue());
    }

    public Duck addInsert(Object bean) {
        Info info = mustGetTableConfig(bean);
        return addInsert(info.name(), bean);
    }

    public List<String> insertBatch(List<?> batch, boolean ignoreDup) {
        Info info = mustGetTableConfig(batch.get(0));
        return insertBatch(info.name(), batch, ignoreDup);
    }

    public List<String> insertBatch(String table, List<?> batch, boolean ignoreDup) {
        if (batch == null || batch.isEmpty()) {
            throw new DuckException("Batch params empty");
        }

        List<Map<String, Object>> rows = new ArrayList<>();
        batch.forEach(it -> rows.add(beanToMap(it, false, true)));

        // 获取第一个Map的所有键并固定顺序
        Map<String, Object> firstMap = rows.get(0);
        List<String> keys = new ArrayList<>(firstMap.keySet());

        // 构建SQL语句
        String ignore = ignoreDup ? "ignore" : "";
        String fieldStr = keys.stream().map(this::quote).collect(Collectors.joining(", "));
        String placeholders = String.join(", ", Collections.nCopies(keys.size(), "?"));
        String sql = String.format("insert %s into %s (%s) values (%s)", ignore, quote(table), fieldStr, placeholders);

        // 构建批量参数数组
        Object[][] batchArgs = new Object[rows.size()][];
        for (int i = 0; i < rows.size(); i++) {
            Map<String, Object> params = rows.get(i);
            Object[] values = new Object[keys.size()];
            for (int j = 0; j < keys.size(); j++) {
                values[j] = params.get(keys.get(j));
            }
            batchArgs[i] = values;
        }

        if (log.isDebugEnabled()) {
            List<String> args = new ArrayList<>();
            for (Object[] arg : batchArgs) {
                args.add("> " + Arrays.toString(arg));
            }
            log.debug("SQL: " + sql + "\n" + String.join("\n", args));
        }

        List<?> pks = execute(conn -> runner.insertBatch(conn, sql, new ColumnListHandler<>(), batchArgs));
        return pks.stream().map(Object::toString).collect(Collectors.toList());
    }

    public <R> R transaction(Function<Duck, R> action) {
        if (txConn != null) {
            throw new DuckException("Nested transactions are not allowed");
        }
        Connection conn = null;
        Boolean isAutoCommit = null;
        try {
            conn = runner.getDataSource().getConnection();
            isAutoCommit = conn.getAutoCommit();
            conn.setAutoCommit(false);
            Duck tx = init(conn);
            R result = action.apply(tx);
            conn.commit();
            return result;
        } catch (Throwable e) {
            try {
                if (conn != null) {
                    conn.rollback();
                }
            } catch (SQLException ex) {
                throw new DuckException("Failed to rollback transaction", ex);
            }
            throw new DuckException("Transaction failed", e);
        } finally {
            try {
                if (conn != null) {
                    if (isAutoCommit != null) {
                        conn.setAutoCommit(isAutoCommit);
                    }
                    conn.close();
                }
            } catch (SQLException e) {
                log.warn("Failed to close connection: " + e.getMessage());
            }
        }
    }

    private Info mustGetTableConfig(Object bean) {
        Info info = bean.getClass().getAnnotation(Info.class);
        if (info == null) {
            throw new DuckException("table config required");
        }
        return info;
    }

    protected Duck copy() {
        Duck query = new Duck(runner.getDataSource());
        query.txConn = txConn;
        query.queryParts.addAll(queryParts);
        query.bindValues.addAll(bindValues);
        query.marks.putAll(marks);
        query.quoter = quoter;

        return query;
    }

    public Duck setQuoter(Function<String, String> quoter) {
        Duck query = copy();
        query.quoter = quoter;
        return query;
    }

    public static Function<String, String> makeQuoter(String quote) {
        return identifier -> {
            if (identifier == null) {
                throw new DuckException("Identifier cannot be null");
            }

            if (identifier.trim().isEmpty()) {
                throw new DuckException("Identifier cannot be empty");
            }

            return Arrays.stream(identifier.split("\\."))
              .map(part -> quote + part.replace( quote, quote + quote) + quote)
              .collect(Collectors.joining("."));
        };
    }

    public String quote(String identity) {
        return quoter.apply(identity);
    }

    public Duck add(boolean yes, String tpl, Object... v) {
        if (yes) return add(tpl, v);
        return this;
    }

    public Duck add(String tpl, Object... v) {
        Matcher m = FORMAT_PH_PATTERN.matcher(tpl);
        StringBuffer sb = new StringBuffer();
        List<Object> sqlParams = new ArrayList<>();

        AtomicReference<Map> namedArgs = new AtomicReference<>(null);

        Runnable namedArgsInit = () -> {
            if(namedArgs.get() != null) return;

            if(v.length > 0) {
                Object arg = v[v.length - 1];
                namedArgs.set(beanToMap(arg));
            } else {
                throw new IllegalArgumentException("Named argument required at latest position");
            }
        };


        boolean find = false;
        while (m.find()) {
            find = true;

            Object arg = null;
            String key = m.group(2).trim();

            if(INT_PATTERN.matcher(key).find()) {
                int idx = Integer.parseInt(key);
                if (idx < 1 || idx > v.length) {
                    throw new IllegalArgumentException("Index " + (idx) + " out of range");
                }
                arg = v[idx - 1];
            } else {
                namedArgsInit.run();
                if (!namedArgs.get().containsKey(key)) {
                    throw new IllegalArgumentException("Named argument '"+ key +"' not found");
                }
                arg = namedArgs.get().get(key);
            }

            String ident = m.group(1);
            switch (ident) {
                case "": //raw replace
                    if (arg == null) throw new IllegalArgumentException("null not allowed: " + key);
                    m.appendReplacement(sb, Matcher.quoteReplacement(arg.toString()));
                    break;
                case "@": //quote replace
                    if (arg == null) throw new IllegalArgumentException("null not allowed: " + key);
                    m.appendReplacement(sb, Matcher.quoteReplacement(quote(arg.toString())));
                    break;
                case "#": //sql param
                    sqlParams.add(arg);
                    m.appendReplacement(sb, "?");
                    break;
            }
        }
        m.appendTail(sb);

        if (!find) {
            return addRaw(sb.toString(), v);
        }

        return addRaw(sb.toString(), sqlParams.toArray());
    }

    public Duck addRaw(boolean yes, String sql, Object... v) {
        if (yes) return addRaw(sql, v);
        return this;
    }

    public Duck addRaw(String sql, Object... v) {
        Duck query = copy();
        query.appendParams(sql, v);
        return query;
    }


    public Duck namingStrategy(Function<String, String> fn) {
        Duck query = copy();
        query.namingStrategy = fn;
        return query;
    }

    public Duck mark(String name, String sql) {
        Duck query = copy();
        if (query.marks.containsKey(name)) {
            query.queryParts.set(query.marks.get(name), sql);
        } else {
            query.queryParts.add(sql);
            query.marks.put(name, query.queryParts.size() - 1);
        }
        return query;
    }

    protected void appendParams(String sql, Object[] args) {
        String[] parts = (sql + " ").split("\\?");

        if (parts.length != args.length + 1) {
            String errmsg = String.format("Placeholders length (%d) doesn't match parameters length (%d)",
                                          parts.length - 1,
                                          args.length);
            throw new DuckException(errmsg);
        }

        for (int i = 0; i < parts.length; i++) {
            if (args.length <= i) {
                queryParts.add(parts[i]);
                return;
            }

            Object arg = args[i];

            if (arg == null) {
                queryParts.add(parts[i] + '?');
                bindValues.add(null);
            } else if (arg instanceof Collection) {
                appendArray(parts[i], ((Collection<?>) arg).toArray());
            } else if (arg.getClass().isArray()) {
                appendArray(parts[i], (Object[]) arg);
            } else {
                queryParts.add(parts[i] + '?');
                bindValues.add(arg);
            }
        }
    }

    protected void appendArray(String sql, Object[] args) {
        StringBuilder marks = new StringBuilder();
        for (int i = 0; i < args.length; i++) {
            if (i > 0) marks.append(',');
            marks.append('?');
        }
        queryParts.add(sql + marks);
        bindValues.addAll(Arrays.asList(args));
    }

    protected void debug() {
        if (log.isDebugEnabled()) {
            log.debug(this.toString());
        }
    }

    @Override
    public String toString() {
        String sql = getSql().replaceAll("(#|--)[^\\n]*", "").replace("\n", " ").trim();
        return "SQL: " + sql + " #" + Arrays.toString(getParams());
    }

    public static String toSnake(String input) {
        if (input == null || input.isEmpty()) {
            return input;
        }
        String s = ABBR_TO_WORD.matcher(input).replaceAll("$1_$2");
        s = CAMEL_TO_SNAKE.matcher(s).replaceAll("$1_$2");
        return s.toLowerCase();
    }

    public static Map<String, Object> mapOf(Object... v) {
        Map<String, Object> result = new HashMap<>();
        if(v.length % 2 != 0) throw new IllegalArgumentException("argument length must be even");
        for (int i = 0; i < v.length; i += 2) {
            String prefix = v[i].toString();
            result.put(prefix, v[i+1]);
            Map<String, Object> temp = beanToMap(v[i+1]);
            if (temp.isEmpty()) continue;
            for (Map.Entry<String, Object> kv : temp.entrySet()) {
                result.put(prefix + "." + kv.getKey(), kv.getValue());
            }
        }
        return result;
    }

    @SuppressWarnings("unchecked")
    public static Map<String, Object> beanToMap(Object bean) {
        if(bean instanceof Map) {
            Map<String, Object> result = new HashMap<>();
            ((Map) bean).forEach((k, v) -> result.put(k.toString(), v));
            return result;
        }

        try {
            Map<String, Object> result = new HashMap<>();
            Class<?> clazz = bean.getClass();
            PropertyDescriptor[] props = Introspector.getBeanInfo(clazz, Object.class).getPropertyDescriptors();
            for (PropertyDescriptor pd : props) {
                Method getter = pd.getReadMethod();  // getXxx()
                if (getter == null || getter.getDeclaringClass().getName().startsWith("java.")) {
                    continue;
                }

                int modifiers = getter.getModifiers();
                if (!Modifier.isPublic(modifiers) || Modifier.isStatic(modifiers)) {
                    continue;
                }

                result.put(pd.getName(), getter.invoke(bean));
            }
            return result;
        } catch (Exception ex) {
            throw new DuckException(ex);
        }
    }

    protected Map<String, Object> beanToMap(Object bean, boolean isUpdate, boolean isInsert) {
        Map<String, Object> source = beanToMap(bean);
        Map<String, Object> result = new HashMap<>();
        Map<String, Info> infoMap = getCachedClassInfo(bean.getClass());
        for (Map.Entry<String, Object> kv : source.entrySet()) {
            boolean ignoreNull = true;
            String fieldName = kv.getKey().toString();
            String columnName = namingStrategy.apply(fieldName);
            Info info = infoMap.get(fieldName);
            if (info != null) {
                if (!info.exists()) continue;
                if (isUpdate && !info.update()) continue;
                if (isInsert && !info.insert()) continue;

                ignoreNull = info.ignoreNull();

                String annoColumnName = info.name();
                if (annoColumnName != null && !annoColumnName.trim().isEmpty()) {
                    columnName = annoColumnName;
                }
            }

            if (kv.getValue() != null || !ignoreNull) {
                result.put(columnName, kv.getValue());
            }
        }
        return result;
    }

    public SimpleEntry<String, List<Object>> pairInsert(Map<String, Object> data) {
        SimpleEntry<List<String>, List<Object>> kv = pair(data);
        String sql = String.join(", ", kv.getKey());
        return new SimpleEntry<>(sql, kv.getValue());
    }

    public SimpleEntry<String, Object[]> pairUpdate(Map<String, Object> data) {
        SimpleEntry<List<String>, List<Object>> kv = pair(data);
        String sql = kv.getKey().stream().map(k -> String.format("%s = ?", k)).collect(Collectors.joining(", "));

        return new SimpleEntry<>(sql, kv.getValue().toArray());
    }

    public SimpleEntry<List<String>, List<Object>> pair(Map<String, Object> data) {
        List<String> keys = new ArrayList<>(data.size());
        List<Object> values = new ArrayList<>(data.size());

        int index = 0;
        for (Map.Entry<String, Object> kv : data.entrySet()) {
            String key = quote(kv.getKey());
            keys.add(index, key);
            values.add(index, kv.getValue());
            index++;
        }

        return new SimpleEntry<>(keys, values);
    }

    public static Map<String, Info> getCachedClassInfo(Class<?> type) {
        return CLASS_INFO_CACHE.computeIfAbsent(type, t -> {
            Map<String, Info> fieldInfoMap = new HashMap<>();
            for (Field field : t.getDeclaredFields()) {
                Info info = field.getAnnotation(Info.class);
                fieldInfoMap.put(field.getName(), info);
            }
            return fieldInfoMap;
        });
    }

    @FunctionalInterface
    protected interface ConnectionCallback<T> {
        T execWithConn(Connection conn) throws SQLException;
    }

    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.FIELD, ElementType.TYPE})
    public @interface Info {
        String name() default "";

        boolean exists() default true;

        boolean update() default true;

        boolean insert() default true;

        boolean ignoreNull() default true;
    }

    public static class DuckException extends RuntimeException {
        public DuckException(String message) {
            super(message);
        }

        public DuckException(Throwable cause) {
            super(cause);
        }

        public DuckException(String message, Throwable cause) {
            super(message, cause);
        }
    }

    public static class DuckBeanProcessor extends BeanProcessor {
        private final Class<?> type;
        private final Map<String, Info> infoMap;

        public DuckBeanProcessor(Class<?> type) {
            this.type = type;
            this.infoMap = getCachedClassInfo(type);
        }

        @Override
        protected int[] mapColumnsToProperties(final ResultSetMetaData rsmd, final PropertyDescriptor[] props) throws SQLException {

            final int cols = rsmd.getColumnCount();
            final int[] columnToProperty = new int[cols + 1];
            Arrays.fill(columnToProperty, PROPERTY_NOT_FOUND);

            for (int col = 1; col <= cols; col++) {
                String columnName = rsmd.getColumnLabel(col);

                if (null == columnName || columnName.isEmpty()) {
                    columnName = rsmd.getColumnName(col);
                }

                final String generousColumnName = columnName.replace("_", "")   // more idiomatic to Java
                  .replace(" ", "");  // can't have spaces in property names

                for (int i = 0; i < props.length; i++) {
                    String propName = props[i].getName();

                    Info info = infoMap.get(propName);
                    if (info != null && !info.name().isEmpty()) {
                        propName = info.name();
                    }

                    // see if either the column name, or the generous one matches
                    if (columnName.equalsIgnoreCase(propName) || generousColumnName.equalsIgnoreCase(propName)) {
                        columnToProperty[col] = i;
                        break;
                    }
                }
            }

            return columnToProperty;
        }
    }
}