package org.elsfs.tool.sql.mybatisplus.injector;

import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.injector.DefaultSqlInjector;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import java.util.List;
import java.util.function.Predicate;
import org.elsfs.tool.sql.mybatisplus.method.ExistsById;
import org.elsfs.tool.sql.mybatisplus.method.ExistsByJoinWrapper;
import org.elsfs.tool.sql.mybatisplus.method.ExistsByWrapper;
import org.elsfs.tool.sql.mybatisplus.method.SelectJoinCount;
import org.elsfs.tool.sql.mybatisplus.method.SelectJoinList;
import org.elsfs.tool.sql.mybatisplus.method.SelectJoinOne;
import org.elsfs.tool.sql.mybatisplus.method.SelectJoinPage;
import org.elsfs.tool.sql.mybatisplus.utils.PredicateHolder;

/**
 * 增强的SQL注入器
 *
 * @author zeng
 * @since 0.0.4
 */
public class EnhancedSqlInjector extends DefaultSqlInjector {

  static class ThreadLocalPredicate implements Predicate<TableFieldInfo> {

    /**
     * Evaluates this predicate on the given argument.
     *
     * @param tableFieldInfo the input argument
     * @return {@code true} if the input argument matches the predicate, otherwise {@code false}
     */
    @Override
    public boolean test(TableFieldInfo tableFieldInfo) {
      Predicate<TableFieldInfo> predicate = PredicateHolder.get();
      if (predicate == null) {
        return true;
      }

      return predicate.test(tableFieldInfo);
    }
  }

  /**
   * 获取注入方法列表
   *
   * @param mapperClass 当前mapper
   * @param tableInfo 表信息
   * @return 方法列表
   */
  @Override
  public List<AbstractMethod> getMethodList(Class<?> mapperClass, TableInfo tableInfo) {
    List<AbstractMethod> methods = super.getMethodList(mapperClass, tableInfo);

    ThreadLocalPredicate threadLocalPredicate = new ThreadLocalPredicate();
    // import com.baomidou.mybatisplus.extension.injector.methods.AlwaysUpdateSomeColumnById;
    // import com.baomidou.mybatisplus.extension.injector.methods.InsertBatchSomeColumn;
    //        methods.add(new AlwaysUpdateSomeColumnById(threadLocalPredicate));
    //        methods.add(new InsertBatchSomeColumn(threadLocalPredicate));
    methods.add(new ExistsById());
    methods.add(new ExistsByWrapper());
    methods.add(new ExistsByJoinWrapper());
    methods.add(new SelectJoinCount());
    methods.add(new SelectJoinOne());
    methods.add(new SelectJoinList());
    methods.add(new SelectJoinPage());

    return methods;
  }
}
