package com.github.cschen1205.navigator.minefield.agents;

import com.github.cschen1205.falcon.FalconConfig;
import com.github.cschen1205.falcon.QValue;
import com.github.cschen1205.falcon.QValueProvider;
import com.github.cschen1205.falcon.TDFalcon;
import com.github.cschen1205.falcon.TDLambdaFalcon;
import com.github.cschen1205.falcon.TDMethod;
import com.github.cschen1205.navigator.minefield.env.MineField;

/* loaded from: input_file:com/github/cschen1205/navigator/minefield/agents/TDFalconNavAgent.class */
public class TDFalconNavAgent extends FalconNavAgent {
    private TDFalcon ai;
    public boolean useImmediateRewardAsQ;

    public TDFalconNavAgent(FalconConfig falconConfig, int i, int i2, int i3, int i4, int i5) {
        super(i, i2, i3, i4, i5);
        this.ai = new TDFalcon(falconConfig);
    }

    public TDFalconNavAgent(FalconConfig falconConfig, int i, TDMethod tDMethod, int i2, int i3, int i4, int i5) {
        super(i, i2, i3, i4, i5);
        this.ai = new TDFalcon(falconConfig, tDMethod);
    }

    public void decayQEpsilon() {
        this.ai.decayQEpsilon();
    }

    @Override // com.github.cschen1205.navigator.minefield.agents.FalconNavAgent
    public void learn(MineField mineField) {
        this.ai.learnQ(this.state, this.actions, this.newState, getFeasibleActions(mineField), this.reward, createQInject(mineField));
    }

    protected QValueProvider createQInject(final MineField mineField) {
        return new QValueProvider() { // from class: com.github.cschen1205.navigator.minefield.agents.TDFalconNavAgent.1
            public QValue queryQValue(double[] dArr, int i, boolean z) {
                if (TDFalconNavAgent.this.useImmediateRewardAsQ) {
                    return new QValue(TDFalconNavAgent.this.reward);
                }
                if (z) {
                    if (mineField.willHitMine(TDFalconNavAgent.this.getId(), i - 2)) {
                        return new QValue(0.0d);
                    }
                    if (mineField.willHitTarget(TDFalconNavAgent.this.getId(), i - 2)) {
                        return new QValue(1.0d);
                    }
                } else {
                    if (mineField.isHitMine(TDFalconNavAgent.this.getId())) {
                        return new QValue(0.0d);
                    }
                    if (mineField.isHitTarget(TDFalconNavAgent.this.getId())) {
                        return new QValue(1.0d);
                    }
                }
                return QValue.Invalid();
            }
        };
    }

    @Override // com.github.cschen1205.navigator.minefield.agents.FalconNavAgent
    public int selectValidAction(MineField mineField) {
        return this.ai.selectActionId(this.state, getFeasibleActions(mineField), createQInject(mineField));
    }

    @Override // com.github.cschen1205.navigator.minefield.agents.FalconNavAgent
    public int getNodeCount() {
        return this.ai.nodes.size();
    }

    public void setQGamma(double d) {
        this.ai.QGamma = d;
    }

    public void enableEligibilityTrace() {
        this.ai = new TDLambdaFalcon(this.ai.getConfig(), this.ai.method);
    }
}
