package org.kink_lang.kink.internal.compile.javaclassir;

import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.kink_lang.kink.internal.function.Function4;
import org.kink_lang.kink.internal.intrinsicsupport.PreloadedFuns;
import org.kink_lang.kink.internal.program.itree.LocalVar;
import org.kink_lang.kink.internal.program.itree.LocalVarContent;
import org.kink_lang.kink.internal.program.itree.UsedDefinedVars;

/**
 * Allocation analysis of localvars in a fast fun.
 *
 * @param recv the set of local vars containing recv.
 * @param args the mapping from local vars to arg indexes.
 * @param control the local vars of preloaded control funs.
 * @param field vars stored in fields of the fun val.
 * @param stack vars stored in stack.
 */
public record AllocationSet(
        Set<LocalVar> recv,
        Map<LocalVar, Integer> args,
        Set<LocalVar> control,
        List<LocalVar> field,
        List<LocalVar> stack) {

    /** Lvars of preloaded control funs. */
    private static final Set<LocalVar> CONTROL_LVARS = PreloadedFuns.controlSyms().stream()
        .map(LocalVar.Original::new)
        .collect(Collectors.toUnmodifiableSet());

    /**
     * Gets the allocation of the lvar.
     *
     * @param lvar the local var.
     * @return the allocation of the lvar.
     */
    public Allocation get(LocalVar lvar) {
        if (recv().contains(lvar)) {
            return new Allocation.Recv();
        }

        if (args().containsKey(lvar)) {
            int argInd = args().get(lvar);
            return new Allocation.Arg(argInd);
        }

        if (control().contains(lvar)) {
            return new Allocation.Preloaded();
        }

        int fieldInd = field().indexOf(lvar);
        if (fieldInd >= 0) {
            return new Allocation.Field(fieldInd, isNonNull(lvar));
        }

        int stackInd = stack().indexOf(lvar);
        if (stackInd >= 0) {
            return new Allocation.Stack(stackInd, isNonNull(lvar));
        }

        return new Allocation.Unused();
    }

    /**
     * Whether the local var is nonnull.
     */
    private boolean isNonNull(LocalVar lvar) {
        return lvar instanceof LocalVar.Generated
            || CONTROL_LVARS.contains(lvar);
    }

    /**
     * Analyzes the vars as of the val-capture fast fun for the case a local var
     * of a control fun is overridden on the top level.
     *
     * @param vars of the val-capture fast fun.
     * @return the analysis.
     */
    public static AllocationSet valCaptureControlOverridden(UsedDefinedVars vars) {
        return valCapture(Set.of(), vars);
    }

    /**
     * Analyzes the vars as of the val-capture fast fun for the case that no local var
     * of a control fun is NOT overridden on the top level.
     *
     * @param vars of the val-capture fast fun.
     * @return the analysis.
     */
    public static AllocationSet valCaptureControlUnchanged(UsedDefinedVars vars) {
        return valCapture(CONTROL_LVARS, vars);
    }

    /**
     * Makes analysis for a val-capture fast fun.
     */
    private static AllocationSet valCapture(Set<LocalVar> control, UsedDefinedVars vars) {
        return analyze(control, vars, (recv, args, bound, free) -> {
            List<LocalVar> field = free.stream()
                .sorted(Comparator.comparing(LocalVar::name))
                .toList();
            List<LocalVar> stack = bound.stream()
                .sorted(Comparator.comparing(LocalVar::name))
                .toList();
            return new AllocationSet(recv, args, control, field, stack);
        });
    }

    /**
     * Analyzes the vars as of the binding-capture fast fun for the case a local var
     * of a control fun is overridden on the top level.
     *
     * @param vars of the binding-capture fast fun.
     * @return the analysis.
     */
    public static AllocationSet bindingCaptureControlOverridden(UsedDefinedVars vars) {
        return bindingCapture(Set.of(), vars);
    }

    /**
     * Analyzes the vars as of the binding-capture fast fun for the case a local var
     * of a control fun is NOT overridden on the top level.
     *
     * @param vars of the binding-capture fast fun.
     * @return the analysis.
     */
    public static AllocationSet bindingCaptureControlUnchanged(UsedDefinedVars vars) {
        return bindingCapture(CONTROL_LVARS, vars);
    }

    /**
     * Makes analysis for binding-capture fast fun.
     */
    private static AllocationSet bindingCapture(Set<LocalVar> control, UsedDefinedVars vars) {
        return analyze(control, vars, (recv, args, bound, free) -> {
            List<LocalVar> stack = Stream.concat(bound.stream(), free.stream())
                .sorted(Comparator.comparing(LocalVar::name))
                .toList();
            return new AllocationSet(recv, args, control, List.of(), stack);
        });
    }

    /**
     * Does analysis.
     */
    private static AllocationSet analyze(
            Set<LocalVar> control,
            UsedDefinedVars vars,
            Function4<
                Set<LocalVar>,
                Map<LocalVar, Integer>,
                Set<LocalVar>,
                Set<LocalVar>,
                AllocationSet> emit) {
        Set<LocalVar> recv = new HashSet<>();
        Map<LocalVar, Integer> args = new HashMap<>();
        SortedSet<LocalVar> bound = new TreeSet<>(Comparator.comparing(LocalVar::name));
        for (LocalVar lvar : vars.definedLvars()) {
            LocalVarContent content = vars.getContent(lvar);
            if (content instanceof LocalVarContent.Recv) {
                recv.add(lvar);
            } else if (content instanceof LocalVarContent.Arg arg) {
                args.put(lvar, arg.index());
            } else if (vars.isUsed(lvar)) {
                bound.add(lvar);
            }
        }

        Set<LocalVar> free = vars.freeLvars()
            .stream()
            .filter(Predicate.not(control::contains))
            .collect(Collectors.toUnmodifiableSet());

        return emit.apply(recv, args, bound, free);
    }

}

// vim: et sw=4 sts=4 fdm=marker
