/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.network.sasl;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import java.io.File;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.SaslClientBootstrap;
import org.apache.spark.network.sasl.SaslEncryption;
import org.apache.spark.network.sasl.SaslEncryptionBackend;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.sasl.SparkSaslClient;
import org.apache.spark.network.sasl.SparkSaslServer;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.ConfigProvider;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;

public class SparkSaslSuite {
    private SecretKeyHolder secretKeyHolder = new SecretKeyHolder(){

        public String getSaslUser(String appId) {
            return "user";
        }

        public String getSecretKey(String appId) {
            return appId;
        }
    };

    @Test
    public void testMatching() {
        SparkSaslClient client = new SparkSaslClient("shared-secret", this.secretKeyHolder, false);
        SparkSaslServer server = new SparkSaslServer("shared-secret", this.secretKeyHolder, false);
        Assert.assertFalse((boolean)client.isComplete());
        Assert.assertFalse((boolean)server.isComplete());
        byte[] clientMessage = client.firstToken();
        while (!client.isComplete()) {
            clientMessage = client.response(server.response(clientMessage));
        }
        Assert.assertTrue((boolean)server.isComplete());
        server.dispose();
        Assert.assertFalse((boolean)server.isComplete());
        client.dispose();
        Assert.assertFalse((boolean)client.isComplete());
    }

    @Test
    public void testNonMatching() {
        SparkSaslClient client = new SparkSaslClient("my-secret", this.secretKeyHolder, false);
        SparkSaslServer server = new SparkSaslServer("your-secret", this.secretKeyHolder, false);
        Assert.assertFalse((boolean)client.isComplete());
        Assert.assertFalse((boolean)server.isComplete());
        byte[] clientMessage = client.firstToken();
        try {
            while (!client.isComplete()) {
                clientMessage = client.response(server.response(clientMessage));
            }
            Assert.fail((String)"Should not have completed");
        }
        catch (Exception e) {
            Assert.assertTrue((boolean)e.getMessage().contains("Mismatched response"));
            Assert.assertFalse((boolean)client.isComplete());
            Assert.assertFalse((boolean)server.isComplete());
        }
    }

    @Test
    public void testSaslAuthentication() throws Throwable {
        SparkSaslSuite.testBasicSasl(false);
    }

    @Test
    public void testSaslEncryption() throws Throwable {
        SparkSaslSuite.testBasicSasl(true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void testBasicSasl(boolean encrypt) throws Throwable {
        RpcHandler rpcHandler = (RpcHandler)Mockito.mock(RpcHandler.class);
        ((RpcHandler)Mockito.doAnswer(invocation -> {
            ByteBuffer message = (ByteBuffer)invocation.getArguments()[1];
            RpcResponseCallback cb = (RpcResponseCallback)invocation.getArguments()[2];
            Assert.assertEquals((Object)"Ping", (Object)JavaUtils.bytesToString((ByteBuffer)message));
            cb.onSuccess(JavaUtils.stringToBytes((String)"Pong"));
            return null;
        }).when((Object)rpcHandler)).receive((TransportClient)Mockito.any(TransportClient.class), (ByteBuffer)Mockito.any(ByteBuffer.class), (RpcResponseCallback)Mockito.any(RpcResponseCallback.class));
        SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
        try {
            ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes((String)"Ping"), TimeUnit.SECONDS.toMillis(10L));
            Assert.assertEquals((Object)"Pong", (Object)JavaUtils.bytesToString((ByteBuffer)response));
        }
        finally {
            ctx.close();
            Throwable error = null;
            long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10L, TimeUnit.SECONDS);
            while (deadline > System.nanoTime()) {
                try {
                    ((RpcHandler)Mockito.verify((Object)rpcHandler, (VerificationMode)Mockito.times((int)2))).channelInactive((TransportClient)Mockito.any(TransportClient.class));
                    error = null;
                    break;
                }
                catch (Throwable t) {
                    error = t;
                    TimeUnit.MILLISECONDS.sleep(10L);
                }
            }
            if (error != null) {
                throw error;
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testEncryptedMessage() throws Exception {
        SaslEncryptionBackend backend = (SaslEncryptionBackend)Mockito.mock(SaslEncryptionBackend.class);
        byte[] data = new byte[1024];
        new Random().nextBytes(data);
        Mockito.when((Object)backend.wrap((byte[])Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt())).thenReturn((Object)data);
        ByteBuf msg = Unpooled.buffer();
        try {
            msg.writeBytes(data);
            ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
            SaslEncryption.EncryptedMessage emsg = new SaslEncryption.EncryptedMessage(backend, (Object)msg, 1024);
            long count = emsg.transferTo((WritableByteChannel)channel, emsg.transfered());
            Assert.assertTrue((count < (long)data.length ? 1 : 0) != 0);
            Assert.assertTrue((count > 0L ? 1 : 0) != 0);
            Assert.assertEquals((long)0L, (long)emsg.transferTo((WritableByteChannel)channel, emsg.transfered()));
            channel.reset();
            Assert.assertEquals((long)1L, (long)emsg.transferTo((WritableByteChannel)channel, emsg.transfered()));
            for (int i = 0; i < data.length / 32 - 2; ++i) {
                channel.reset();
                Assert.assertEquals((long)1L, (long)emsg.transferTo((WritableByteChannel)channel, emsg.transfered()));
            }
            channel.reset();
            count = emsg.transferTo((WritableByteChannel)channel, emsg.transfered());
            Assert.assertTrue((String)("Unexpected count: " + count), (count > 1L && count < (long)data.length ? 1 : 0) != 0);
            Assert.assertEquals((long)data.length, (long)emsg.transfered());
        }
        finally {
            msg.release();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testEncryptedMessageChunking() throws Exception {
        File file = File.createTempFile("sasltest", ".txt");
        try {
            TransportConf conf = new TransportConf("shuffle", (ConfigProvider)MapConfigProvider.EMPTY);
            byte[] data = new byte[8192];
            new Random().nextBytes(data);
            Files.write((byte[])data, (File)file);
            SaslEncryptionBackend backend = (SaslEncryptionBackend)Mockito.mock(SaslEncryptionBackend.class);
            Mockito.when((Object)backend.wrap((byte[])Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt())).thenReturn((Object)data);
            FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0L, file.length());
            SaslEncryption.EncryptedMessage emsg = new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
            ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
            while (emsg.transfered() < emsg.count()) {
                channel.reset();
                emsg.transferTo((WritableByteChannel)channel, emsg.transfered());
            }
            ((SaslEncryptionBackend)Mockito.verify((Object)backend, (VerificationMode)Mockito.times((int)8))).wrap((byte[])Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt());
        }
        finally {
            file.delete();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testFileRegionEncryption() throws Exception {
        ImmutableMap testConf = ImmutableMap.of((Object)"spark.network.sasl.maxEncryptedBlockSize", (Object)"1k");
        AtomicReference response = new AtomicReference();
        File file = File.createTempFile("sasltest", ".txt");
        SaslTestCtx ctx = null;
        try {
            TransportConf conf = new TransportConf("shuffle", (ConfigProvider)new MapConfigProvider((Map)testConf));
            StreamManager sm = (StreamManager)Mockito.mock(StreamManager.class);
            Mockito.when((Object)sm.getChunk(Mockito.anyLong(), Mockito.anyInt())).thenAnswer(invocation -> new FileSegmentManagedBuffer(conf, file, 0L, file.length()));
            RpcHandler rpcHandler = (RpcHandler)Mockito.mock(RpcHandler.class);
            Mockito.when((Object)rpcHandler.getStreamManager()).thenReturn((Object)sm);
            byte[] data = new byte[8192];
            new Random().nextBytes(data);
            Files.write((byte[])data, (File)file);
            ctx = new SaslTestCtx(rpcHandler, true, false, (Map<String, String>)testConf);
            CountDownLatch lock = new CountDownLatch(1);
            ChunkReceivedCallback callback = (ChunkReceivedCallback)Mockito.mock(ChunkReceivedCallback.class);
            ((ChunkReceivedCallback)Mockito.doAnswer(invocation -> {
                response.set((ManagedBuffer)invocation.getArguments()[1]);
                ((ManagedBuffer)response.get()).retain();
                lock.countDown();
                return null;
            }).when((Object)callback)).onSuccess(Mockito.anyInt(), (ManagedBuffer)Mockito.any(ManagedBuffer.class));
            ctx.client.fetchChunk(0L, 0, callback);
            lock.await(10L, TimeUnit.SECONDS);
            ((ChunkReceivedCallback)Mockito.verify((Object)callback, (VerificationMode)Mockito.times((int)1))).onSuccess(Mockito.anyInt(), (ManagedBuffer)Mockito.any(ManagedBuffer.class));
            ((ChunkReceivedCallback)Mockito.verify((Object)callback, (VerificationMode)Mockito.never())).onFailure(Mockito.anyInt(), (Throwable)Mockito.any(Throwable.class));
            byte[] received = ByteStreams.toByteArray((InputStream)((ManagedBuffer)response.get()).createInputStream());
            Assert.assertTrue((boolean)Arrays.equals(data, received));
        }
        finally {
            file.delete();
            if (ctx != null) {
                ctx.close();
            }
            if (response.get() != null) {
                ((ManagedBuffer)response.get()).release();
            }
        }
    }

    @Test
    public void testServerAlwaysEncrypt() throws Exception {
        try (SaslTestCtx ctx = null;){
            ctx = new SaslTestCtx((RpcHandler)Mockito.mock(RpcHandler.class), false, false, (Map<String, String>)ImmutableMap.of((Object)"spark.network.sasl.serverAlwaysEncrypt", (Object)"true"));
            Assert.fail((String)"Should have failed to connect without encryption.");
        }
    }

    @Test
    public void testDataEncryptionIsActuallyEnabled() throws Exception {
        try (SaslTestCtx ctx = null;){
            ctx = new SaslTestCtx((RpcHandler)Mockito.mock(RpcHandler.class), true, true);
            ctx.client.sendRpcSync(JavaUtils.stringToBytes((String)"Ping"), TimeUnit.SECONDS.toMillis(10L));
            Assert.fail((String)"Should have failed to send RPC to server.");
        }
    }

    @Test
    public void testRpcHandlerDelegate() throws Exception {
        RpcHandler handler = (RpcHandler)Mockito.mock(RpcHandler.class);
        SaslRpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null);
        saslHandler.getStreamManager();
        ((RpcHandler)Mockito.verify((Object)handler)).getStreamManager();
        saslHandler.channelInactive(null);
        ((RpcHandler)Mockito.verify((Object)handler)).channelInactive((TransportClient)Mockito.isNull());
        saslHandler.exceptionCaught(null, null);
        ((RpcHandler)Mockito.verify((Object)handler)).exceptionCaught((Throwable)Mockito.isNull(), (TransportClient)Mockito.isNull());
    }

    @Test
    public void testDelegates() throws Exception {
        Method[] rpcHandlerMethods;
        for (Method m : rpcHandlerMethods = RpcHandler.class.getDeclaredMethods()) {
            Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
            Assert.assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
        }
    }

    private static class EncryptionDisablerBootstrap
    implements TransportClientBootstrap {
        private EncryptionDisablerBootstrap() {
        }

        public void doBootstrap(TransportClient client, Channel channel) {
            channel.pipeline().remove("saslEncryption");
        }
    }

    private static class EncryptionCheckerBootstrap
    extends ChannelOutboundHandlerAdapter
    implements TransportServerBootstrap {
        boolean foundEncryptionHandler;
        String encryptHandlerName;

        EncryptionCheckerBootstrap(String encryptHandlerName) {
            this.encryptHandlerName = encryptHandlerName;
        }

        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            if (!this.foundEncryptionHandler) {
                this.foundEncryptionHandler = ctx.channel().pipeline().get(this.encryptHandlerName) != null;
            }
            ctx.write(msg, promise);
        }

        public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
            channel.pipeline().addFirst("encryptionChecker", (ChannelHandler)this);
            return rpcHandler;
        }
    }

    private static class SaslTestCtx {
        final TransportClient client;
        final TransportServer server;
        final TransportContext ctx;
        private final boolean encrypt;
        private final boolean disableClientEncryption;
        private final EncryptionCheckerBootstrap checker;

        SaslTestCtx(RpcHandler rpcHandler, boolean encrypt, boolean disableClientEncryption) throws Exception {
            this(rpcHandler, encrypt, disableClientEncryption, Collections.emptyMap());
        }

        SaslTestCtx(RpcHandler rpcHandler, boolean encrypt, boolean disableClientEncryption, Map<String, String> extraConf) throws Exception {
            ImmutableMap testConf = ImmutableMap.builder().putAll(extraConf).put((Object)"spark.authenticate.enableSaslEncryption", (Object)String.valueOf(encrypt)).build();
            TransportConf conf = new TransportConf("shuffle", (ConfigProvider)new MapConfigProvider((Map)testConf));
            SecretKeyHolder keyHolder = (SecretKeyHolder)Mockito.mock(SecretKeyHolder.class);
            Mockito.when((Object)keyHolder.getSaslUser(Mockito.anyString())).thenReturn((Object)"user");
            Mockito.when((Object)keyHolder.getSecretKey(Mockito.anyString())).thenReturn((Object)"secret");
            this.ctx = new TransportContext(conf, rpcHandler);
            this.checker = new EncryptionCheckerBootstrap("saslEncryption");
            this.server = this.ctx.createServer(Arrays.asList(new TransportServerBootstrap[]{new SaslServerBootstrap(conf, keyHolder), this.checker}));
            try {
                ArrayList<Object> clientBootstraps = new ArrayList<Object>();
                clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder));
                if (disableClientEncryption) {
                    clientBootstraps.add(new EncryptionDisablerBootstrap());
                }
                this.client = this.ctx.createClientFactory(clientBootstraps).createClient(TestUtils.getLocalHost(), this.server.getPort());
            }
            catch (Exception e) {
                this.close();
                throw e;
            }
            this.encrypt = encrypt;
            this.disableClientEncryption = disableClientEncryption;
        }

        void close() {
            if (!this.disableClientEncryption) {
                Assert.assertEquals((Object)this.encrypt, (Object)this.checker.foundEncryptionHandler);
            }
            if (this.client != null) {
                this.client.close();
            }
            if (this.server != null) {
                this.server.close();
            }
            if (this.ctx != null) {
                this.ctx.close();
            }
        }
    }
}

