/*
 * Copyright 2023 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.seppiko.commons.utils.jdbc;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Objects;
import org.seppiko.commons.utils.StringUtil;

/**
 * SQL executor
 *
 * @author Leonard Woo
 */
public class SQLExecutor {

  private final Connection conn;

  /**
   * Initialize SQL executor
   *
   * @param conn the connection.
   * @throws NullPointerException connection is null
   */
  public SQLExecutor(Connection conn) throws NullPointerException {
    Objects.requireNonNull(conn, "Connection must be not null.");
    this.conn = conn;
  }

  /**
   * Query
   * Execute {@code SELECT} statement
   *
   * @param sql SQL Statement.
   * @return a {@code ResultSet} object that contains the data produced by the given query; never {@code null}
   * @throws SQLException if a database access error occurs.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  public ResultSet query(final String sql) throws SQLException, IllegalArgumentException {
    checkSQL(sql);
    return conn.createStatement().executeQuery(sqlFilter(sql));
  }

  /**
   * Query
   * Execute {@code SELECT} statement
   *
   * @param sql SQL Statement.
   * @param params parameters.
   * @return a {@code ResultSet} object that contains the data produced by the given query; never {@code null}
   * @throws SQLException if a database access error occurs.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  public ResultSet query(final String sql, Object... params) throws SQLException, IllegalArgumentException {
    checkSQL(sql);

    PreparedStatement pstmt = null;
    try {
      pstmt = conn.prepareStatement(sqlFilter(sql));
      if (params.length == 1) {
        setParameter(pstmt, 1, params[0]);
      } else {
        for (int i = 0; i < params.length; i++) {
          setParameter(pstmt, i + 1, params[i]);
        }
      }
      return pstmt.executeQuery();
    } finally{
      if (pstmt != null && !pstmt.isClosed()) {
        pstmt.close();
      }
    }
  }

  /**
   *  Execute {@code INSERT} {@code UPDATE} {@code DELETE} or other statement
   *
   * @param sql SQL Statement.
   * @return SQL execute row count.
   * @throws SQLException if a database access error occurs.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  public int execute(final String sql) throws SQLException, IllegalArgumentException {
    checkSQL(sql);

    Statement stmt = null;
    try {
      stmt = conn.createStatement();
      return stmt.executeUpdate(sqlFilter(sql));
    } finally {
      if (stmt != null && !stmt.isClosed()) {
        stmt.close();
      }
    }
  }

  /**
   * Execute {@code INSERT} {@code UPDATE} {@code DELETE} or other statement
   *
   * @param sql SQL Statement.
   * @param params parameters.
   * @return SQL execute row count.
   * @throws SQLException if a database access error occurs.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  public int execute(final String sql, Object... params) throws SQLException, IllegalArgumentException {
    checkSQL(sql);

    PreparedStatement pstmt = null;
    try {
      pstmt = conn.prepareStatement(sqlFilter(sql));
      if (params.length == 1) {
        setParameter(pstmt, 1, params[0]);
      } else {
        for (int i = 0; i < params.length; i++) {
          setParameter(pstmt, i + 1, params[i]);
        }
      }
      return pstmt.executeUpdate();
    } finally {
      if (pstmt != null && !pstmt.isClosed()) {
        pstmt.close();
      }
    }
  }

  /**
   * Execute {@code INSERT} {@code UPDATE} {@code DELETE} or other statement
   *
   * @param sql SQL Statement.
   * @return auto-generated key, -1 is Fail, and 0 is not get key.
   * @throws SQLException if a database access error occurs.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  public long executeWithGeneratedKey(final String sql) throws SQLException, IllegalArgumentException {
    checkSQL(sql);

    Statement stmt = null;
    ResultSet rs = null;
    try {
      stmt = conn.createStatement();
      if (stmt.executeUpdate(sqlFilter(sql)) > 0) {
        rs = stmt.getGeneratedKeys();
        if (rs != null && rs.next()) {
          return rs.getLong(1);
        }
      }
      return -1L;
    } finally {
      if (rs != null && !rs.isClosed()) {
        rs.close();
      }
      if (stmt != null && !stmt.isClosed()) {
        stmt.close();
      }
    }
  }

  /**
   * Execute {@code INSERT} {@code UPDATE} {@code DELETE} or other statement
   *
   * @param sql SQL Statement.
   * @param params parameters.
   * @return auto-generated key, -1 is Fail, and 0 is not get key.
   * @throws SQLException if a database access error occurs.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  public long executeWithGeneratedKey(final String sql, Object... params) throws SQLException, IllegalArgumentException {
    checkSQL(sql);

    PreparedStatement pstmt = null;
    ResultSet rs = null;
    try {
      pstmt = conn.prepareStatement(sqlFilter(sql), Statement.RETURN_GENERATED_KEYS);
      if (params.length == 1) {
        setParameter(pstmt, 1, params[0]);
      } else {
        for (int i = 0; i < params.length; i++) {
          setParameter(pstmt, i + 1, params[i]);
        }
      }
      if (pstmt.executeUpdate() > 0) {
        rs = pstmt.getGeneratedKeys();
        if (rs != null && rs.next()) {
          return rs.getLong(1);
        }
      }
      return -1L;
    } finally {
      if (rs != null && !rs.isClosed()) {
        rs.close();
      }
      if (pstmt != null && !pstmt.isClosed()) {
        pstmt.close();
      }
    }
  }

  /**
   * Close Connection
   *
   * @throws SQLException if a database access error occurs.
   */
  public void close() throws SQLException {
    if (!conn.isClosed()) {
      conn.close();
    }
  }

  /**
   * Check SQL
   *
   * @param sql SQL Statement.
   * @throws IllegalArgumentException SQL is empty or null.
   */
  private void checkSQL(String sql) throws IllegalArgumentException {
    if (StringUtil.isNullOrEmpty(sql)) {
      throw new IllegalArgumentException("SQL must be not null.");
    }
  }

  /**
   * This is for filtering SQL statement, avoid SQL statement will error on executing
   *
   * @param srcSQL Source SQL Statement.
   * @return Execution SQL Statement.
   */
  private String sqlFilter(String srcSQL) {
    return srcSQL.endsWith(";") ? srcSQL.substring(0, srcSQL.length() - 1) : srcSQL;
  }

  /**
   * set parameter object
   *
   * @param pstmt {@link PreparedStatement} instance.
   * @param i parameter index.
   * @param obj parameter object.
   * @throws SQLException if parameterIndex does not correspond to a parameter
   *     marker in the SQL statement; if a database access error occurs or
   *     this method is called on a closed {@code PreparedStatement}.
   */
  private static void setParameter(PreparedStatement pstmt, int i, Object obj)
      throws SQLException {
    if (obj instanceof java.util.Date) {
      if (obj instanceof java.sql.Date) {
        pstmt.setDate(i, (java.sql.Date) obj);
      } else if (obj instanceof java.sql.Time) {
        pstmt.setTime(i, (java.sql.Time) obj);
      } else {
        pstmt.setTimestamp(i, SqlTypeUtil.toSqlTimestamp((java.util.Date) obj));
      }
    } else if (obj instanceof Number) {
      if (obj instanceof BigDecimal) {
        pstmt.setBigDecimal(i, (BigDecimal) obj);
      } else if (obj instanceof BigInteger) {
        pstmt.setBigDecimal(i, new BigDecimal((BigInteger) obj));
      }
    } else if (obj instanceof String) {
      pstmt.setString(i, (String) obj);
    } else {
      pstmt.setObject(i, obj);
    }
  }
}
