/*
 * Decompiled with CFR 0.152.
 */
package org.infinispan.server.jgroups;

import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.function.Supplier;
import javax.security.auth.callback.CallbackHandler;
import org.infinispan.server.jgroups.TopologyAddressGenerator;
import org.infinispan.server.jgroups.logging.JGroupsLogger;
import org.infinispan.server.jgroups.security.RealmAuthorizationCallbackHandler;
import org.infinispan.server.jgroups.security.SaslClientCallbackHandler;
import org.infinispan.server.jgroups.spi.ChannelFactory;
import org.infinispan.server.jgroups.spi.ProtocolConfiguration;
import org.infinispan.server.jgroups.spi.ProtocolStackConfiguration;
import org.infinispan.server.jgroups.spi.RelayConfiguration;
import org.infinispan.server.jgroups.spi.RemoteSiteConfiguration;
import org.infinispan.server.jgroups.spi.SaslConfiguration;
import org.infinispan.server.jgroups.spi.TransportConfiguration;
import org.jboss.as.domain.management.SecurityRealm;
import org.jboss.as.network.SocketBinding;
import org.jboss.modules.ModuleIdentifier;
import org.jboss.modules.ModuleLoadException;
import org.jgroups.Header;
import org.jgroups.JChannel;
import org.jgroups.Message;
import org.jgroups.annotations.Property;
import org.jgroups.blocks.RequestCorrelator;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.conf.ProtocolStackConfigurator;
import org.jgroups.fork.UnknownForkHandler;
import org.jgroups.protocols.FORK;
import org.jgroups.protocols.SASL;
import org.jgroups.protocols.TP;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.relay.RELAY2;
import org.jgroups.protocols.relay.RouteStatusListener;
import org.jgroups.protocols.relay.config.RelayConfig;
import org.jgroups.stack.AddressGenerator;
import org.jgroups.stack.Configurator;
import org.jgroups.stack.Protocol;
import org.jgroups.stack.ProtocolStack;
import org.jgroups.util.StackType;
import org.jgroups.util.Util;
import org.wildfly.security.manager.WildFlySecurityManager;

public class JChannelFactory
implements ChannelFactory,
ProtocolStackConfigurator {
    static final ByteBuffer UNKNOWN_FORK_RESPONSE = ByteBuffer.allocate(0);
    private final ProtocolStackConfiguration configuration;

    public JChannelFactory(ProtocolStackConfiguration configuration) {
        this.configuration = configuration;
    }

    @Override
    public ProtocolStackConfiguration getProtocolStackConfiguration() {
        return this.configuration;
    }

    @Override
    public JChannel createChannel(String id) throws Exception {
        JGroupsLogger.ROOT_LOGGER.debugf("Creating channel %s from stack %s", id, this.configuration.getName());
        PrivilegedExceptionAction<JChannel> action = () -> new JChannel((ProtocolStackConfigurator)this);
        final JChannel channel = (JChannel)WildFlySecurityManager.doChecked(action);
        ProtocolStack stack = channel.getProtocolStack();
        TP transport = stack.getTransport();
        this.init(transport);
        RelayConfiguration relayConfig = this.configuration.getRelay();
        StackType ipStackType = Util.getIpStackType();
        if (relayConfig != null) {
            String localSite = relayConfig.getSiteName();
            List<RemoteSiteConfiguration> remoteSites = this.configuration.getRelay().getRemoteSites();
            ArrayList<String> sites = new ArrayList<String>(remoteSites.size() + 1);
            sites.add(localSite);
            HashMap<String, 1> bridges = new HashMap<String, 1>();
            for (final RemoteSiteConfiguration remoteSiteConfiguration : remoteSites) {
                String siteName = remoteSiteConfiguration.getName();
                sites.add(siteName);
                String clusterName = remoteSiteConfiguration.getClusterName();
                RelayConfig.BridgeConfig bridge = new RelayConfig.BridgeConfig(clusterName){

                    public JChannel createChannel() throws Exception {
                        JChannel channel = remoteSiteConfiguration.getChannel();
                        channel.getProtocolStack().removeProtocol(FORK.class);
                        return channel;
                    }
                };
                bridges.put(clusterName, bridge);
            }
            RELAY2 relay = new RELAY2().site(localSite);
            relay.setRouteStatusListener((RouteStatusListener)new DefaultRouteStatusListener());
            for (String site : sites) {
                RelayConfig.SiteConfig siteConfig = new RelayConfig.SiteConfig(site);
                relay.addSite(site, siteConfig);
                if (!site.equals(localSite)) continue;
                for (RelayConfig.BridgeConfig bridge : bridges.values()) {
                    siteConfig.addBridge(bridge);
                }
            }
            org.jgroups.conf.ProtocolConfiguration protocolConfiguration = new org.jgroups.conf.ProtocolConfiguration(relay.getName(), relayConfig.getProperties());
            Configurator.initializeAttrs((Protocol)relay, (org.jgroups.conf.ProtocolConfiguration)protocolConfiguration, (StackType)ipStackType);
            stack.addProtocol((Protocol)relay);
            relay.init();
        }
        UnknownForkHandler unknownForkHandler = new UnknownForkHandler(){
            private final short id = ClassConfigurator.getProtocolId(RequestCorrelator.class);

            public Object handleUnknownForkStack(Message message, String forkStackId) {
                return this.handle(message);
            }

            public Object handleUnknownForkChannel(Message message, String forkChannelId) {
                return this.handle(message);
            }

            private Object handle(Message message) {
                RequestCorrelator.Header header = (RequestCorrelator.Header)message.getHeader(this.id);
                if (header != null && header.type == 0 && header.rspExpected()) {
                    Message response = message.makeReply().setFlag(message.getFlags());
                    response.putHeader(FORK.ID, message.getHeader(FORK.ID));
                    response.putHeader(this.id, (Header)new RequestCorrelator.Header(1, header.req_id, this.id));
                    response.setBuffer(UNKNOWN_FORK_RESPONSE.array());
                    channel.down(response);
                }
                return null;
            }
        };
        FORK fork = new FORK();
        fork.setUnknownForkHandler(unknownForkHandler);
        stack.addProtocol((Protocol)fork);
        fork.init();
        SaslConfiguration saslConfig = this.configuration.getSasl();
        if (saslConfig != null) {
            String clusterRole = saslConfig.getClusterRole();
            SecurityRealm securityRealm = saslConfig.getSecurityRealm();
            String string = saslConfig.getMech();
            SASL sasl = new SASL();
            sasl.setMech(string);
            Map<String, String> props = saslConfig.getProperties();
            if (props.containsKey("client_password")) {
                String credential = props.get("client_password");
                String name = props.get("client_name");
                if (name == null) {
                    sasl.setClientCallbackHandler((CallbackHandler)new SaslClientCallbackHandler(securityRealm.getName(), this.configuration.getNodeName(), credential));
                } else if (name.contains("@")) {
                    sasl.setClientCallbackHandler((CallbackHandler)new SaslClientCallbackHandler(name, credential));
                } else {
                    sasl.setClientCallbackHandler((CallbackHandler)new SaslClientCallbackHandler(securityRealm.getName(), name, credential));
                }
            } else {
                props.put("client_password", "");
            }
            org.jgroups.conf.ProtocolConfiguration protocolConfiguration = new org.jgroups.conf.ProtocolConfiguration(sasl.getName(), props);
            Configurator.initializeAttrs((Protocol)sasl, (org.jgroups.conf.ProtocolConfiguration)protocolConfiguration, (StackType)ipStackType);
            sasl.setServerCallbackHandler((CallbackHandler)new RealmAuthorizationCallbackHandler(securityRealm, string, clusterRole != null ? clusterRole : id, sasl.getSaslProps()));
            channel.getProtocolStack().insertProtocol((Protocol)sasl, ProtocolStack.Position.BELOW, GMS.class);
            sasl.init();
        }
        channel.setName(this.configuration.getNodeName());
        TransportConfiguration.Topology topology = this.configuration.getTransport().getTopology();
        if (topology != null) {
            channel.addAddressGenerator((AddressGenerator)new TopologyAddressGenerator(topology));
        }
        return channel;
    }

    @Override
    public boolean isUnknownForkResponse(ByteBuffer buffer) {
        return UNKNOWN_FORK_RESPONSE.equals(buffer);
    }

    private void init(TP transport) {
        TransportConfiguration transportConfig = this.configuration.getTransport();
        SocketBinding binding = transportConfig.getSocketBinding();
    }

    public String getProtocolStackString() {
        return null;
    }

    public List<org.jgroups.conf.ProtocolConfiguration> getProtocolStack() {
        SocketBinding diagnosticsSocketBinding;
        ArrayList<org.jgroups.conf.ProtocolConfiguration> stack = new ArrayList<org.jgroups.conf.ProtocolConfiguration>(this.configuration.getProtocols().size() + 1);
        TransportConfiguration transport = this.configuration.getTransport();
        org.jgroups.conf.ProtocolConfiguration protocol = JChannelFactory.createProtocol(this.configuration, transport);
        Map properties = protocol.getProperties();
        Introspector introspector = new Introspector(protocol);
        SocketBinding binding = transport.getSocketBinding();
        if (binding != null) {
            JChannelFactory.configureBindAddress(introspector, protocol, binding);
            JChannelFactory.configureServerSocket(introspector, protocol, "bind_port", binding);
            JChannelFactory.configureMulticastSocket(introspector, protocol, "mcast_addr", "mcast_port", binding);
        }
        boolean diagnostics = (diagnosticsSocketBinding = transport.getDiagnosticsSocketBinding()) != null;
        properties.put("enable_diagnostics", String.valueOf(diagnostics));
        if (diagnostics) {
            JChannelFactory.configureMulticastSocket(introspector, protocol, "diagnostics_addr", "diagnostics_port", diagnosticsSocketBinding);
        }
        stack.add(protocol);
        Class<TP> transportClass = introspector.getProtocolClass().asSubclass(TP.class);
        PrivilegedExceptionAction<TP> action = transportClass::newInstance;
        try {
            stack.addAll(JChannelFactory.createProtocols(this.configuration, ((TP)WildFlySecurityManager.doChecked(action)).isMulticastCapable()));
        }
        catch (PrivilegedActionException e) {
            throw new IllegalStateException(e.getCause());
        }
        return stack;
    }

    static List<org.jgroups.conf.ProtocolConfiguration> createProtocols(ProtocolStackConfiguration stack, boolean multicastCapable) {
        List<ProtocolConfiguration> protocols = stack.getProtocols();
        ArrayList<org.jgroups.conf.ProtocolConfiguration> result = new ArrayList<org.jgroups.conf.ProtocolConfiguration>(protocols.size());
        TransportConfiguration transport = stack.getTransport();
        for (ProtocolConfiguration protocol : protocols) {
            org.jgroups.conf.ProtocolConfiguration config = JChannelFactory.createProtocol(stack, protocol);
            Introspector introspector = new Introspector(config);
            SocketBinding binding = protocol.getSocketBinding();
            if (binding != null) {
                JChannelFactory.configureBindAddress(introspector, config, binding);
                JChannelFactory.configureServerSocket(introspector, config, "bind_port", binding);
                JChannelFactory.configureServerSocket(introspector, config, "start_port", binding);
                JChannelFactory.configureMulticastSocket(introspector, config, "mcast_addr", "mcast_port", binding);
            } else if (transport.getSocketBinding() != null) {
                JChannelFactory.configureBindAddress(introspector, config, transport.getSocketBinding());
            }
            if (!multicastCapable) {
                JChannelFactory.setProperty(introspector, config, "use_mcast_xmit", String.valueOf(false));
                JChannelFactory.setProperty(introspector, config, "use_mcast_xmit_req", String.valueOf(false));
            }
            result.add(config);
        }
        return result;
    }

    private static org.jgroups.conf.ProtocolConfiguration createProtocol(ProtocolStackConfiguration stack, ProtocolConfiguration protocol) {
        String protocolName = protocol.getName();
        ModuleIdentifier module = protocol.getModule();
        HashMap<String, String> properties = new HashMap<String, String>(stack.getDefaultProperties(protocolName));
        properties.putAll(protocol.getProperties());
        try {
            return new org.jgroups.conf.ProtocolConfiguration(protocol.getProtocolClassName(), properties, (ClassLoader)stack.getModuleLoader().loadModule(module).getClassLoader());
        }
        catch (ModuleLoadException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static void configureBindAddress(Introspector introspector, org.jgroups.conf.ProtocolConfiguration config, SocketBinding binding) {
        JChannelFactory.setSocketBindingProperty(introspector, config, "bind_addr", binding.getSocketAddress().getAddress().getHostAddress());
    }

    private static void configureServerSocket(Introspector introspector, org.jgroups.conf.ProtocolConfiguration config, String property, SocketBinding binding) {
        JChannelFactory.setSocketBindingProperty(introspector, config, property, String.valueOf(binding.getSocketAddress().getPort()));
    }

    private static void configureMulticastSocket(Introspector introspector, org.jgroups.conf.ProtocolConfiguration config, String addressProperty, String portProperty, SocketBinding binding) {
        try {
            InetSocketAddress mcastSocketAddress = binding.getMulticastSocketAddress();
            JChannelFactory.setSocketBindingProperty(introspector, config, addressProperty, mcastSocketAddress.getAddress().getHostAddress());
            JChannelFactory.setSocketBindingProperty(introspector, config, portProperty, String.valueOf(mcastSocketAddress.getPort()));
        }
        catch (IllegalStateException e) {
            JGroupsLogger.ROOT_LOGGER.couldNotSetAddressAndPortNoMulticastSocket(e, config.getProtocolName(), addressProperty, config.getProtocolName(), portProperty, binding.getName());
        }
    }

    private static void setSocketBindingProperty(Introspector introspector, org.jgroups.conf.ProtocolConfiguration config, String name, String value) {
        try {
            Map properties = config.getProperties();
            if (properties.containsKey(name)) {
                JGroupsLogger.ROOT_LOGGER.unableToOverrideSocketBindingValue(name, config.getProtocolName(), value, properties.get(name));
            }
            JChannelFactory.setProperty(introspector, config, name, value);
        }
        catch (Exception e) {
            JGroupsLogger.ROOT_LOGGER.unableToAccessProtocolPropertyValue(e, name, config.getProtocolName());
        }
    }

    private static void setProperty(Introspector introspector, org.jgroups.conf.ProtocolConfiguration config, String name, String value) {
        if (introspector.hasProperty(name)) {
            config.getProperties().put(name, value);
        }
    }

    static class DefaultRouteStatusListener
    implements RouteStatusListener,
    Supplier<Set<String>> {
        private final Set<String> view = new ConcurrentSkipListSet<String>();

        DefaultRouteStatusListener() {
        }

        public void sitesUp(String ... sites) {
            JGroupsLogger log = JGroupsLogger.ROOT_LOGGER;
            if (log.isTraceEnabled()) {
                log.tracef("Joined x-site view: %s", Arrays.toString(sites));
            }
            this.view.addAll(Arrays.asList(sites));
            log.receivedXSiteClusterView(this.view);
        }

        public void sitesDown(String ... sites) {
            JGroupsLogger log = JGroupsLogger.ROOT_LOGGER;
            if (log.isTraceEnabled()) {
                log.tracef("Left x-site view: %s", Arrays.toString(sites));
            }
            this.view.removeAll(Arrays.asList(sites));
            log.receivedXSiteClusterView(this.view);
        }

        @Override
        public Set<String> get() {
            return Collections.unmodifiableSet(this.view);
        }
    }

    private static class Introspector {
        final Class<? extends Protocol> protocolClass;
        final Set<String> properties = new HashSet<String>();

        Introspector(org.jgroups.conf.ProtocolConfiguration config) {
            String name = config.getProtocolName();
            try {
                this.protocolClass = config.getClassLoader().loadClass(name).asSubclass(Protocol.class);
                PrivilegedAction<Void> action = () -> {
                    Class<? extends Protocol> targetClass = this.protocolClass;
                    while (Protocol.class.isAssignableFrom(targetClass)) {
                        String property;
                        for (Method method : targetClass.getDeclaredMethods()) {
                            if (!method.isAnnotationPresent(Property.class) || (property = method.getAnnotation(Property.class).name()).isEmpty()) continue;
                            this.properties.add(property);
                        }
                        for (AccessibleObject accessibleObject : targetClass.getDeclaredFields()) {
                            if (!accessibleObject.isAnnotationPresent(Property.class)) continue;
                            property = ((Field)accessibleObject).getAnnotation(Property.class).name();
                            this.properties.add(!property.isEmpty() ? property : ((Field)accessibleObject).getName());
                        }
                        targetClass = targetClass.getSuperclass();
                    }
                    return null;
                };
                WildFlySecurityManager.doChecked(action);
            }
            catch (ClassNotFoundException e) {
                throw new IllegalArgumentException(e);
            }
        }

        Class<? extends Protocol> getProtocolClass() {
            return this.protocolClass;
        }

        boolean hasProperty(String property) {
            return this.properties.contains(property);
        }
    }
}

