package cn.ziyicloud.framework.boot.autoconfigure.data.jpa.util;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.ziyicloud.framework.boot.autoconfigure.data.jpa.annotation.JpaQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.jpa.domain.Specification;

import javax.persistence.criteria.*;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
 * jpa查询工具类
 *
 * @author Li Ruitong 86415270@qq.com
 * @since 1.0.7
 */
@Slf4j
public final class JpaQueryUtils {
    /**
     * sql like连接符
     */
    public static final String LIKE_SYMBOL = "%";

    private JpaQueryUtils() {
    }

    /**
     * 构建Specification
     *
     * @param criteria 查询条件
     * @param <R>      Specification类型
     * @param <Q>      查询条件类型
     * @return Specification
     */
    public static <R, Q> Specification<R> buildSpec(Q criteria) {
        return (root, criteriaQuery, criteriaBuilder) -> JpaQueryUtils.getPredicate(root, criteria, criteriaBuilder);
    }

    /**
     * @param root    /
     * @param query   /
     * @param builder /
     * @param <R>     /
     * @param <Q>     /
     * @return /
     */
    public static <R, Q> Predicate getPredicate(Root<R> root, Q query, CriteriaBuilder builder) {
        //匹配规则列表
        List<Predicate> list = new ArrayList<>();
        if (query == null) {
            return builder.conjunction();
        }
        try {
            List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
            for (Field field : fields) {
                boolean accessible = field.isAccessible();
                //设置可以通过反射访问私有变量
                field.setAccessible(true);
                //获取注解
                JpaQuery jpaQuery = field.getAnnotation(JpaQuery.class);
                //含有自定义注解
                if (jpaQuery != null) {
                    String propName = jpaQuery.propName();
                    String joinName = jpaQuery.joinName();
                    String blurry = jpaQuery.blurry();
                    String attrName = StringUtils.isBlank(propName) ? field.getName() : propName;
                    //获取字段类型和字段值
                    Class<?> fieldType = field.getType();
                    Object value = field.get(query);
                    if (ObjectUtils.isEmpty(value)) {
                        continue;
                    }
                    //配置表连接
                    Join<?, ?> join = null;
                    // 模糊多字段
                    if (ObjectUtils.isNotEmpty(blurry)) {
                        String[] blurryList = blurry.split(",");
                        List<Predicate> orPredicate = new ArrayList<>();
                        for (String blurryAttr : blurryList) {
                            //添加like模式
                            orPredicate.add(builder.like(root.get(blurryAttr).as(String.class), LIKE_SYMBOL + value + LIKE_SYMBOL));
                        }
                        Predicate[] p = new Predicate[orPredicate.size()];
                        //用or连接多个模糊模式添加到模式列表
                        list.add(builder.or(orPredicate.toArray(p)));
                        //匹配到模糊查询继续
                        continue;
                    }
                    // 配置表连接
                    if (ObjectUtils.isNotEmpty(joinName)) {
                        String[] joinNames = joinName.split(">");
                        for (String name : joinNames) {
                            switch (jpaQuery.join()) {
                                case LEFT:
                                    join = getJoin(root, join, name, JoinType.LEFT, value);
                                    break;
                                case RIGHT:
                                    join = getJoin(root, join, name, JoinType.RIGHT, value);
                                    break;
                                case INNER:
                                    join = getJoin(root, join, name, JoinType.INNER, value);
                                    break;
                                default:
                                    break;
                            }
                        }
                    }
                    Expression<?> expression = getExpression(attrName, join, root);
                    buildPredicate(builder, list, jpaQuery, fieldType, value, expression);
                }
                //还原字段修饰符
                field.setAccessible(accessible);
            }
        } catch (Exception e) {
            log.error(e.getMessage(), e);
        }

        int size = list.size();
        return builder.and(list.toArray(new Predicate[size]));
    }

    /**
     * 构建规则
     *
     * @param builder    构建器
     * @param list       列表
     * @param jpaQuery   jpa查询
     * @param fieldType  字段类型
     * @param value      值
     * @param expression 表达式
     */
    @SuppressWarnings("unchecked")
    private static void buildPredicate(CriteriaBuilder builder, List<Predicate> list, JpaQuery jpaQuery, Class<?> fieldType, Object value, Expression<?> expression) {
        switch (jpaQuery.type()) {
            case EQ:
                list.add(builder.equal(expression, value));
                break;
            case GT:
                list.add(builder.greaterThan(expression.as((Class<? extends Comparable>) fieldType), (Comparable) value));
                break;
            case GTE:
                list.add(builder.greaterThanOrEqualTo(expression.as((Class<? extends Comparable>) fieldType), (Comparable) value));
                break;
            case LT:
                list.add(builder.lessThan(expression.as((Class<? extends Comparable>) fieldType), (Comparable) value));
                break;
            case LTE:
                list.add(builder.lessThanOrEqualTo(expression.as((Class<? extends Comparable>) fieldType), (Comparable) value));
                break;
            case INNER_LIKE:
                list.add(builder.like(expression.as(String.class), LIKE_SYMBOL + value + LIKE_SYMBOL));
                break;
            case LEFT_LIKE:
                list.add(builder.like(expression.as(String.class), LIKE_SYMBOL + value));
                break;
            case RIGHT_LIKE:
                list.add(builder.like(expression.as(String.class), value + LIKE_SYMBOL));
                break;
            case IN:
                if (CollUtil.isNotEmpty((Collection<?>) value)) {
                    list.add(expression.in((Collection<?>) value));
                }
                break;
            case NE:
                list.add(builder.notEqual(expression, value));
                break;
            case NOT_NULL:
                list.add(builder.isNotNull(expression));
                break;
            case BETWEEN:
                List<?> between = new ArrayList<>((List<?>) value);
                list.add(builder.between(expression.as((Class<? extends Comparable>) between.get(0).getClass()), (Comparable) between.get(0), (Comparable) between.get(1)));
                break;
            default:
                break;
        }
    }

    /**
     * 获取连接Join
     *
     * @param root     root
     * @param join     join
     * @param name     name
     * @param joinType joinType
     * @param <R>      the entity type referenced by the root
     * @return Join
     */
    private static <R> Join<?, ?> getJoin(Root<R> root, Join<?, ?> join, String name, JoinType joinType, Object value) {
        if (ObjectUtils.allNotNull(join, value)) {
            join = join.join(name, joinType);
        } else {
            join = root.join(name, joinType);
        }
        return join;
    }


    /**
     * 获取查询表达式
     *
     * @param attrName 属性名
     * @param join     join
     * @param root     root
     * @param <R>      the entity type referenced by the root
     * @return expressions
     */
    private static <T, R> Expression<T> getExpression(String attrName, Join<?, ?> join, Root<R> root) {
        if (ObjectUtil.isNotEmpty(join)) {
            return join.get(attrName);
        } else {
            return root.get(attrName);
        }
    }

    /**
     * 通过反射获取一个类及其父类所有有属性
     *
     * @param clazz  clazz
     * @param fields fields
     * @return {@link List}<{@link Field}> Field列表
     */
    private static List<Field> getAllFields(Class<?> clazz, List<Field> fields) {
        if (clazz != null) {
            fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
            getAllFields(clazz.getSuperclass(), fields);
        }
        return fields;
    }
}
