/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.northstar.strategy.trainer;

import com.alibaba.fastjson2.JSONObject;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Generated;
import org.dromara.northstar.common.IGatewayService;
import org.dromara.northstar.common.IModuleService;
import org.dromara.northstar.common.ObjectManager;
import org.dromara.northstar.common.constant.ChannelType;
import org.dromara.northstar.common.constant.DateTimeConstant;
import org.dromara.northstar.common.constant.GatewayUsage;
import org.dromara.northstar.common.model.ComponentAndParamsPair;
import org.dromara.northstar.common.model.ContractSimpleInfo;
import org.dromara.northstar.common.model.GatewayDescription;
import org.dromara.northstar.common.model.Identifier;
import org.dromara.northstar.common.model.ModuleAccountDescription;
import org.dromara.northstar.common.model.ModuleRuntimeDescription;
import org.dromara.northstar.gateway.Gateway;
import org.dromara.northstar.gateway.IContractManager;
import org.dromara.northstar.gateway.MarketGateway;
import org.dromara.northstar.gateway.TradeGateway;
import org.dromara.northstar.strategy.IModule;
import org.dromara.northstar.strategy.tester.ModuleTesterContext;
import org.dromara.northstar.strategy.trainer.AbstractTrainer;
import org.dromara.northstar.strategy.trainer.RLAgentTrainingContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractSerialTrainer
extends AbstractTrainer
implements RLAgentTrainingContext {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(AbstractSerialTrainer.class);

    protected AbstractSerialTrainer(ObjectManager<Gateway> gatewayMgr, ObjectManager<IModule> moduleMgr, IContractManager contractMgr, IGatewayService gatewayService, IModuleService moduleService) {
        super(gatewayMgr, moduleMgr, contractMgr, gatewayService, moduleService);
    }

    @Override
    public void start() {
        for (ContractSimpleInfo csi : this.testContracts()) {
            MarketGateway mktGateway = this.createPlaybackGateway(this, csi.getName());
            TradeGateway tdGateway = this.createSimGateway(mktGateway);
            String symbol = csi.getUnifiedSymbol().replaceAll("\\d+.+$", "");
            tdGateway.connect();
            this.gatewayService.simMoneyIO(tdGateway.gatewayId(), this.symbolTestAmount().get(symbol));
            ModuleAccountDescription mad = ModuleAccountDescription.builder().accountGatewayId(tdGateway.gatewayId()).bindedContracts(List.of(csi)).build();
            ComponentAndParamsPair strategySettings = ComponentAndParamsPair.builder().componentMeta(this.strategy()).initParams(this.convertParams(this.strategyParams(csi))).build();
            this.createModules(mad, strategySettings, csi);
            for (int i = 0; i < this.maxTrainingEpisodes(); ++i) {
                long numOfConverged;
                log.info("\u3010{}\u5408\u7ea6\u3011 \u5f00\u59cb\u7b2c{}\u4e2a\u56de\u5408\u8bad\u7ec3", (Object)csi.getName(), (Object)(i + 1));
                this.moduleMgr.findAll().forEach(this::resetModule);
                List<IModule> traineeModules = this.moduleMgr.findAll();
                this.gatewayService.resetPlayback(mktGateway.gatewayId());
                mktGateway = (MarketGateway)this.gatewayMgr.get(Identifier.of(mktGateway.gatewayId()));
                mktGateway.connect();
                while (mktGateway.isActive()) {
                    log.info("\u6570\u636e\u9884\u70ed\u4e2d");
                    this.pause(5);
                }
                this.pause(5 * traineeModules.size());
                log.info("\u6570\u636e\u9884\u70ed\u5b8c\u6210");
                traineeModules.forEach(m -> m.setEnabled(true));
                this.pause(1);
                mktGateway.connect();
                while (mktGateway.isActive() && traineeModules.stream().anyMatch(IModule::isEnabled)) {
                    this.pause(30);
                    log.info("\u3010{}\u5408\u7ea6\u3011 \u7b2c{}\u4e2a\u56de\u5408\u8bad\u7ec3\u4e2d", (Object)csi.getName(), (Object)(i + 1));
                }
                this.pause(30);
                log.info("\u3010{}\u5408\u7ea6\u3011 \u7b2c{}\u4e2a\u56de\u5408\u8bad\u7ec3\u7ed3\u675f", (Object)csi.getName(), (Object)(i + 1));
                if (!this.mrdMap.isEmpty() && (numOfConverged = traineeModules.stream().map(m -> m.getModuleContext().getRuntimeDescription(false)).filter(mrd -> this.hasPerformanceConverged((ModuleRuntimeDescription)this.mrdMap.get(mrd.getModuleName()), (ModuleRuntimeDescription)mrd)).count()) > (long)(traineeModules.size() / 2)) {
                    log.info("\u3010{}\u5408\u7ea6\u3011 \u6a21\u7ec4\u603b\u6570\u4e3a{}\uff0c\u5176\u4e2d\u6709{}\u4e2a\u6a21\u7ec4\u5df2\u7ecf\u6536\u655b", new Object[]{csi.getName(), traineeModules.size(), numOfConverged});
                    break;
                }
                this.mrdMap = traineeModules.stream().collect(Collectors.toMap(IModule::getName, m -> m.getModuleContext().getRuntimeDescription(false)));
            }
            this.moduleMgr.findAll().stream().forEach(m -> this.moduleService.removeModule(m.getName()));
        }
    }

    private MarketGateway createPlaybackGateway(ModuleTesterContext ctx, String symbolName) {
        String gatewayId = "\u5386\u53f2\u56de\u653e_" + symbolName;
        List<ContractSimpleInfo> contracts = this.testContracts();
        JSONObject settings = new JSONObject();
        settings.put((Object)"preStartDate", (Object)ctx.preStartDate().format(DateTimeConstant.D_FORMAT_INT_FORMATTER));
        settings.put((Object)"startDate", (Object)ctx.startDate().format(DateTimeConstant.D_FORMAT_INT_FORMATTER));
        settings.put((Object)"endDate", (Object)ctx.endDate().format(DateTimeConstant.D_FORMAT_INT_FORMATTER));
        settings.put((Object)"precision", (Object)ctx.precision());
        settings.put((Object)"speed", (Object)ctx.speed());
        settings.put((Object)"playContracts", contracts);
        GatewayDescription gd = GatewayDescription.builder().gatewayId(gatewayId).gatewayUsage(GatewayUsage.MARKET_DATA).channelType(ChannelType.PLAYBACK).subscribedContracts(contracts).settings(settings).build();
        this.gatewayService.createGateway(gd);
        return (MarketGateway)this.gatewayMgr.get(Identifier.of(gatewayId));
    }

    protected TradeGateway createSimGateway(MarketGateway mktGateway, String symbolName) {
        String gatewayId = "\u6a21\u62df\u8d26\u6237_" + symbolName;
        GatewayDescription gd = GatewayDescription.builder().gatewayId(gatewayId).gatewayUsage(GatewayUsage.TRADE).channelType(ChannelType.SIM).bindedMktGatewayId(mktGateway.gatewayId()).settings(new JSONObject()).build();
        this.gatewayService.createGateway(gd);
        return (TradeGateway)this.gatewayMgr.get(Identifier.of(gatewayId));
    }
}

