/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.mendmix.mybatis.plugin.pagination;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.dromara.mendmix.common.model.PageParams;
import org.dromara.mendmix.mybatis.datasource.DatabaseType;

public class PageSqlUtils {
    private static final String KEY_FROM = "FROM";
    private static final String KEY_SELECT = "SELECT";
    private static final char PAIR_CLOSE_CHAR = ')';
    private static final char PAIR_OPEN_CHAR = '(';
    private static final String[] SQL_LINE_CHARS = new String[]{"\r", "\n", "\t"};
    private static final String[] SQL_LINE_REPLACE_CHARS = new String[]{" ", " ", " "};
    private static final String PAGE_SIZE_PLACEHOLDER = "#{pageSize}";
    private static final String OFFSET_PLACEHOLDER = "#{offset}";
    private static final String SQL_COUNT_PREFIX = "SELECT count(1) ";
    private static String[] unionKeys = new String[]{" UNION ", " union "};
    private static Pattern selectFromPattern = Pattern.compile("(SELECT\\s{1})|(\\s{1}FROM\\s{1})", 2);
    private static Pattern orderByPattern = Pattern.compile("(ORDER)\\s+(BY)", 2);
    private static Pattern nestSelectPattern = Pattern.compile("\\(\\s{0,}(SELECT)\\s+", 2);
    private static Pattern groupByPattern = Pattern.compile("\\s+GROUP\\s+BY\\s+", 2);
    private static List<Pattern> aggregationKeyPatterns = Arrays.asList(Pattern.compile("(\\s+|,)(COUNT|MIN|MAX|SUM|AVG)\\(", 2), Pattern.compile("(\\s+|,)DISTINCT", 2));
    private static Map<String, String> pageTemplates = new HashMap<String, String>(4);
    private static String commonCountSqlTemplate = "select count(1) from (%s) tmp";

    public static String getLimitSQL(DatabaseType dbType, String sql) {
        return String.format(pageTemplates.get(dbType.name()), sql);
    }

    public static String getLimitSQL(DatabaseType dbType, String sql, PageParams pageParams) {
        return PageSqlUtils.getLimitSQL(dbType, sql).replace(OFFSET_PLACEHOLDER, String.valueOf(pageParams.offset())).replace(PAGE_SIZE_PLACEHOLDER, String.valueOf(pageParams.getPageSize()));
    }

    public static String getCountSql(String sql) {
        String formatSql = StringUtils.replaceEach((String)sql, (String[])SQL_LINE_CHARS, (String[])SQL_LINE_REPLACE_CHARS).trim();
        String selectHead = PageSqlUtils.matchTopSelectFrom(formatSql);
        boolean useWrapperMode = aggregationKeyPatterns.stream().anyMatch(p -> p.matcher(selectHead).find());
        String outterSql = formatSql;
        if (!useWrapperMode) {
            String removeSelectHead = formatSql.substring(selectHead.length());
            if (nestSelectPattern.matcher(removeSelectHead).find()) {
                String innerSql = PageSqlUtils.matchOutterParenthesesPair(removeSelectHead);
                outterSql = formatSql.replace(innerSql, "");
            }
            boolean bl = useWrapperMode = StringUtils.containsAny((CharSequence)outterSql, (CharSequence[])unionKeys) || groupByPattern.matcher(outterSql).find();
        }
        if (useWrapperMode) {
            return String.format(commonCountSqlTemplate, formatSql);
        }
        Matcher matcher = orderByPattern.matcher(outterSql);
        if (matcher.find()) {
            int end = formatSql.lastIndexOf(matcher.group());
            sql = formatSql.substring(0, end);
        } else {
            sql = formatSql;
        }
        return StringUtils.replaceOnce((String)sql, (String)selectHead, (String)SQL_COUNT_PREFIX);
    }

    public static String matchOutterParenthesesPair(String sql) {
        char[] chars = sql.toCharArray();
        int start = -1;
        int end = -1;
        int matchIndex = 0;
        for (int i = 0; i < chars.length; ++i) {
            if (chars[i] == '(') {
                if (start < 0) {
                    start = i;
                }
                ++matchIndex;
                continue;
            }
            if (chars[i] != ')' || --matchIndex != 0) continue;
            end = i;
            break;
        }
        return sql.substring(start, end + 1);
    }

    public static String matchTopSelectFrom(String sql) {
        int start = -1;
        int end = -1;
        int matchIndex = 0;
        Matcher matcher = selectFromPattern.matcher(sql);
        while (matcher.find()) {
            if (matcher.group().trim().equalsIgnoreCase(KEY_SELECT)) {
                ++matchIndex;
                if (start >= 0) continue;
                start = matcher.start();
                continue;
            }
            if (!matcher.group().trim().equalsIgnoreCase(KEY_FROM) || --matchIndex != 0) continue;
            end = matcher.end();
            break;
        }
        return sql.substring(start, end - 5);
    }

    public static void main(String[] args) throws IOException {
        List<String> sqls = Arrays.asList("select * from users", "select * from users where status = 1  \n union all  \n select * from users_1", "select a.from,a.to   from    users where status = 1 order by id desc,name asc", "SELECT  wid.id,(SELECT  key_lot_code FROM wms_lot wl2 WHERE wl2.lot_code = wid.lot_code LIMIT 1) AS key_lot_code,(wid.shelve_time , wid.created_at) AS shelveDay FROM wms_inventory_detail wid WHERE wid.deleted = '0' AND wid.qty > 0 ORDER BY wid.created_at DESC", "select u.* from users u,details d where u.id = d.id and u.status=1", "select u.* from users u join details d on u.id = d.id where u.status=1", "select t.* from (select t.* from table where aa = '1') t where t.status ", "select a.*,\nSUM(a.c) from audited_policy a where 1=1\nand title like CONCAT('%',?,'%')\norder by updated_at desc", "select * from \n( \n select MAX(id) from mes_order ) t1 where ( 1 = 1) order by t1.created_at desc");
        for (String sql : sqls) {
            System.out.println("===================================================");
            System.out.println("countSQL:" + PageSqlUtils.getCountSql(sql));
        }
    }

    static {
        pageTemplates.put(DatabaseType.mysql.name(), "%s limit #{offset},#{pageSize}");
        pageTemplates.put(DatabaseType.oracle.name(), "select * from (select a1.*,rownum rn from (%s) a1 where rownum <=#{offset} + #{pageSize}) where rn>=#{offset}");
        pageTemplates.put(DatabaseType.postgresql.name(), "%s limit #{pageSize} offset #{offset}");
        pageTemplates.put(DatabaseType.h2.name(), "%s limit #{pageSize} offset #{offset}");
    }
}

