/*
 * Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license
 * Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template
 */
package rocks.imsofa.ai.puppychatter;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.slf4j.LoggerFactory;
import rocks.imsofa.ai.puppychatter.cache.CacheService;

/**
 *
 * @author USER
 */
public abstract class AbstractPuppyChatter<T extends PromptParameters, S extends Response> implements PuppyChatter<T, S> {

    protected Map<String, List<Conversation>> sessionHistory = new HashMap<>();
    protected CacheService cacheService = null;
    protected String replyRole = "assistant";

    public AbstractPuppyChatter(CacheService cacheService, String replyRole) {
        this.cacheService = cacheService;
        this.replyRole = replyRole;
    }

    public AbstractPuppyChatter() {
    }

    @Override
    public String createSession() {
        String sessionId = UUID.randomUUID().toString();
        sessionHistory.put(sessionId, new ArrayList<>());
        return sessionId;
    }

    /**
     * Generates a response by invoking the {@code bark} method with a default verification function.
     * If it is possible, the call to the underlying llm will be in a streaming way but not guaranteed.
     * The default verification function returns {@code VerificationResult.GOOD} if the response is not an error,
     * otherwise it returns {@code VerificationResult.GIVE_UP}.
     *
     * @param sessionId   the session identifier
     * @param prompt      the input prompt to process
     * @param parameters  additional parameters for the bark operation
     * @return            the generated response of type {@code S}
     * @throws BarkException if an error occurs during the bark operation
     */
    @Override
    public void bark(String sessionId, String prompt, T parameters, BarkCallback<S> barkCallback) throws BarkException {
        this._bark(sessionId, prompt, parameters, barkCallback);
    }

    @Override
    public S bark(String sessionId, String prompt, T parameters) throws BarkException {
        return this._bark(sessionId, prompt, parameters, null);
    }

    @Override
    public void bark(String sessionId, String prompt, BarkCallback<S> barkCallback) throws BarkException {
        this.bark(sessionId, prompt, this.createDefaultPromptParameter(), barkCallback);
    }

    @Override
    public S bark(String sessionId, String prompt) throws BarkException {
        return this.bark(sessionId, prompt, this.createDefaultPromptParameter());
    }

    @Override
    public void closeSession(String sessionId) {
        sessionHistory.remove(sessionId);
    }

    /**
     * merge everything to make it simpler
     *
     * @param sessionId
     * @param prompt
     * @param parameters
     * @param verifier
     * @param barkCallback
     * @return
     * @throws Exception
     */
    private S _bark(String sessionId, String prompt, T parameters, BarkCallback<S> barkCallback) throws BarkException {
        List<Conversation> messages = sessionHistory.get(sessionId);
        Conversation newMessage = this.createConversationFromPrompt(prompt, parameters);

        messages.add(newMessage);
        boolean cacheable = isCacheable(sessionId, messages, parameters);
        //visit cache in advance
        if (cacheService != null && cacheable) {
            try {
                messages = cacheService.sync2BoundConversations(messages);
            } catch (Exception ex) {
                throw new BarkException(ex);
            }
        }
        sessionHistory.put(sessionId, messages);
        if (cacheService != null && cacheable) {
            try {
                Conversation c = cacheService.getCachedReply(messages);
                if (c != null) {
                    S response = createResponseFromConversation(messages, c);
                    if (parameters.getResponseVerifier().verify(response).equals(VerificationResult.GOOD)) {
                        LoggerFactory.getLogger(getClass()).info("reply obtained from cache");
                        messages.add(c);
                        if (barkCallback != null) {
                            barkCallback.responseChunkReceived(response, true);
                            return null;
                        }else{
                            return response;
                        }
                    }
                }
            } catch (Exception ex) {
                throw new BarkException(ex);
            }
        }
        LoggerFactory.getLogger(getClass()).info("cache not available, bark at the underlying service");
        if (barkCallback != null) {
            try {
                //then use async way
                this._bark(sessionId, messages, parameters, barkCallback);
            } catch (Exception ex) {
                Logger.getLogger(AbstractPuppyChatter.class.getName()).log(Level.SEVERE, null, ex);
            }
            return null;
        } else {
            //then use sync way
            S response = null;
            try {
                response = this._bark(sessionId, messages, parameters);
                // System.out.println("response="+response.getMessage());
            } catch (Exception ex) {
                System.out.println("something wrong "+ex);
                throw new BarkException(ex);
            }
            VerificationResult responseQuality = parameters.getResponseVerifier().verify(response);
            LoggerFactory.getLogger(getClass()).info("verification result: " + responseQuality);
            while (responseQuality.equals(VerificationResult.TRY_AGAIN)) {
                try {
                    response = this._bark(sessionId, messages, parameters);
                    responseQuality = parameters.getResponseVerifier().verify(response);
                    LoggerFactory.getLogger(getClass()).info("verification result: " + responseQuality);
                } catch (Exception ex) {
                    throw new BarkException(ex);
                }
            }
            if (responseQuality.equals(VerificationResult.GOOD)) {
                Conversation replyMessage = new Conversation(replyRole, response.getMessage());
//                LoggerFactory.getLogger(getClass()).info("saving to cache");
                //save to cache
                if (cacheService != null && cacheable) {
//                    LoggerFactory.getLogger(getClass()).info("saving to cache 2");
                    try {
                        messages = cacheService.sync2BoundConversations(messages);
                        cacheService.cacheReply(messages, replyMessage);
                        
                    } catch (Exception ex) {
                        throw new BarkException(ex);
                    }
                }
                messages.add(replyMessage);
                sessionHistory.put(sessionId, messages);
                return response;
            } else {
                response.setError(true);
                throw new BarkException(response);
            }
        }
    }

    protected abstract S _bark(String sessionId, List<Conversation> messages, T parameters) throws Exception;

    protected abstract void _bark(String sessionId, List<Conversation> messages, T parameters, BarkCallback<S> callback) throws Exception;
    
    protected abstract T createDefaultPromptParameter();

    protected abstract S createResponseFromConversation(List<Conversation> lastPrompt, Conversation lastResponse);

    /**
     * create a conversation from a prompt
     * by default, this function create an instance of {@code Conversation}
     * @param prompt
     * @param parameters
     * @return
     */
    protected Conversation createConversationFromPrompt(String prompt, T parameters){
        return new Conversation(parameters.getRole(), prompt);
    }

    /**
     * check if the current conversation is cacheable
     * by default, this function returns true
     * child classes can override this function to customize the behavior
     * @param sessionId
     * @param messages
     * @param parameters
     * @return
     */
    protected boolean isCacheable(String sessionId, List<Conversation> messages, T parameters) {
        return true;
    }
}
