/*
 * Decompiled with CFR 0.152.
 */
package org.intocps.maestro.plugin;

import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.xml.xpath.XPathExpressionException;
import org.intocps.maestro.ast.LexIdentifier;
import org.intocps.maestro.ast.MableAstFactory;
import org.intocps.maestro.ast.MableBuilder;
import org.intocps.maestro.ast.node.AExpressionStm;
import org.intocps.maestro.ast.node.AIdentifierExp;
import org.intocps.maestro.ast.node.PExp;
import org.intocps.maestro.ast.node.PStm;
import org.intocps.maestro.ast.node.PType;
import org.intocps.maestro.core.Framework;
import org.intocps.maestro.framework.fmi2.ComponentInfo;
import org.intocps.maestro.framework.fmi2.Fmi2SimulationEnvironment;
import org.intocps.maestro.framework.fmi2.RelationVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class DerivativesHandler {
    static final Logger logger = LoggerFactory.getLogger(DerivativesHandler.class);
    final BiPredicate<Fmi2SimulationEnvironment, Fmi2SimulationEnvironment.Variable> canInterpolateInputsFilter = (env, v) -> {
        try {
            return ((ComponentInfo)env.getUnitInfo((LexIdentifier)v.scalarVariable.instance, (Framework)Framework.FMI2)).modelDescription.getCanInterpolateInputs();
        }
        catch (XPathExpressionException e) {
            return false;
        }
    };
    final BiFunction<Fmi2SimulationEnvironment, LexIdentifier, Integer> maxOutputDerivativeOrder = (env, id) -> {
        try {
            return ((ComponentInfo)env.getUnitInfo((LexIdentifier)id, (Framework)Framework.FMI2)).modelDescription.getMaxOutputDerivativeOrder();
        }
        catch (XPathExpressionException e) {
            e.printStackTrace();
            return 0;
        }
    };
    String globalCacheName = "derivatives";
    String globalDerInputBuffer = "der_input_buffer";
    Map<LexIdentifier, GetDerivativesInfo> derivativesGetInfo = new HashMap<LexIdentifier, GetDerivativesInfo>();
    private boolean allocated = false;
    private boolean requireArrayUtilUnload = false;
    private Map<Map.Entry<LexIdentifier, List<Fmi2SimulationEnvironment.Relation>>, LinkedHashMap<Fmi2SimulationEnvironment.Variable, Map.Entry<Fmi2SimulationEnvironment.Variable, GetDerivativesInfo>>> resolvedInputData;

    DerivativesHandler() {
    }

    public List<PStm> allocateMemory(List<LexIdentifier> componentNames, Set<Fmi2SimulationEnvironment.Relation> inputRelations, Fmi2SimulationEnvironment env) {
        this.allocated = true;
        Set tmp = inputRelations.stream().filter(r -> this.canInterpolateInputsFilter.test(env, r.getSource())).map(Fmi2SimulationEnvironment.Relation::getTargets).collect(Collectors.toSet());
        HashMap<LexIdentifier, List> vars = new HashMap<LexIdentifier, List>();
        for (Map map : tmp) {
            for (Map.Entry entry : map.entrySet()) {
                vars.computeIfAbsent((LexIdentifier)entry.getKey(), key -> new Vector()).add((Fmi2SimulationEnvironment.Variable)entry.getValue());
            }
        }
        for (Map.Entry entry : vars.entrySet()) {
            ((List)entry.getValue()).sort(Comparator.comparing(v -> v.getScalarVariable().getScalarVariable().valueReference));
        }
        if (vars.isEmpty()) {
            return new Vector<PStm>();
        }
        Vector<PStm> statements = new Vector<PStm>();
        Vector vector = new Vector();
        for (LexIdentifier name : componentNames) {
            vars.entrySet().stream().filter(f -> ((LexIdentifier)f.getKey()).getText().equals(name.getText())).findFirst().ifPresent(f -> {
                LexIdentifier id = (LexIdentifier)f.getKey();
                int order = this.maxOutputDerivativeOrder.apply(env, id);
                GetDerivativesInfo varDerInfo = new GetDerivativesInfo();
                varDerInfo.varMaxOrder = order;
                for (int i = 0; i < ((List)f.getValue()).size(); ++i) {
                    Fmi2SimulationEnvironment.Variable var = (Fmi2SimulationEnvironment.Variable)((List)f.getValue()).get(i);
                    varDerInfo.varStartIndex.put(var, varDerInfo.varStartIndex.size() * order);
                }
                int size = ((List)f.getValue()).size() * order;
                perInstanceSizes.add(size);
                varDerInfo.valueDestIdentifier = MableAstFactory.newAArrayIndexExp((PExp)MableAstFactory.newAIdentifierExp((LexIdentifier)MableAstFactory.newAIdentifier((String)this.globalCacheName)), Arrays.asList(MableAstFactory.newAIntLiteralExp((Integer)componentNames.indexOf(name))));
                String orderArrayName = id.getText() + "_der_order";
                statements.add(MableBuilder.newVariable((String)orderArrayName, (PType)MableAstFactory.newAIntNumericPrimitiveType(), IntStream.range(0, ((List)f.getValue()).size()).mapToObj(v -> IntStream.range(1, order + 1).mapToObj(MableAstFactory::newAIntLiteralExp)).flatMap(Function.identity()).collect(Collectors.toList())));
                varDerInfo.orderArrayId = orderArrayName;
                String varSelectName = id.getText() + "_der_select";
                statements.add(MableBuilder.newVariable((String)varSelectName, (PType)MableAstFactory.newUIntType(), ((List)f.getValue()).stream().flatMap(v -> IntStream.range(1, order + 1).mapToObj(o -> MableAstFactory.newAIntLiteralExp((Integer)((int)v.getScalarVariable().getScalarVariable().valueReference)))).collect(Collectors.toList())));
                varDerInfo.valueSelectArrayId = varSelectName;
                this.derivativesGetInfo.put(id, varDerInfo);
            });
        }
        statements.add(0, MableBuilder.newVariable((String)this.globalCacheName, (PType)MableAstFactory.newARealNumericPrimitiveType(), (int[])new int[]{componentNames.size() - 1, vector.stream().mapToInt(i -> i).max().orElse(0)}));
        statements.addAll(this.allocateForInput(inputRelations, env));
        return statements;
    }

    private List<PStm> allocateForInput(Set<Fmi2SimulationEnvironment.Relation> inputRelations, Fmi2SimulationEnvironment env) {
        this.resolvedInputData = inputRelations.stream().filter(r -> this.canInterpolateInputsFilter.test(env, r.getSource())).collect(Collectors.groupingBy(s -> s.getSource().getScalarVariable().instance)).entrySet().stream().collect(Collectors.toMap(Function.identity(), mapped -> ((List)mapped.getValue()).stream().sorted(Comparator.comparing(map -> map.getSource().getScalarVariable().getScalarVariable().valueReference)).collect(Collectors.toMap(Fmi2SimulationEnvironment.Relation::getSource, map -> {
            Fmi2SimulationEnvironment.Variable next = (Fmi2SimulationEnvironment.Variable)map.getTargets().values().iterator().next();
            RelationVariable fromVar = next.scalarVariable;
            GetDerivativesInfo fromVarDerivativeInfo = this.derivativesGetInfo.get(fromVar.instance);
            if (fromVarDerivativeInfo != null) {
                logger.trace("Derivative mapping {}.{} to {}.{}", new Object[]{fromVar.instance, fromVar.scalarVariable.name, map.getSource().getScalarVariable().instance, map.getSource().getScalarVariable().scalarVariable.name});
            }
            return Map.entry(next, fromVarDerivativeInfo);
        }, (e1, e2) -> e1, LinkedHashMap::new))));
        int size = this.resolvedInputData.values().stream().mapToInt(variableEntryLinkedHashMap -> variableEntryLinkedHashMap.values().stream().mapToInt(v -> ((GetDerivativesInfo)v.getValue()).varMaxOrder).sum()).sum();
        List<PStm> allocationStatements = Stream.concat(Stream.of(MableBuilder.newVariable((String)"der_input_buffer", (PType)MableAstFactory.newARealNumericPrimitiveType(), (int)size)), this.resolvedInputData.entrySet().stream().flatMap(map -> {
            LinkedHashMap resolved = (LinkedHashMap)map.getValue();
            List inputSelectIndices = resolved.entrySet().stream().flatMap(m -> IntStream.range(1, ((GetDerivativesInfo)((Map.Entry)m.getValue()).getValue()).varMaxOrder + 1).mapToObj(i -> Long.valueOf(((Fmi2SimulationEnvironment.Variable)m.getKey()).getScalarVariable().scalarVariable.valueReference).intValue())).collect(Collectors.toList());
            List inputOrders = resolved.entrySet().stream().flatMap(m -> IntStream.range(1, ((GetDerivativesInfo)((Map.Entry)m.getValue()).getValue()).varMaxOrder + 1).mapToObj(i -> i)).collect(Collectors.toList());
            LexIdentifier name = (LexIdentifier)((Map.Entry)map.getKey()).getKey();
            return Stream.of(MableBuilder.newVariable((String)("der_input_select_" + name.getText()), (PType)MableAstFactory.newAIntNumericPrimitiveType(), inputSelectIndices.stream().map(MableAstFactory::newAIntLiteralExp).collect(Collectors.toList())), MableBuilder.newVariable((String)("der_input_order_" + name.getText()), (PType)MableAstFactory.newAIntNumericPrimitiveType(), inputOrders.stream().map(MableAstFactory::newAIntLiteralExp).collect(Collectors.toList())));
        })).collect(Collectors.toList());
        this.requireArrayUtilUnload = true;
        allocationStatements.add(0, MableBuilder.newVariable((String)"util", (PType)MableAstFactory.newANameType((String)"ArrayUtil"), (PExp)MableAstFactory.newALoadExp(Arrays.asList(MableAstFactory.newAStringLiteralExp((String)"ArrayUtil")))));
        return allocationStatements;
    }

    public List<PStm> deallocate() {
        if (this.requireArrayUtilUnload) {
            return Collections.singletonList(MableAstFactory.newExpressionStm((PExp)MableAstFactory.newUnloadExp(Collections.singletonList(MableAstFactory.newAIdentifierExp((String)"util")))));
        }
        return Collections.emptyList();
    }

    public List<PStm> get(String errorStateLocation) throws InstantiationException {
        if (!this.allocated) {
            throw new InstantiationException("Must be allocated first");
        }
        return this.get(errorStateLocation, null);
    }

    public List<PStm> get(String errorStateLocation, List<LexIdentifier> componentNamesFilter) throws InstantiationException {
        if (!this.allocated) {
            throw new InstantiationException("Must be allocated first");
        }
        if (this.derivativesGetInfo == null) {
            return new Vector<PStm>();
        }
        Vector<PStm> stmts = new Vector<PStm>();
        for (Map.Entry<LexIdentifier, GetDerivativesInfo> map : this.derivativesGetInfo.entrySet()) {
            LexIdentifier id = map.getKey();
            if (componentNamesFilter != null && !componentNamesFilter.contains(id)) continue;
            GetDerivativesInfo info = map.getValue();
            AIdentifierExp object = MableAstFactory.newAIdentifierExp((LexIdentifier)((LexIdentifier)id.clone()));
            stmts.add((PStm)MableAstFactory.newExpressionStm((PExp)MableBuilder.call((PExp)object, (String)"getRealOutputDerivatives", (PExp[])new PExp[]{MableAstFactory.newAIdentifierExp((String)info.valueSelectArrayId), MableAstFactory.newAIntLiteralExp((Integer)(info.varStartIndex.size() * info.varMaxOrder)), MableAstFactory.newAIdentifierExp((String)info.orderArrayId), info.valueDestIdentifier.clone()})));
        }
        return stmts;
    }

    public List<PStm> set(String errorStateLocation) throws InstantiationException {
        return this.set(errorStateLocation, null);
    }

    public List<PStm> set(String errorStateLocation, List<LexIdentifier> componentNamesFilter) throws InstantiationException {
        if (!this.allocated) {
            throw new InstantiationException("Must be allocated first");
        }
        if (this.resolvedInputData == null) {
            return new Vector<PStm>();
        }
        return this.resolvedInputData.entrySet().stream().filter(m -> componentNamesFilter == null || componentNamesFilter.contains(((Map.Entry)m.getKey()).getKey())).flatMap(map -> {
            AtomicInteger inputOffset = new AtomicInteger(0);
            LexIdentifier name = (LexIdentifier)((Map.Entry)map.getKey()).getKey();
            List inputOrders = ((LinkedHashMap)map.getValue()).entrySet().stream().map(m -> IntStream.range(1, ((GetDerivativesInfo)((Map.Entry)m.getValue()).getValue()).varMaxOrder + 1).boxed()).flatMap(Function.identity()).collect(Collectors.toList());
            Stream<AExpressionStm> copyStatements = ((LinkedHashMap)map.getValue()).entrySet().stream().map(pair -> {
                GetDerivativesInfo from = (GetDerivativesInfo)((Map.Entry)pair.getValue()).getValue();
                Integer index = from.varStartIndex.get(((Map.Entry)pair.getValue()).getKey());
                logger.debug("Copying {} from index {} in ders", ((Map.Entry)pair.getValue()).getKey(), (Object)index);
                PExp c = MableBuilder.call((PExp)MableAstFactory.newAIdentifierExp((String)"util"), (String)"copyRealArray", (PExp[])new PExp[]{from.valueDestIdentifier, MableAstFactory.newAIntLiteralExp((Integer)index), MableAstFactory.newAIntLiteralExp((Integer)from.varMaxOrder), MableAstFactory.newARefExp((PExp)MableAstFactory.newAIdentifierExp((String)this.globalDerInputBuffer)), MableAstFactory.newAIntLiteralExp((Integer)inputOffset.getAndAdd(from.varMaxOrder))});
                logger.debug("{}", (Object)c);
                return MableAstFactory.newExpressionStm((PExp)c);
            });
            PExp set = MableBuilder.call((PExp)MableAstFactory.newAIdentifierExp((LexIdentifier)((LexIdentifier)((LexIdentifier)((Map.Entry)map.getKey()).getKey()).clone())), (String)"setRealInputDerivatives", (PExp[])new PExp[]{MableAstFactory.newAIdentifierExp((String)("der_input_select_" + name.getText())), MableAstFactory.newAIntLiteralExp((Integer)inputOrders.size()), MableAstFactory.newAIdentifierExp((String)("der_input_order_" + name.getText())), MableAstFactory.newAIdentifierExp((String)this.globalDerInputBuffer)});
            return Stream.concat(copyStatements, Stream.of(MableAstFactory.newExpressionStm((PExp)set)));
        }).collect(Collectors.toList());
    }

    class GetDerivativesInfo {
        String orderArrayId;
        String valueSelectArrayId;
        PExp valueDestIdentifier;
        Map<Fmi2SimulationEnvironment.Variable, Integer> varStartIndex = new HashMap<Fmi2SimulationEnvironment.Variable, Integer>();
        Integer varMaxOrder;

        GetDerivativesInfo() {
        }
    }
}

