package org.kink_lang.kink.internal.program.itreeoptimize;

import java.util.List;
import java.util.function.Function;
import java.util.function.UnaryOperator;

import org.kink_lang.kink.internal.program.itree.*;

/**
 * A recursive composite optimizer.
 *
 * <p>This optimizer optimizes sub itrees recursively,
 * and optimizes the itree using the optimizers made by the factories.</p>
 */
public class RecursiveOptimizer implements UnaryOperator<Itree> {

    /** The factories of optimizers applied to the argument itree. */
    private final List<OptimizerFactory> optimizerFactories;

    /**
     * Constructs a composite optimizer.
     *
     * @param optimizerFactories the factories of optimizers applied to the argument itree.
     */
    public RecursiveOptimizer(List<OptimizerFactory> optimizerFactories) {
        this.optimizerFactories = List.copyOf(optimizerFactories);
    }

    @Override
    public Itree apply(Itree org) {
        return DeepTransformer.deepTransform(org, new DeepTransformer.Callback() {
            @Override public LocalVar derefLvar(LocalVar lvar) {
                return lvar;
            }

            @Override public LocalVar storeLvar(LocalVar lvar) {
                return lvar;
            }

            @Override public Itree itree(Itree itree) {
                return itree.accept(new Visitor());
            }
        });
    }

    /**
     * Does the job for transformation.
     */
    private class Visitor extends SkeltonItreeVisitor<Itree> {

        /**
         * Constructs a visitor.
         */
        Visitor() {
            super(makeCompositeOptimizer());
        }

        @Override
        public Itree visit(SlowFunItree itree) {
            var body = RecursiveOptimizer.this.apply(itree.body());
            return super.visit(new SlowFunItree(body, itree.pos()));
        }

        @Override
        public Itree visit(FastFunItree itree) {
            var bodyOptimizer = makeBodyOptimizer(itree);
            var body = bodyOptimizer.apply(itree.body());
            return super.visit(new FastFunItree(body, itree.pos()));
        }

    }

    /**
     * Makes an optimizer for the body of the ssa fun.
     */
    private RecursiveOptimizer makeBodyOptimizer(FastFunItree enclosing) {
        var subFactories = this.optimizerFactories.stream()
            .map(factory -> factory.makeFactory(enclosing))
            .toList();
        return new RecursiveOptimizer(subFactories);
    }

    /**
     * The optimizer composing all the optimzers made by the factories.
     */
    private Function<Itree, Itree> makeCompositeOptimizer() {
        Function<Itree, Itree> opt = Function.identity();
        for (OptimizerFactory factory : this.optimizerFactories) {
            opt = opt.andThen(factory.makeOptimizer());
        }
        return opt;
    }

}

// vim: et sw=4 sts=4 fdm=marker
