package cn.ziyicloud.framework.boot.autoconfigure.data.jpa.util;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.ziyicloud.framework.boot.autoconfigure.data.jpa.annotation.ZiyiJpaQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.jpa.domain.Specification;

import javax.persistence.criteria.*;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
 * jpa查询工具类
 *
 * @author Li Ruitong
 * @date 2020/5/25
 */
@Slf4j
public class ZiyiJpaQueryUtils {
    public static String LIKE_SYMBOL = "%";

    private ZiyiJpaQueryUtils() {
    }

    public static <R, Q> Specification<R> getSpec(Q criteria) {
        return (root, criteriaQuery, criteriaBuilder) -> ZiyiJpaQueryUtils.getPredicate(root,
            criteria, criteriaBuilder);
    }

    /**
     * @param root    /
     * @param query   /
     * @param builder /
     * @param <R>     /
     * @param <Q>     /
     * @return /
     */
    public static <R, Q> Predicate getPredicate(Root<R> root, Q query, CriteriaBuilder builder) {
        //匹配规则列表
        List<Predicate> list = new ArrayList<>();
        if (query == null) {
            return builder.conjunction();
        }
        List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
        for (Field field : fields) {
            boolean accessible = field.isAccessible();
            //设置可以通过反射访问私有变量
            field.setAccessible(true);
            //获取注解
            ZiyiJpaQuery ziyiJpaQuery = field.getAnnotation(ZiyiJpaQuery.class);
            //含有自定义注解
            if (ziyiJpaQuery != null) {
                String propName = ziyiJpaQuery.propName();
                String joinName = ziyiJpaQuery.joinName();
                String blurry = ziyiJpaQuery.blurry();
                String attributeName = StringUtils.isBlank(propName) ? field.getName() : propName;

                //获取字段类型和字段值
                Class<?> fieldType = field.getType();
                Object val = null;
                try {
                    val = field.get(query);
                } catch (IllegalAccessException e) {
                    log.error(e.getMessage(), e);
                }
                //值不存在
                if (ObjectUtil.isNull(val) || "".equals(val)) {
                    continue;
                }

                // 模糊多字段
                if (ObjectUtil.isNotEmpty(blurry)) {
                    String[] blurryList = blurry.split(",");
                    List<Predicate> orPredicate = new ArrayList<>();
                    for (String blurryAttr : blurryList) {
                        //添加like模式
                        orPredicate.add(builder.like(root.get(blurryAttr).as(String.class),
                            LIKE_SYMBOL + val.toString() + LIKE_SYMBOL));
                    }
                    Predicate[] p = new Predicate[orPredicate.size()];
                    //用or连接多个模糊模式添加到模式列表
                    list.add(builder.or(orPredicate.toArray(p)));
                    //匹配到模糊查询继续
                    continue;
                }

                //配置表连接
                Join join = null;
                if (ObjectUtil.isNotEmpty(joinName)) {
                    String[] joinNames = joinName.split(">");
                    for (String name : joinNames) {
                        switch (ziyiJpaQuery.join()) {
                            case LEFT:
                                join = getJoin(root, join, name, JoinType.LEFT);
                                break;
                            case RIGHT:
                                join = getJoin(root, join, name, JoinType.RIGHT);
                                break;
                            default:
                                break;
                        }
                    }
                }

                switch (ziyiJpaQuery.type()) {
                    case EQ:
                        list.add(builder.equal(getExpression(attributeName, join, root), val));
                        break;
                    case GT:
                        list.add(builder.greaterThan(getExpression(attributeName, join, root)
                            .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                        break;
                    case GTE:
                        list.add(builder.greaterThanOrEqualTo(getExpression(attributeName, join, root)
                            .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                        break;
                    case LT:
                        list.add(builder.lessThan(getExpression(attributeName, join, root)
                            .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                        break;
                    case LTE:
                        list.add(builder.lessThanOrEqualTo(getExpression(attributeName, join, root)
                            .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                        break;
                    case INNER_LIKE:
                        list.add(builder.like(getLikeExpression(attributeName, join, root),
                            LIKE_SYMBOL + val.toString() + LIKE_SYMBOL));
                        break;
                    case LEFT_LIKE:
                        list.add(builder.like(getLikeExpression(attributeName, join, root),
                            LIKE_SYMBOL + val.toString()));
                        break;
                    case RIGHT_LIKE:
                        list.add(builder.like(getLikeExpression(attributeName, join, root),
                            val.toString() + LIKE_SYMBOL));
                        break;
                    case IN:
                        if (CollUtil.isNotEmpty((Collection<Long>) val)) {
                            list.add(getExpression(attributeName, join, root).in((Collection<Long>) val));
                        }
                        break;
                    case NE:
                        list.add(builder.notEqual(getExpression(attributeName, join, root), val));
                        break;
                    case NOT_NULL:
                        list.add(builder.isNotNull(getExpression(attributeName, join, root)));
                        break;
                    case BETWEEN:
                        List<Object> between = new ArrayList<>((List<Object>) val);
                        list.add(builder.between(getExpression(attributeName, join, root).as((Class<? extends Comparable>) between.get(0).getClass()),
                            (Comparable) between.get(0), (Comparable) between.get(1)));
                        break;
                    default:
                        break;
                }
            }
            //还原字段修饰符
            field.setAccessible(accessible);
        }
        int size = list.size();
        return builder.and(list.toArray(new Predicate[size]));
    }

    /**
     * 获取连接Join
     *
     * @param root     /
     * @param join     /
     * @param name     /
     * @param joinType /
     * @param <R>      /
     * @return
     */
    private static <R> Join getJoin(Root<R> root, Join join, String name, JoinType joinType) {
        if (ObjectUtil.isNotNull(join)) {
            join = join.join(name, joinType);
        } else {
            join = root.join(name, joinType);
        }
        return join;
    }


    /**
     * 获取查询表达式
     *
     * @param attributeName /
     * @param join          /
     * @param root          /
     * @return Expression<T>
     */
    private static <T, R> Expression<T> getExpression(String attributeName, Join join, Root<R> root) {
        if (ObjectUtil.isNotEmpty(join)) {
            return join.get(attributeName);
        } else {
            return root.get(attributeName);
        }
    }

    /**
     * 获取查询表达式
     *
     * @param attributeName
     * @param join
     * @param root
     * @param <R>
     * @return
     */
    private static <R> Expression<String> getLikeExpression(String attributeName, Join join, Root<R> root) {
        if (ObjectUtil.isNotEmpty(join)) {
            return join.get(attributeName).as(String.class);
        } else {
            return root.get(attributeName).as(String.class);
        }
    }

    /**
     * 通过反射获取一个类及其父类所有有属性
     *
     * @param clazz  /
     * @param fields /
     * @return List<Field>
     */
    private static List<Field> getAllFields(Class clazz, List<Field> fields) {
        if (clazz != null) {
            fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
            getAllFields(clazz.getSuperclass(), fields);
        }
        return fields;
    }
}
