/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.afrunt.jpa.powerdao;

import javax.persistence.Entity;
import javax.persistence.EntityManager;
import javax.persistence.LockModeType;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * @author Andrii Frunt
 */
public abstract class PowerDao<ET, KT> extends SimpleDao implements ExtendedEntityApiDao<ET, KT> {
    private Class<ET> entityType;
    private Class<KT> idType;

    /**
     * In this case, you need to override the getEntityManager() method
     */
    public PowerDao() {
    }

    public PowerDao(EntityManager em) {
        super(em);
    }

    public static <AET, AKT> PowerDao<AET, AKT> instance(Class<AET> actualEntityType, Class<AKT> actualIdType, EntityManager em) {
        PowerDao<AET, AKT> powerDAO = new PowerDao<AET, AKT>(em) {
        };

        powerDAO.cacheEntityType(actualEntityType);
        powerDAO.cacheIdType(actualIdType);

        return powerDAO;
    }

    @Override
    public int getDefaultInClauseLimit() {
        return 1000;
    }

    @Override
    public long count() {
        return count(getEntityType());
    }

    @Override
    public ET find(Object primaryKey) {
        return find(getEntityType(), primaryKey);
    }

    @Override
    public ET find(Object primaryKey,
                   Map<String, Object> properties) {
        return find(getEntityType(), primaryKey, properties);
    }

    @Override
    public ET find(Object primaryKey, LockModeType lockMode) {
        return find(getEntityType(), primaryKey, lockMode);
    }


    @Override
    public ET find(Object primaryKey, LockModeType lockMode, Map<String, Object> properties) {
        return find(getEntityType(), primaryKey, lockMode, properties);
    }

    @Override
    public List<ET> findAll() {
        return findAll(getEntityType());
    }

    @Override
    public Stream<ET> findAllStream() {
        return findAllStream(getEntityType());
    }

    @Override
    public Optional<ET> findById(KT primaryKey) {
        return findById(getEntityType(), primaryKey);
    }

    @Override
    public Optional<ET> findById(KT primaryKey, Map<String, Object> properties) {
        return findById(getEntityType(), primaryKey, properties);
    }

    @Override
    public Optional<ET> findById(KT primaryKey, LockModeType lockMode) {
        return findById(getEntityType(), primaryKey, lockMode);
    }

    @Override
    public Optional<ET> findById(KT primaryKey, LockModeType lockMode, Map<String, Object> properties) {
        return findById(getEntityType(), primaryKey, lockMode, properties);
    }

    @Override
    public List<ET> findByIds(Collection<KT> ids, int maxInSize) {
        return findByIds(getEntityType(), ids, maxInSize);
    }

    @Override
    public List<ET> findByIds(Collection<KT> ids) {
        return findByIds(ids, getDefaultInClauseLimit());
    }

    @Override
    public boolean exists(KT primaryKey) {
        return exists(getEntityType(), primaryKey);
    }

    @Override
    public int deleteById(KT id) {
        return deleteById(getEntityType(), id);
    }

    @Override
    public int deleteByIds(Collection<KT> ids) {
        return deleteByIds(getEntityType(), ids, getDefaultInClauseLimit());
    }

    @Override
    public int deleteByIds(Collection<KT> ids, int maxInSize) {
        return deleteByIds(getEntityType(), ids, maxInSize);
    }

    @Override
    public Set<KT> findExistingEntityIdsIn(Collection<KT> ids) {
        return findExistingEntityIdsIn(entityType, ids);
    }

    @Override
    @SuppressWarnings("unchecked")
    public Set<KT> findExistingEntityIdsIn(Collection<KT> ids, int partitionSize) {
        return findExistingEntityIdsIn(entityType, ids, partitionSize);
    }

    @Override
    public <VT> Set<KT> findExistingEntityIdsIn(Collection<VT> objects, Function<VT, KT> idMapper) {
        return findExistingEntityIdsIn(entityType, objects, idMapper);
    }

    @Override
    public <VT> Set<KT> findExistingEntityIdsIn(Collection<VT> objects, Function<VT, KT> idMapper, int partitionSize) {
        return findExistingEntityIdsIn(entityType, objects, idMapper, partitionSize);
    }

    protected Class<ET> getCachedEntityType() {
        return entityType;
    }

    protected Class<KT> getCachedIdType() {
        return idType;
    }

    protected Class<ET> cacheEntityType(Class<ET> entityType) {
        return this.entityType = entityType;
    }

    protected Class<KT> cacheIdType(Class<KT> idType) {
        return this.idType = idType;
    }

    @SuppressWarnings("unchecked")
    protected Class<ET> getEntityType() {
        Class<ET> cachedEntityType = getCachedEntityType();
        if (cachedEntityType != null) {
            return cachedEntityType;
        }
        ParameterizedType genericSuperclass = getCorrectGenericSuperclass();
        return cacheEntityType((Class<ET>) genericSuperclass.getActualTypeArguments()[0]);
    }

    @SuppressWarnings("unchecked")
    protected Class<KT> getIdType() {
        Class<KT> cachedIdType = getCachedIdType();
        if (cachedIdType != null) {
            return cachedIdType;
        }

        ParameterizedType genericSuperclass = getCorrectGenericSuperclass();
        return cacheIdType((Class<KT>) genericSuperclass.getActualTypeArguments()[1]);
    }

    private boolean isCorrectGenericSuperType(Type type) {
        if (!(type instanceof ParameterizedType)) {
            return false;
        }

        ParameterizedType pt = (ParameterizedType) type;

        Type[] actualTypeArguments = pt.getActualTypeArguments();

        return actualTypeArguments.length == getNumberOfParameters()
                && actualTypeArguments[0] instanceof Class
                && ((Class) actualTypeArguments[0]).isAnnotationPresent(Entity.class);
    }

    protected int getNumberOfParameters() {
        return 2;
    }

    protected ParameterizedType getCorrectGenericSuperclass() {
        return getCorrectGenericSuperclass(getClass());
    }

    @SuppressWarnings("unchecked")
    private ParameterizedType getCorrectGenericSuperclass(Class<? extends PowerDao> type) {
        Type genericSuperclass = type.getGenericSuperclass();
        List<Type> types = new ArrayList<>();
        types.add(type);
        types.add(genericSuperclass);
        types.addAll(Arrays.asList(type.getGenericInterfaces()));

        types = types.stream().filter(this::isCorrectGenericSuperType).collect(Collectors.toList());

        if (types.isEmpty() && PowerDao.class.isAssignableFrom(type.getSuperclass())) {
            return getCorrectGenericSuperclass((Class<? extends PowerDao>) type.getSuperclass());
        }

        if (types.size() > 1) {
            throw new RuntimeException("Cannot infer generic supertype");
        }

        return (ParameterizedType) types.iterator().next();
    }

    @Override
    public Optional<ET> random() {
        return random(getEntityType());
    }
}
