/*
 * Decompiled with CFR 0.152.
 */
package org.miaixz.bus.pager.plugin;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.miaixz.bus.core.lang.Charset;
import org.miaixz.bus.core.lang.exception.InternalException;
import org.miaixz.bus.core.xyz.StringKit;
import org.miaixz.bus.crypto.Builder;
import org.miaixz.bus.logger.Logger;
import org.miaixz.bus.pager.plugin.SqlParserHandler;

@Intercepts(value={@Signature(type=StatementHandler.class, method="prepare", args={Connection.class, Integer.class})})
public class IllegalSqlHandler
extends SqlParserHandler
implements Interceptor {
    private static final Set<String> cacheValidResult = new HashSet<String>();
    private static final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<String, List<IndexInfo>>();

    private static void validExpression(Expression expression) {
        InExpression inExpression;
        if (expression instanceof OrExpression) {
            OrExpression orExpression = (OrExpression)expression;
            throw new InternalException("\u975e\u6cd5SQL\uff0cwhere\u6761\u4ef6\u4e2d\u4e0d\u80fd\u4f7f\u7528\u3010or\u3011\u5173\u952e\u5b57\uff0c\u9519\u8befor\u4fe1\u606f\uff1a" + orExpression.toString());
        }
        if (expression instanceof NotEqualsTo) {
            NotEqualsTo notEqualsTo = (NotEqualsTo)expression;
            throw new InternalException("\u975e\u6cd5SQL\uff0cwhere\u6761\u4ef6\u4e2d\u4e0d\u80fd\u4f7f\u7528\u3010!=\u3011\u5173\u952e\u5b57\uff0c\u9519\u8bef!=\u4fe1\u606f\uff1a" + notEqualsTo.toString());
        }
        if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression)expression;
            if (binaryExpression.getLeftExpression() instanceof Function) {
                Function function = (Function)binaryExpression.getLeftExpression();
                throw new InternalException("\u975e\u6cd5SQL\uff0cwhere\u6761\u4ef6\u4e2d\u4e0d\u80fd\u4f7f\u7528\u6570\u636e\u5e93\u51fd\u6570\uff0c\u9519\u8bef\u51fd\u6570\u4fe1\u606f\uff1a" + function.toString());
            }
            if (binaryExpression.getRightExpression() instanceof ParenthesedSelect) {
                ParenthesedSelect subSelect = (ParenthesedSelect)binaryExpression.getRightExpression();
                throw new InternalException("\u975e\u6cd5SQL\uff0cwhere\u6761\u4ef6\u4e2d\u4e0d\u80fd\u4f7f\u7528\u5b50\u67e5\u8be2\uff0c\u9519\u8bef\u5b50\u67e5\u8be2SQL\u4fe1\u606f\uff1a" + subSelect.toString());
            }
        } else if (expression instanceof InExpression && (inExpression = (InExpression)expression).getRightExpression() instanceof ParenthesedSelect) {
            ParenthesedSelect subSelect = (ParenthesedSelect)inExpression.getRightExpression();
            throw new InternalException("\u975e\u6cd5SQL\uff0cwhere\u6761\u4ef6\u4e2d\u4e0d\u80fd\u4f7f\u7528\u5b50\u67e5\u8be2\uff0c\u9519\u8bef\u5b50\u67e5\u8be2SQL\u4fe1\u606f\uff1a" + subSelect.toString());
        }
    }

    private static void validJoins(List<Join> joins, Table table, Connection connection) {
        if (null != joins) {
            for (Join join : joins) {
                Table rightTable = (Table)join.getRightItem();
                Expression expression = join.getOnExpression();
                IllegalSqlHandler.validWhere(expression, table, rightTable, connection);
            }
        }
    }

    private static void validUseIndex(Table table, String columnName, Connection connection) {
        String tableName;
        boolean useIndexFlag = false;
        String tableInfo = table.getName();
        String dbName = null;
        String[] tableArray = tableInfo.split("\\.");
        if (tableArray.length == 1) {
            tableName = tableArray[0];
        } else {
            dbName = tableArray[0];
            tableName = tableArray[1];
        }
        List<IndexInfo> indexInfos = IllegalSqlHandler.getIndexInfos(dbName, tableName, connection);
        for (IndexInfo indexInfo : indexInfos) {
            if (!Objects.equals(columnName, indexInfo.getColumnName())) continue;
            useIndexFlag = true;
            break;
        }
        if (!useIndexFlag) {
            throw new InternalException("\u975e\u6cd5SQL\uff0cSQL\u672a\u4f7f\u7528\u5230\u7d22\u5f15, table:" + String.valueOf(table) + ", columnName:" + columnName);
        }
    }

    private static void validWhere(Expression expression, Table table, Connection connection) {
        IllegalSqlHandler.validWhere(expression, table, null, connection);
    }

    private static void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
        IllegalSqlHandler.validExpression(expression);
        if (expression instanceof BinaryExpression) {
            Expression rightExpression;
            Expression leftExpression = ((BinaryExpression)expression).getLeftExpression();
            IllegalSqlHandler.validExpression(leftExpression);
            if (leftExpression instanceof Column) {
                rightExpression = ((BinaryExpression)expression).getRightExpression();
                if (null != joinTable && rightExpression instanceof Column) {
                    if (Objects.equals(((Column)rightExpression).getTable().getName(), table.getAlias().getName())) {
                        IllegalSqlHandler.validUseIndex(table, ((Column)rightExpression).getColumnName(), connection);
                        IllegalSqlHandler.validUseIndex(joinTable, ((Column)leftExpression).getColumnName(), connection);
                    } else {
                        IllegalSqlHandler.validUseIndex(joinTable, ((Column)rightExpression).getColumnName(), connection);
                        IllegalSqlHandler.validUseIndex(table, ((Column)leftExpression).getColumnName(), connection);
                    }
                } else {
                    IllegalSqlHandler.validUseIndex(table, ((Column)leftExpression).getColumnName(), connection);
                }
            } else if (leftExpression instanceof BinaryExpression) {
                IllegalSqlHandler.validWhere(leftExpression, table, joinTable, connection);
            }
            rightExpression = ((BinaryExpression)expression).getRightExpression();
            IllegalSqlHandler.validExpression(rightExpression);
        }
    }

    private static List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
        return IllegalSqlHandler.getIndexInfos(null, dbName, tableName, conn);
    }

    private static List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
        List<IndexInfo> indexInfos = null;
        if (StringKit.isNotBlank((CharSequence)key)) {
            indexInfos = indexInfoMap.get(key);
        }
        if (null == indexInfos || indexInfos.isEmpty()) {
            try {
                DatabaseMetaData metadata = conn.getMetaData();
                String catalog = StringKit.isBlank((CharSequence)dbName) ? conn.getCatalog() : dbName;
                String schema = StringKit.isBlank((CharSequence)dbName) ? conn.getSchema() : dbName;
                ResultSet rs = metadata.getIndexInfo(catalog, schema, tableName, false, true);
                indexInfos = new ArrayList<IndexInfo>();
                while (rs.next()) {
                    if (!Objects.equals(rs.getString(8), "1")) continue;
                    IndexInfo indexInfo = new IndexInfo();
                    indexInfo.setDbName(rs.getString(1));
                    indexInfo.setTableName(rs.getString(3));
                    indexInfo.setColumnName(rs.getString(9));
                    indexInfos.add(indexInfo);
                }
                if (StringKit.isNotBlank((CharSequence)key)) {
                    indexInfoMap.put(key, indexInfos);
                }
            }
            catch (SQLException e) {
                e.printStackTrace();
            }
        }
        return indexInfos;
    }

    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler)IllegalSqlHandler.realTarget((Object)invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject((Object)statementHandler);
        MappedStatement mappedStatement = IllegalSqlHandler.getMappedStatement((MetaObject)metaObject);
        if (SqlCommandType.INSERT.equals((Object)mappedStatement.getSqlCommandType()) || IllegalSqlHandler.getSqlParserInfo((MetaObject)metaObject)) {
            return invocation.proceed();
        }
        BoundSql boundSql = (BoundSql)metaObject.getValue("delegate.boundSql");
        String originalSql = boundSql.getSql();
        Logger.debug((String)("Check for SQL : " + originalSql), (Object[])new Object[0]);
        String md5Base64 = Base64.getEncoder().encodeToString(Builder.md5().digest(originalSql.getBytes(Charset.UTF_8)));
        if (cacheValidResult.contains(md5Base64)) {
            Logger.debug((String)("The SQL has been checked : " + originalSql), (Object[])new Object[0]);
            return invocation.proceed();
        }
        Connection connection = (Connection)invocation.getArgs()[0];
        Statement statement = CCJSqlParserUtil.parse((String)originalSql);
        Expression where = null;
        Table table = null;
        List joins = null;
        if (statement instanceof Select) {
            PlainSelect plainSelect = (PlainSelect)((Select)statement).getSelectBody();
            where = plainSelect.getWhere();
            table = (Table)plainSelect.getFromItem();
            joins = plainSelect.getJoins();
        } else if (statement instanceof Update) {
            Update update = (Update)statement;
            where = update.getWhere();
            table = update.getTable();
            joins = update.getJoins();
        } else if (statement instanceof Delete) {
            Delete delete = (Delete)statement;
            where = delete.getWhere();
            table = delete.getTable();
            joins = delete.getJoins();
        }
        if (null == where) {
            throw new InternalException("\u975e\u6cd5SQL\uff0c\u5fc5\u987b\u8981\u6709where\u6761\u4ef6");
        }
        IllegalSqlHandler.validWhere(where, table, connection);
        IllegalSqlHandler.validJoins(joins, table, connection);
        cacheValidResult.add(md5Base64);
        return invocation.proceed();
    }

    public Object plugin(Object object) {
        if (object instanceof StatementHandler) {
            return Plugin.wrap((Object)object, (Interceptor)this);
        }
        return object;
    }

    private static class IndexInfo {
        private String dbName;
        private String tableName;
        private String columnName;

        private IndexInfo() {
        }

        public String getDbName() {
            return this.dbName;
        }

        public void setDbName(String dbName) {
            this.dbName = dbName;
        }

        public String getTableName() {
            return this.tableName;
        }

        public void setTableName(String tableName) {
            this.tableName = tableName;
        }

        public String getColumnName() {
            return this.columnName;
        }

        public void setColumnName(String columnName) {
            this.columnName = columnName;
        }
    }
}

