/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.afrunt.jpa.powerdao;

import javax.persistence.LockModeType;
import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaDelete;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Root;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * @author Andrii Frunt
 */
public abstract class AbstractExtendedApiDao extends AbstractExtendedQueryApiDao implements ExtendedApiDao {
    @Override
    public <ET> Optional<ET> findById(Class<ET> entityClass, Object primaryKey) {
        return Optional.ofNullable(find(entityClass, primaryKey));
    }

    @Override
    public <ET> Optional<ET> findById(Class<ET> entityType, Object primaryKey,
                                      Map<String, Object> properties) {
        return Optional.ofNullable(find(entityType, primaryKey, properties));
    }

    @Override
    public <ET> Optional<ET> findById(Class<ET> entityType, Object primaryKey,
                                      LockModeType lockMode) {
        return Optional.ofNullable(find(entityType, primaryKey, lockMode));
    }

    @Override
    public <ET> Optional<ET> findById(Class<ET> entityType, Object primaryKey,
                                      LockModeType lockMode,
                                      Map<String, Object> properties) {
        return Optional.ofNullable(find(entityType, primaryKey, lockMode, properties));
    }

    @Override
    public <ET> List<ET> findAll(Class<ET> entityClass) {
        return createQuery(findAllQuery(entityClass)).getResultList();
    }

    @Override
    public <ET> Stream<ET> findAllStream(Class<ET> entityClass) {
        return queryResultStream(createQuery(findAllQuery(entityClass)));
    }

    @Override
    public <ET> CriteriaQuery<ET> findAllQuery(Class<ET> entityClass) {
        CriteriaBuilder cb = getCriteriaBuilder();
        CriteriaQuery<ET> query = cb.createQuery(entityClass);
        query.from(entityClass);
        return query;
    }

    @Override
    public <ET> boolean exists(Class<ET> entityType, Object primaryKey) {
        String idFieldName = getIdFieldName(entityType);
        CriteriaBuilder cb = getCriteriaBuilder();
        CriteriaQuery<Long> query = cb.createQuery(Long.class);
        Root<ET> from = query.from(entityType);
        query.select(cb.count(from));
        query.where(cb.equal(from.get(idFieldName), primaryKey));
        return createQuery(query).getSingleResult() == 1L;
    }

    @Override
    public <ET> CriteriaQuery<Long> createCountQuery(Class<ET> entityType) {
        CriteriaBuilder cb = getCriteriaBuilder();
        CriteriaQuery<Long> query = cb.createQuery(Long.class);
        query.select(cb.count(query.from(entityType)));
        return query;
    }

    @Override
    public <ET> long count(Class<ET> entityType) {
        return createQuery(createCountQuery(entityType)).getSingleResult();
    }

    @Override
    @SuppressWarnings("unchecked")
    public <ET, KT> List<ET> findByIds(Class<ET> entityType, Collection<KT> ids, int maxInSize) {
        return partitionsTo(ids, maxInSize, p -> findByIds(entityType, p));
    }

    @Override
    public <ET, KT> List<ET> findByIds(Class<ET> entityType, Collection<KT> ids) {
        CriteriaBuilder cb = getCriteriaBuilder();
        CriteriaQuery<ET> find = cb.createQuery(entityType);
        Root<ET> from = find.from(entityType);
        find.where(from.get(getIdFieldName(entityType)).in(ids));
        TypedQuery<ET> query = createQuery(find);
        return queryResultList(query);
    }

    @Override
    @SuppressWarnings("unchecked")
    public <ET, KT> int deleteById(Class<ET> entityType, KT id) {
        return deleteByIds(entityType, listOf(id));
    }

    @Override
    public <ET, KT> int deleteByIds(Class<ET> entityType, Collection<KT> ids, int maxInSize) {
        return partitions(ids, maxInSize).stream()
                .map(p -> deleteByIds(entityType, ids))
                .mapToInt(c -> c)
                .sum();
    }

    @Override
    public <ET, KT> int deleteByIds(Class<ET> entityType, Collection<KT> ids) {
        CriteriaBuilder cb = getCriteriaBuilder();
        CriteriaDelete<ET> delete = cb.createCriteriaDelete(entityType);
        Root<ET> from = delete.from(entityType);
        delete.where(from.get(getIdFieldName(entityType)).in(ids));
        return createQuery(delete).executeUpdate();
    }

    @Override
    public long pageCount(long recordsCount, long perPage) {
        return recordsCount / perPage + (recordsCount % perPage == 0 ? 0 : 1);
    }

    @Override
    public <ET, KT> Set<KT> findExistingEntityIdsIn(Class<ET> entityType, Collection<KT> ids) {
        return findExistingEntityIdsIn(entityType, ids, 1000);
    }

    @Override
    @SuppressWarnings("unchecked")
    public <ET, KT> Set<KT> findExistingEntityIdsIn(Class<ET> entityType, Collection<KT> ids, int partitionSize) {
        String idFieldName = getIdFieldName(entityType);

        String query = "SELECT e." + idFieldName + " FROM " + entityType.getSimpleName() + " e WHERE e." + idFieldName + " IN ?1";
        return new HashSet<>((Collection<? extends KT>) partitionsToQueryResultList(query, getEntityIdType(entityType), ids, partitionSize));
    }

    @Override
    public <ET, VT, KT> Set<KT> findExistingEntityIdsIn(Class<ET> entityType, Collection<VT> objects, Function<VT, KT> idMapper) {
        return findExistingEntityIdsIn(
                entityType,
                objects.stream()
                        .map(idMapper)
                        .collect(Collectors.toSet())
        );
    }

    @Override
    public <ET, VT, KT> Set<KT> findExistingEntityIdsIn(Class<ET> entityType, Collection<VT> objects, Function<VT, KT> idMapper, int partitionSize) {
        return findExistingEntityIdsIn(
                entityType,
                objects.stream()
                        .map(idMapper)
                        .collect(Collectors.toSet()),
                partitionSize
        );
    }

    @Override
    public <ET> Optional<ET> random(Class<ET> entityType) {
        long total = count(entityType);

        if (total == 0) {
            return Optional.empty();
        }

        long number = ThreadLocalRandom.current().nextLong(0L, total);

        List<ET> entities = queryPage(number, 1, "SELECT e FROM " + entityType.getSimpleName() + " e");

        if (!entities.isEmpty()) {
            return Optional.of(entities.get(0));
        } else {
            return Optional.empty();
        }
    }
}
