package org.nkjmlab.sorm4j.internal.mapping.multirow;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Function;

import org.nkjmlab.sorm4j.context.logging.LogContext;
import org.nkjmlab.sorm4j.internal.OrmConnectionImpl;
import org.nkjmlab.sorm4j.internal.context.PreparedStatementSupplier;
import org.nkjmlab.sorm4j.internal.context.SqlParametersSetter;
import org.nkjmlab.sorm4j.internal.context.logging.LogPoint;
import org.nkjmlab.sorm4j.internal.mapping.ContainerToTableMapper;
import org.nkjmlab.sorm4j.sql.TableSql;
import org.nkjmlab.sorm4j.util.function.exception.Try;

public abstract class MultiRowProcessorBase<T> implements MultiRowProcessor<T> {

  private final int batchSize;
  private final PreparedStatementSupplier statementSupplier;
  private final SqlParametersSetter sqlParametersSetter;
  private final ContainerToTableMapper<T> tableMapping;
  private final LogContext loggerContext;

  MultiRowProcessorBase(
      LogContext loggerContext,
      SqlParametersSetter sqlParametersSetter,
      PreparedStatementSupplier statementSupplier,
      ContainerToTableMapper<T> tableMapping,
      int batchSize) {
    this.loggerContext = loggerContext;
    this.statementSupplier = statementSupplier;
    this.sqlParametersSetter = sqlParametersSetter;
    this.tableMapping = tableMapping;
    this.batchSize = batchSize;
  }

  @Override
  public abstract int[] multiRowInsert(Connection con, T[] objects);

  @Override
  public abstract int[] multiRowMerge(Connection con, T[] objects);

  protected final TableSql getSql() {
    return tableMapping.getSql();
  }

  protected Object[] getInsertParameters(T obj) {
    return tableMapping.getInsertParameters(obj);
  }

  protected Object[] getMergeParameters(T obj) {
    return tableMapping.getMergeParameters(obj);
  }

  @Override
  public final void setPrametersOfMultiRow(PreparedStatement stmt, T[] objects)
      throws SQLException {
    Object[] parameters =
        Arrays.stream(objects)
            .flatMap(
                obj -> Arrays.stream(tableMapping.getParametersWithoutAutoGeneratedColumns(obj)))
            .toArray(Object[]::new);
    sqlParametersSetter.setParameters(stmt, parameters);
  }

  @Override
  public final int[] batch(
      Connection con, String sql, Function<T, Object[]> parameterCreator, T[] objects) {
    return execMultiRowProcIfValidObjects(
        con,
        objects,
        nonNullObjects -> {
          int[] result = new int[0];
          boolean origAutoCommit = OrmConnectionImpl.getAutoCommit(con);

          try (PreparedStatement stmt = statementSupplier.prepareStatement(con, sql)) {
            OrmConnectionImpl.setAutoCommit(con, false);
            final BatchHelper batchHelper = new BatchHelper(batchSize, stmt);
            for (int i = 0; i < objects.length; i++) {
              T obj = objects[i];
              this.sqlParametersSetter.setParameters(stmt, parameterCreator.apply(obj));
              batchHelper.addBatchAndExecuteIfReachedThreshold();
            }
            result = batchHelper.finish();
            return result;
          } catch (SQLException e) {
            throw Try.rethrow(e);
          } finally {
            OrmConnectionImpl.commitOrRollback(con, origAutoCommit);
            OrmConnectionImpl.setAutoCommit(con, origAutoCommit);
          }
        });
  }

  /**
   * Execute multirow sql function. objects when objects[0] is null, {@code NullPointerException}
   * are throw.
   */
  final int[] execMultiRowProcIfValidObjects(
      Connection con, T[] objects, Function<T[], int[]> exec) {
    if (objects == null || objects.length == 0) {
      return new int[0];
    }
    Optional<LogPoint> lp =
        loggerContext.createLogPoint(
            LogContext.Category.EXECUTE_MULTI_ROW_UPDATE, MultiRowProcessorBase.class);
    lp.ifPresent(
        _lp ->
            _lp.logBeforeMultiRow(
                con,
                objects[0].getClass(),
                objects.length,
                tableMapping.getTableMetaData().getTableName()));

    final int[] result = exec.apply(objects);

    lp.ifPresent(_lp -> _lp.logAfterMultiRow(result));
    return result;
  }

  protected PreparedStatement prepareStatement(Connection con, String sql) throws SQLException {
    return statementSupplier.prepareStatement(con, sql);
  }
}
