/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.soul.web.plugin.function;

import java.net.URI;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.dromara.soul.common.dto.RuleData;
import org.dromara.soul.common.dto.SelectorData;
import org.dromara.soul.common.dto.convert.DivideUpstream;
import org.dromara.soul.common.dto.convert.rule.DivideRuleHandle;
import org.dromara.soul.common.enums.PluginEnum;
import org.dromara.soul.common.enums.PluginTypeEnum;
import org.dromara.soul.common.enums.RpcTypeEnum;
import org.dromara.soul.common.utils.GsonUtils;
import org.dromara.soul.common.utils.LogUtils;
import org.dromara.soul.web.balance.utils.LoadBalanceUtils;
import org.dromara.soul.web.cache.LocalCacheManager;
import org.dromara.soul.web.cache.UpstreamCacheManager;
import org.dromara.soul.web.plugin.AbstractSoulPlugin;
import org.dromara.soul.web.plugin.SoulPluginChain;
import org.dromara.soul.web.request.RequestDTO;
import org.dromara.soul.web.result.SoulResultEnum;
import org.dromara.soul.web.result.SoulResultUtils;
import org.dromara.soul.web.result.SoulResultWarp;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;

public class WebSocketPlugin
extends AbstractSoulPlugin {
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketPlugin.class);
    private static final String SEC_WEB_SOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
    private final UpstreamCacheManager upstreamCacheManager;
    private final WebSocketClient webSocketClient;
    private final WebSocketService webSocketService;

    public WebSocketPlugin(LocalCacheManager localCacheManager, UpstreamCacheManager upstreamCacheManager, WebSocketClient webSocketClient, WebSocketService webSocketService) {
        super(localCacheManager);
        this.upstreamCacheManager = upstreamCacheManager;
        this.webSocketClient = webSocketClient;
        this.webSocketService = webSocketService;
    }

    @Override
    protected Mono<Void> doExecute(ServerWebExchange exchange, SoulPluginChain chain, SelectorData selector, RuleData rule) {
        RequestDTO requestDTO;
        List<DivideUpstream> upstreamList;
        block5: {
            block4: {
                upstreamList = this.upstreamCacheManager.findUpstreamListBySelectorId(selector.getId());
                requestDTO = (RequestDTO)exchange.getAttribute("requestDTO");
                if (CollectionUtils.isEmpty(upstreamList)) break block4;
                if (!Objects.isNull(requestDTO)) break block5;
            }
            LogUtils.error((Logger)LOGGER, (String)"divide upstream configuration error\uff1a{}", () -> ((RuleData)rule).toString());
            return chain.execute(exchange);
        }
        DivideRuleHandle ruleHandle = (DivideRuleHandle)GsonUtils.getInstance().fromJson(rule.getHandle(), DivideRuleHandle.class);
        String ip = Objects.requireNonNull(exchange.getRequest().getRemoteAddress()).getAddress().getHostAddress();
        DivideUpstream divideUpstream = LoadBalanceUtils.selector(upstreamList, ruleHandle.getLoadBalance(), ip);
        if (Objects.isNull(divideUpstream)) {
            LOGGER.error("websocket has no upstream");
            Object error = SoulResultWarp.error(SoulResultEnum.CANNOT_FIND_URL.getCode(), SoulResultEnum.CANNOT_FIND_URL.getMsg(), null);
            return SoulResultUtils.result(exchange, error);
        }
        URI wsRequestUrl = UriComponentsBuilder.fromUri((URI)URI.create(this.buildWsRealPath(divideUpstream, requestDTO))).build().toUri();
        LOGGER.info("you websocket urlPath is :{}", (Object)wsRequestUrl.toASCIIString());
        HttpHeaders headers = exchange.getRequest().getHeaders();
        return this.webSocketService.handleRequest(exchange, (WebSocketHandler)new SoulWebSocketHandler(wsRequestUrl, this.webSocketClient, this.filterHeaders(headers), this.buildWsProtocols(headers)));
    }

    private String buildWsRealPath(DivideUpstream divideUpstream, RequestDTO requestDTO) {
        String protocol = divideUpstream.getProtocol();
        if (StringUtils.isEmpty((Object)protocol)) {
            protocol = "ws://";
        }
        return protocol + divideUpstream.getUpstreamUrl() + requestDTO.getMethod();
    }

    private List<String> buildWsProtocols(HttpHeaders headers) {
        List protocols = headers.get((Object)SEC_WEB_SOCKET_PROTOCOL);
        if (CollectionUtils.isNotEmpty((Collection)protocols)) {
            protocols = protocols.stream().flatMap(header -> Arrays.stream(StringUtils.commaDelimitedListToStringArray((String)header))).map(String::trim).collect(Collectors.toList());
        }
        return protocols;
    }

    private HttpHeaders filterHeaders(HttpHeaders headers) {
        HttpHeaders filtered = new HttpHeaders();
        headers.entrySet().stream().filter(entry -> !((String)entry.getKey()).toLowerCase().startsWith("sec-websocket")).forEach(header -> filtered.addAll((String)header.getKey(), (List)header.getValue()));
        return filtered;
    }

    @Override
    public String named() {
        return PluginEnum.DIVIDE.getName();
    }

    @Override
    public Boolean skip(ServerWebExchange exchange) {
        RequestDTO body = (RequestDTO)exchange.getAttribute("requestDTO");
        return !Objects.equals(Objects.requireNonNull(body).getRpcType(), RpcTypeEnum.WEB_SOCKET.getName());
    }

    @Override
    public PluginTypeEnum pluginType() {
        return PluginTypeEnum.FUNCTION;
    }

    @Override
    public int getOrder() {
        return PluginEnum.WEB_SOCKET.getCode();
    }

    private static class SoulWebSocketHandler
    implements WebSocketHandler {
        private final WebSocketClient client;
        private final URI url;
        private final HttpHeaders headers;
        private final List<String> subProtocols;

        SoulWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) {
            this.client = client;
            this.url = url;
            this.headers = headers;
            this.subProtocols = protocols != null ? protocols : Collections.emptyList();
        }

        public List<String> getSubProtocols() {
            return this.subProtocols;
        }

        public Mono<Void> handle(final WebSocketSession session) {
            return this.client.execute(this.url, this.headers, new WebSocketHandler(){

                public Mono<Void> handle(WebSocketSession webSocketSession) {
                    Mono sessionSend = webSocketSession.send((Publisher)session.receive().doOnNext(WebSocketMessage::retain));
                    Mono serverSessionSend = session.send((Publisher)webSocketSession.receive().doOnNext(WebSocketMessage::retain));
                    return Mono.zip((Mono)sessionSend, (Mono)serverSessionSend).then();
                }

                public List<String> getSubProtocols() {
                    return subProtocols;
                }
            });
        }
    }
}

