/*
 * Decompiled with CFR 0.152.
 */
package org.cryptimeleon.craco.protocols.base;

import java.lang.reflect.Field;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import org.cryptimeleon.craco.commitment.CommitmentScheme;
import org.cryptimeleon.craco.protocols.CommonInput;
import org.cryptimeleon.craco.protocols.SecretInput;
import org.cryptimeleon.craco.protocols.arguments.damgardtechnique.DamgardTechnique;
import org.cryptimeleon.craco.protocols.arguments.fiatshamir.FiatShamirProofSystem;
import org.cryptimeleon.craco.protocols.arguments.sigma.ZnChallengeSpace;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.DelegateProtocol;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.LinearExponentStatementFragment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.LinearStatementFragment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.SchnorrFragment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.SendThenDelegateFragment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.setmembership.SetMembershipPublicParameters;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.setmembership.SmallerThanPowerFragment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.variables.SchnorrZnVariable;
import org.cryptimeleon.math.expressions.Substitution;
import org.cryptimeleon.math.expressions.VariableExpression;
import org.cryptimeleon.math.expressions.bool.ExponentEqualityExpr;
import org.cryptimeleon.math.expressions.bool.GroupEqualityExpr;
import org.cryptimeleon.math.expressions.exponent.BasicNamedExponentVariableExpr;
import org.cryptimeleon.math.expressions.exponent.ExponentExpr;
import org.cryptimeleon.math.expressions.group.BasicNamedGroupVariableExpr;
import org.cryptimeleon.math.structures.groups.Group;
import org.cryptimeleon.math.structures.groups.GroupElement;
import org.cryptimeleon.math.structures.rings.zn.Zn;

public class AdHocSchnorrProof
extends DelegateProtocol {
    protected final Zn zn;
    protected Map<String, FragmentCreator> fragmentCreators;
    private HashSet<BasicNamedExponentVariableExpr> exponentVars;
    private HashMap<BasicNamedGroupVariableExpr, Group> groupElemVars;

    protected AdHocSchnorrProof(Zn zn, Map<String, FragmentCreator> fragmentCreators) {
        this.zn = zn;
        this.fragmentCreators = fragmentCreators;
        this.init();
    }

    private void init() {
        this.exponentVars = new HashSet();
        this.groupElemVars = new HashMap();
        this.fragmentCreators.forEach((name, creator) -> creator.forEachVariable(v -> {
            if (v instanceof BasicNamedExponentVariableExpr) {
                this.exponentVars.add((BasicNamedExponentVariableExpr)v);
            }
            if (v instanceof BasicNamedGroupVariableExpr) {
                this.groupElemVars.put((BasicNamedGroupVariableExpr)v, creator.getGroupOfVariable((VariableExpression)v));
            }
        }));
    }

    @Override
    protected SendThenDelegateFragment.ProverSpec provideProverSpecWithNoSendFirst(CommonInput commonInput, SecretInput secretInput, SendThenDelegateFragment.ProverSpecBuilder builder) {
        Object witness;
        Function<String, ?> witnessSource = ((BaseSchnorrProofInput)secretInput).witnessSource;
        for (BasicNamedExponentVariableExpr basicNamedExponentVariableExpr : this.exponentVars) {
            witness = witnessSource.apply(basicNamedExponentVariableExpr.getName());
            if (witness instanceof BigInteger) {
                witness = this.zn.valueOf((BigInteger)witness);
            }
            builder.putWitnessValue(basicNamedExponentVariableExpr.getName(), (Zn.ZnElement)witness);
        }
        for (BasicNamedGroupVariableExpr basicNamedGroupVariableExpr : this.groupElemVars.keySet()) {
            witness = witnessSource.apply(basicNamedGroupVariableExpr.getName());
            builder.putWitnessValue(basicNamedGroupVariableExpr.getName(), (GroupElement)witness);
        }
        return builder.build();
    }

    @Override
    protected SendThenDelegateFragment.SubprotocolSpec provideSubprotocolSpec(CommonInput commonInput, SendThenDelegateFragment.SubprotocolSpecBuilder builder) {
        HashMap<BasicNamedExponentVariableExpr, SchnorrZnVariable> znVars = new HashMap<BasicNamedExponentVariableExpr, SchnorrZnVariable>();
        HashMap groupVars = new HashMap();
        for (BasicNamedExponentVariableExpr var : this.exponentVars) {
            znVars.put(var, builder.addZnVariable(var.getName(), this.zn));
        }
        this.groupElemVars.forEach((v, group) -> groupVars.put(v, builder.addGroupElemVariable(v.getName(), (Group)group)));
        Substitution[] substitutionArray = new Substitution[2];
        substitutionArray[0] = znVars::get;
        substitutionArray[1] = groupVars::get;
        Substitution substitution = Substitution.join((Substitution[])substitutionArray);
        this.fragmentCreators.forEach((name, creator) -> {
            try {
                builder.addSubprotocol((String)name, creator.createFragment(substitution));
            }
            catch (RuntimeException e) {
                throw new RuntimeException("Error instantiating fragment " + name, e);
            }
        });
        return builder.build();
    }

    @Override
    public ZnChallengeSpace getChallengeSpace(CommonInput commonInput) {
        return new ZnChallengeSpace(this.zn);
    }

    public static BaseSchnorrProofInput witnessOf(Object witnessSource) {
        return new BaseSchnorrProofInput(name -> {
            try {
                Field field = witnessSource.getClass().getDeclaredField((String)name);
                field.setAccessible(true);
                return field.get(witnessSource);
            }
            catch (IllegalAccessException | NoSuchFieldException e) {
                throw new IllegalArgumentException(e);
            }
        });
    }

    public static BaseSchnorrProofInput witnessOf(Function<String, ?> witnessSource) {
        return new BaseSchnorrProofInput(witnessSource);
    }

    public static BaseSchnorrProofBuilder builder(Zn zn) {
        return new BaseSchnorrProofBuilder(zn);
    }

    private static class SmallerThanPowerFragmentCreator
    implements FragmentCreator {
        public final ExponentExpr expr;
        public final int base;
        public final int power;
        public final SetMembershipPublicParameters setMembershipPp;

        public SmallerThanPowerFragmentCreator(ExponentExpr expr, int base, int power, SetMembershipPublicParameters setMembershipPp) {
            this.expr = expr;
            this.base = base;
            this.power = power;
            this.setMembershipPp = setMembershipPp;
        }

        @Override
        public SchnorrFragment createFragment(Substitution substitution) {
            return new SmallerThanPowerFragment(this.expr.substitute(substitution), this.base, this.power, this.setMembershipPp);
        }

        @Override
        public void forEachVariable(Consumer<VariableExpression> action) {
            this.expr.getVariables().forEach(action);
        }
    }

    private static class LinearExponentFragmentCreator
    implements FragmentCreator {
        public final ExponentEqualityExpr expr;
        public final Zn zn;

        public LinearExponentFragmentCreator(ExponentEqualityExpr expr, Zn zn) {
            this.expr = expr;
            this.zn = zn;
        }

        @Override
        public SchnorrFragment createFragment(Substitution substitution) {
            return new LinearExponentStatementFragment(this.expr.substitute(substitution), this.zn);
        }

        @Override
        public void forEachVariable(Consumer<VariableExpression> action) {
            this.expr.getVariables().forEach(action);
        }
    }

    private static class LinearFragmentCreator
    implements FragmentCreator {
        public final GroupEqualityExpr expr;

        public LinearFragmentCreator(GroupEqualityExpr expr) {
            this.expr = expr;
        }

        @Override
        public SchnorrFragment createFragment(Substitution substitution) {
            return new LinearStatementFragment(this.expr.substitute(substitution));
        }

        @Override
        public void forEachVariable(Consumer<VariableExpression> action) {
            this.expr.getVariables().forEach(action);
        }

        @Override
        public Group getGroupOfVariable(VariableExpression var) {
            return this.expr.getGroup();
        }
    }

    private static interface FragmentCreator {
        public SchnorrFragment createFragment(Substitution var1);

        public void forEachVariable(Consumer<VariableExpression> var1);

        default public Group getGroupOfVariable(VariableExpression var) {
            throw new IllegalArgumentException("Cannot infer group type for var");
        }
    }

    public static class BaseSchnorrProofBuilder {
        public final Zn zn;
        protected Map<String, FragmentCreator> fragmentCreators = new HashMap<String, FragmentCreator>();

        public BaseSchnorrProofBuilder(Zn zn) {
            this.zn = zn;
        }

        public BaseSchnorrProofBuilder addLinearStatement(String name, GroupEqualityExpr statement) {
            this.fragmentCreators.put(name, new LinearFragmentCreator(statement));
            return this;
        }

        public BaseSchnorrProofBuilder addLinearExponentStatement(String name, ExponentEqualityExpr statement) {
            this.fragmentCreators.put(name, new LinearExponentFragmentCreator(statement, this.zn));
            return this;
        }

        public BaseSchnorrProofBuilder addSmallerThanPowerStatement(String name, ExponentExpr smallValue, int base, int power, SetMembershipPublicParameters setMembershipPp) {
            this.fragmentCreators.put(name, new SmallerThanPowerFragmentCreator(smallValue, base, power, setMembershipPp));
            return this;
        }

        public AdHocSchnorrProof build() {
            return new AdHocSchnorrProof(this.zn, this.fragmentCreators);
        }

        public FiatShamirProofSystem buildFiatShamir() {
            return new FiatShamirProofSystem(this.build());
        }

        public DamgardTechnique buildInteractiveDamgard(CommitmentScheme commitmentSchemeForDamgard) {
            return new DamgardTechnique(this.build(), commitmentSchemeForDamgard);
        }
    }

    public static class BaseSchnorrProofInput
    implements SecretInput {
        public final Function<String, ?> witnessSource;

        public BaseSchnorrProofInput(Function<String, ?> witnessSource) {
            this.witnessSource = witnessSource;
        }
    }
}

