kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jun...@apache.org
Subject kafka git commit: KAFKA-2534: Fixes and unit tests for SSLTransportLayer buffer overflow
Date Wed, 07 Oct 2015 16:28:30 GMT
Repository: kafka
Updated Branches:
  refs/heads/trunk 02e103b75 -> 2047a9afe


KAFKA-2534: Fixes and unit tests for SSLTransportLayer buffer overflow

Unit tests which mock buffer overflow and underflow in the SSL transport layer and fixes for the couple of issues in buffer overflow handling described in the JIRA.

Author: Rajini Sivaram <rajinisivaram@googlemail.com>

Reviewers: Ismael Juma <ismael@juma.me.uk>, Sriharsha Chintalapani <schintalapani@hortonworks.com>, Jun Rao <junrao@gmail.com>

Closes #205 from rajinisivaram/KAFKA-2534


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/2047a9af
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/2047a9af
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/2047a9af

Branch: refs/heads/trunk
Commit: 2047a9afe1ef9cf81e6347fe27f0051c8758d226
Parents: 02e103b
Author: Rajini Sivaram <rajinisivaram@googlemail.com>
Authored: Wed Oct 7 09:28:22 2015 -0700
Committer: Jun Rao <junrao@gmail.com>
Committed: Wed Oct 7 09:28:22 2015 -0700

----------------------------------------------------------------------
 .../kafka/common/network/SSLChannelBuilder.java |  15 +-
 .../kafka/common/network/SSLTransportLayer.java |  89 ++-
 .../kafka/common/network/SSLSelectorTest.java   | 166 +----
 .../common/network/SSLTransportLayerTest.java   | 654 +++++++++++++++++++
 .../kafka/common/network/SelectorTest.java      |  30 +-
 5 files changed, 751 insertions(+), 203 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/2047a9af/clients/src/main/java/org/apache/kafka/common/network/SSLChannelBuilder.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SSLChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SSLChannelBuilder.java
index 88c218b..e2cce5c 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SSLChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SSLChannelBuilder.java
@@ -12,6 +12,7 @@
  */
 package org.apache.kafka.common.network;
 
+import java.io.IOException;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
 import java.util.Map;
@@ -48,10 +49,7 @@ public class SSLChannelBuilder implements ChannelBuilder {
     public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize) throws KafkaException {
         KafkaChannel channel = null;
         try {
-            SocketChannel socketChannel = (SocketChannel) key.channel();
-            SSLTransportLayer transportLayer = new SSLTransportLayer(id, key,
-                                                                     sslFactory.createSSLEngine(socketChannel.socket().getInetAddress().getHostName(),
-                                                                                                socketChannel.socket().getPort()));
+            SSLTransportLayer transportLayer = buildTransportLayer(sslFactory, id, key);
             Authenticator authenticator = new DefaultAuthenticator();
             authenticator.configure(transportLayer, this.principalBuilder);
             channel = new KafkaChannel(id, transportLayer, authenticator, maxReceiveSize);
@@ -65,4 +63,13 @@ public class SSLChannelBuilder implements ChannelBuilder {
     public void close()  {
         this.principalBuilder.close();
     }
+
+    protected SSLTransportLayer buildTransportLayer(SSLFactory sslFactory, String id, SelectionKey key) throws IOException {
+        SocketChannel socketChannel = (SocketChannel) key.channel();
+        SSLTransportLayer transportLayer = new SSLTransportLayer(id, key,
+                sslFactory.createSSLEngine(socketChannel.socket().getInetAddress().getHostName(),
+                                           socketChannel.socket().getPort()));
+        transportLayer.startHandshake();
+        return transportLayer;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/2047a9af/clients/src/main/java/org/apache/kafka/common/network/SSLTransportLayer.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SSLTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/SSLTransportLayer.java
index 8b4bd9f..35ea9aa 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SSLTransportLayer.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SSLTransportLayer.java
@@ -46,7 +46,7 @@ import org.slf4j.LoggerFactory;
 public class SSLTransportLayer implements TransportLayer {
     private static final Logger log = LoggerFactory.getLogger(SSLTransportLayer.class);
     private final String channelId;
-    protected final SSLEngine sslEngine;
+    private final SSLEngine sslEngine;
     private final SelectionKey key;
     private final SocketChannel socketChannel;
     private HandshakeStatus handshakeStatus;
@@ -63,16 +63,17 @@ public class SSLTransportLayer implements TransportLayer {
         this.key = key;
         this.socketChannel = (SocketChannel) key.channel();
         this.sslEngine = sslEngine;
-        this.netReadBuffer = ByteBuffer.allocate(packetBufferSize());
-        this.netWriteBuffer = ByteBuffer.allocate(packetBufferSize());
-        this.appReadBuffer = ByteBuffer.allocate(applicationBufferSize());
-        startHandshake();
     }
 
     /**
      * starts sslEngine handshake process
      */
-    private void startHandshake() throws IOException {
+    protected void startHandshake() throws IOException {
+
+        this.netReadBuffer = ByteBuffer.allocate(netReadBufferSize());
+        this.netWriteBuffer = ByteBuffer.allocate(netWriteBufferSize());
+        this.appReadBuffer = ByteBuffer.allocate(applicationBufferSize());
+        
         //clear & set netRead & netWrite buffers
         netWriteBuffer.position(0);
         netWriteBuffer.limit(0);
@@ -223,11 +224,13 @@ public class SSLTransportLayer implements TransportLayer {
                               channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
                     handshakeResult = handshakeWrap(write);
                     if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
-                        int currentPacketBufferSize = packetBufferSize();
-                        netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentPacketBufferSize);
-                        if (netWriteBuffer.position() >= currentPacketBufferSize) {
-                            throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.position() +
-                                                            ") >= network buffer size (" + currentPacketBufferSize + ")");
+                        int currentNetWriteBufferSize = netWriteBufferSize();
+                        netWriteBuffer.compact();
+                        netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize);
+                        netWriteBuffer.flip();
+                        if (netWriteBuffer.limit() >= currentNetWriteBufferSize) {
+                            throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() +
+                                                            ") >= network buffer size (" + currentNetWriteBufferSize + ")");
                         }
                     } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
                         throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP.");
@@ -245,20 +248,23 @@ public class SSLTransportLayer implements TransportLayer {
                 case NEED_UNWRAP:
                     log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
                               channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-                    handshakeResult = handshakeUnwrap(read);
+                    do {
+                        handshakeResult = handshakeUnwrap(read);
+                        if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
+                            int currentAppBufferSize = applicationBufferSize();
+                            appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize);
+                            if (appReadBuffer.position() > currentAppBufferSize) {
+                                throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() +
+                                                                ") > packet buffer size (" + currentAppBufferSize + ")");
+                            }
+                        }
+                    } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW);
                     if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
-                        int currentPacketBufferSize = packetBufferSize();
-                        netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentPacketBufferSize);
-                        if (netReadBuffer.position() >= currentPacketBufferSize) {
+                        int currentNetReadBufferSize = netReadBufferSize();
+                        netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize);
+                        if (netReadBuffer.position() >= currentNetReadBufferSize) {
                             throw new IllegalStateException("Buffer underflow when there is available data");
                         }
-                    } else if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
-                        int currentAppBufferSize = applicationBufferSize();
-                        appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize);
-                        if (appReadBuffer.position() > currentAppBufferSize) {
-                            throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() +
-                                                            ") > packet buffer size (" + currentAppBufferSize + ")");
-                        }
                     } else if (handshakeResult.getStatus() == Status.CLOSED) {
                         throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP");
                     }
@@ -285,6 +291,7 @@ public class SSLTransportLayer implements TransportLayer {
                 default:
                     throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus));
             }
+
         } catch (SSLException e) {
             handshakeFailure();
             throw e;
@@ -338,7 +345,7 @@ public class SSLTransportLayer implements TransportLayer {
     * @throws IOException
     */
     private SSLEngineResult handshakeWrap(boolean doWrite) throws IOException {
-        log.trace("SSLHandshake handshakeWrap", channelId);
+        log.trace("SSLHandshake handshakeWrap {}", channelId);
         if (netWriteBuffer.hasRemaining())
             throw new IllegalStateException("handshakeWrap called with netWriteBuffer not empty");
         //this should never be called with a network buffer that contains data
@@ -364,7 +371,7 @@ public class SSLTransportLayer implements TransportLayer {
     * @throws IOException
     */
     private SSLEngineResult handshakeUnwrap(boolean doRead) throws IOException {
-        log.trace("SSLHandshake handshakeUnwrap", channelId);
+        log.trace("SSLHandshake handshakeUnwrap {}", channelId);
         SSLEngineResult result;
         boolean cont = false;
         int read = 0;
@@ -384,7 +391,7 @@ public class SSLTransportLayer implements TransportLayer {
             }
             cont = result.getStatus() == SSLEngineResult.Status.OK &&
                 handshakeStatus == HandshakeStatus.NEED_UNWRAP;
-            log.trace("SSLHandshake handshakeUnwrap: handshakeStatus ", handshakeStatus);
+            log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus());
         } while (netReadBuffer.position() != 0 && cont);
 
         return result;
@@ -410,7 +417,7 @@ public class SSLTransportLayer implements TransportLayer {
         }
 
         if (dst.remaining() > 0) {
-            netReadBuffer = Utils.ensureCapacity(netReadBuffer, packetBufferSize());
+            netReadBuffer = Utils.ensureCapacity(netReadBuffer, netReadBufferSize());
             if (netReadBuffer.remaining() > 0) {
                 int netread = socketChannel.read(netReadBuffer);
                 if (netread == 0) return netread;
@@ -446,11 +453,11 @@ public class SSLTransportLayer implements TransportLayer {
                     else
                         break;
                 } else if (unwrapResult.getStatus() == Status.BUFFER_UNDERFLOW) {
-                    int currentPacketBufferSize = packetBufferSize();
-                    netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentPacketBufferSize);
-                    if (netReadBuffer.position() >= currentPacketBufferSize) {
+                    int currentNetReadBufferSize = netReadBufferSize();
+                    netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize);
+                    if (netReadBuffer.position() >= currentNetReadBufferSize) {
                         throw new IllegalStateException("Buffer underflow when available data size (" + netReadBuffer.position() +
-                                                        ") > packet buffer size (" + currentPacketBufferSize + ")");
+                                                        ") > packet buffer size (" + currentNetReadBufferSize + ")");
                     }
                     break;
                 } else if (unwrapResult.getStatus() == Status.CLOSED) {
@@ -536,10 +543,12 @@ public class SSLTransportLayer implements TransportLayer {
             written = wrapResult.bytesConsumed();
             flush(netWriteBuffer);
         } else if (wrapResult.getStatus() == Status.BUFFER_OVERFLOW) {
-            int currentPacketBufferSize = packetBufferSize();
-            netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, packetBufferSize());
-            if (netWriteBuffer.position() >= currentPacketBufferSize)
-                throw new IllegalStateException("SSL BUFFER_OVERFLOW when available data size (" + netWriteBuffer.position() + ") >= network buffer size (" + currentPacketBufferSize + ")");
+            int currentNetWriteBufferSize = netWriteBufferSize();
+            netWriteBuffer.compact();
+            netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize);
+            netWriteBuffer.flip();
+            if (netWriteBuffer.limit() >= currentNetWriteBufferSize)
+                throw new IllegalStateException("SSL BUFFER_OVERFLOW when available data size (" + netWriteBuffer.limit() + ") >= network buffer size (" + currentNetWriteBufferSize + ")");
         } else if (wrapResult.getStatus() == Status.BUFFER_UNDERFLOW) {
             throw new IllegalStateException("SSL BUFFER_UNDERFLOW during write");
         } else if (wrapResult.getStatus() == Status.CLOSED) {
@@ -668,13 +677,21 @@ public class SSLTransportLayer implements TransportLayer {
         return remaining;
     }
 
-    private int packetBufferSize() {
+    protected int netReadBufferSize() {
+        return sslEngine.getSession().getPacketBufferSize();
+    }
+    
+    protected int netWriteBufferSize() {
         return sslEngine.getSession().getPacketBufferSize();
     }
 
-    private int applicationBufferSize() {
+    protected int applicationBufferSize() {
         return sslEngine.getSession().getApplicationBufferSize();
     }
+    
+    protected ByteBuffer netReadBuffer() {
+        return netReadBuffer;
+    }
 
     private void handshakeFailure() {
         //Release all resources such as internal buffers that SSLEngine is managing

http://git-wip-us.apache.org/repos/asf/kafka/blob/2047a9af/clients/src/test/java/org/apache/kafka/common/network/SSLSelectorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SSLSelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SSLSelectorTest.java
index 5056c71..c28d427 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SSLSelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SSLSelectorTest.java
@@ -16,34 +16,23 @@ import static org.junit.Assert.assertEquals;
 
 import java.util.LinkedHashMap;
 import java.util.Map;
-
-import java.io.IOException;
 import java.io.File;
+import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
 
+import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.config.SSLConfigs;
 import org.apache.kafka.common.security.ssl.SSLFactory;
-import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
-import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestSSLUtils;
-import org.apache.kafka.test.TestUtils;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
 /**
- * A set of tests for the selector over ssl. These use a test harness that runs a simple socket server that echos back responses.
+ * A set of tests for the selector. These use a test harness that runs a simple socket server that echos back responses.
  */
-
-public class SSLSelectorTest {
-
-    private static final int BUFFER_SIZE = 4 * 1024;
-
-    private EchoServer server;
-    private Selector selector;
-    private ChannelBuilder channelBuilder;
+public class SSLSelectorTest extends SelectorTest {
 
     @Before
     public void setup() throws Exception {
@@ -53,12 +42,13 @@ public class SSLSelectorTest {
         sslServerConfigs.put(SSLConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Class.forName(SSLConfigs.DEFAULT_PRINCIPAL_BUILDER_CLASS));
         this.server = new EchoServer(sslServerConfigs);
         this.server.start();
+        this.time = new MockTime();
         Map<String, Object> sslClientConfigs = TestSSLUtils.createSSLConfig(false, false, SSLFactory.Mode.SERVER, trustStoreFile, "client");
         sslClientConfigs.put(SSLConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Class.forName(SSLConfigs.DEFAULT_PRINCIPAL_BUILDER_CLASS));
 
         this.channelBuilder = new SSLChannelBuilder(SSLFactory.Mode.CLIENT);
         this.channelBuilder.configure(sslClientConfigs);
-        this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", new LinkedHashMap<String, String>(), channelBuilder);
+        this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", new LinkedHashMap<String, String>(), channelBuilder);
     }
 
     @After
@@ -67,90 +57,6 @@ public class SSLSelectorTest {
         this.server.close();
     }
 
-
-    /**
-     * Validate that we can send and receive a message larger than the receive and send buffer size
-     */
-    @Test
-    public void testSendLargeRequest() throws Exception {
-        String node = "0";
-        blockingConnect(node);
-        String big = TestUtils.randomString(10 * BUFFER_SIZE);
-        assertEquals(big, blockingRequest(node, big));
-    }
-
-
-    /**
-     * Validate that when the server disconnects, a client send ends up with that node in the disconnected list.
-     */
-    @Test
-    public void testServerDisconnect() throws Exception {
-        String node = "0";
-        // connect and do a simple request
-        blockingConnect(node);
-        assertEquals("hello", blockingRequest(node, "hello"));
-
-        // disconnect
-        this.server.closeConnections();
-        while (!selector.disconnected().contains(node))
-            selector.poll(1000L);
-
-        // reconnect and do another request
-        blockingConnect(node);
-        assertEquals("hello", blockingRequest(node, "hello"));
-    }
-
-     /**
-     * Tests wrap BUFFER_OVERFLOW  and unwrap BUFFER_UNDERFLOW
-     * @throws Exception
-     */
-    @Test
-    public void testLargeMessageSequence() throws Exception {
-        int bufferSize = 512 * 1024;
-        String node = "0";
-        int reqs = 50;
-        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
-        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
-        String requestPrefix = TestUtils.randomString(bufferSize);
-        sendAndReceive(node, requestPrefix, 0, reqs);
-    }
-
-
-    /**
-     * Test sending an empty string
-     */
-    @Test
-    public void testEmptyRequest() throws Exception {
-        String node = "0";
-        blockingConnect(node);
-        assertEquals("", blockingRequest(node, ""));
-    }
-
-
-    @Test
-    public void testMute() throws Exception {
-        blockingConnect("0");
-        blockingConnect("1");
-        // wait for handshake to finish
-        while (!selector.isChannelReady("0") && !selector.isChannelReady("1"))
-            selector.poll(5);
-        selector.send(createSend("0", "hello"));
-        selector.send(createSend("1", "hi"));
-        selector.mute("1");
-
-        while (selector.completedReceives().isEmpty())
-            selector.poll(5);
-        assertEquals("We should have only one response", 1, selector.completedReceives().size());
-        assertEquals("The response should not be from the muted node", "0", selector.completedReceives().get(0).source());
-        selector.unmute("1");
-        do {
-            selector.poll(5);
-        } while (selector.completedReceives().isEmpty());
-        assertEquals("We should have only one response", 1, selector.completedReceives().size());
-        assertEquals("The response should be from the previously muted node", "1", selector.completedReceives().get(0).source());
-    }
-
-
     /**
      * Tests that SSL renegotiation initiated by the server are handled correctly by the client
      * @throws Exception
@@ -199,59 +105,13 @@ public class SSLSelectorTest {
         }
     }
 
-    private String blockingRequest(String node, String s) throws IOException {
-        selector.send(createSend(node, s));
-        while (true) {
-            selector.poll(1000L);
-            for (NetworkReceive receive : selector.completedReceives())
-                if (receive.source() == node)
-                    return asString(receive);
-        }
-    }
-
-    private String asString(NetworkReceive receive) {
-        return new String(Utils.toArray(receive.payload()));
-    }
-
-    private NetworkSend createSend(String node, String s) {
-        return new NetworkSend(node, ByteBuffer.wrap(s.getBytes()));
-    }
-
-    /* connect and wait for the connection to complete */
-    private void blockingConnect(String node) throws IOException {
-        selector.connect(node, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE);
-        while (!selector.connected().contains(node))
-            selector.poll(10000L);
-        //finish the handshake as well
-        while (!selector.isChannelReady(node))
-            selector.poll(10000L);
-    }
-
-
-    private void sendAndReceive(String node, String requestPrefix, int startIndex, int endIndex) throws Exception {
-        int requests = startIndex;
-        int responses = startIndex;
-        // wait for handshake to finish
-        while (!selector.isChannelReady(node)) {
-            selector.poll(1000L);
-        }
-        selector.send(createSend(node, requestPrefix + "-" + startIndex));
-        requests++;
-        while (responses < endIndex) {
-            // do the i/o
-            selector.poll(0L);
-            assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size());
-
-            // handle requests and responses of the fast node
-            for (NetworkReceive receive : selector.completedReceives()) {
-                assertEquals(requestPrefix + "-" + responses, asString(receive));
-                responses++;
-            }
-
-            for (int i = 0; i < selector.completedSends().size() && requests < endIndex && selector.isChannelReady(node); i++, requests++) {
-                selector.send(createSend(node, requestPrefix + "-" + requests));
-            }
-        }
+    /**
+     * Connects and waits for handshake to complete. This is required since SSLTransportLayer
+     * implementation requires the channel to be ready before send is invoked (unlike plaintext
+     * where send can be invoked straight after connect)
+     */
+    protected void connect(String node, InetSocketAddress serverAddr) throws IOException {
+        blockingConnect(node, serverAddr);
     }
 
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/2047a9af/clients/src/test/java/org/apache/kafka/common/network/SSLTransportLayerTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SSLTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SSLTransportLayerTest.java
new file mode 100644
index 0000000..43da621
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/network/SSLTransportLayerTest.java
@@ -0,0 +1,654 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.kafka.common.network;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.io.IOException;
+import java.io.File;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.ServerSocketChannel;
+import java.nio.channels.SocketChannel;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.config.SSLConfigs;
+import org.apache.kafka.common.security.ssl.SSLFactory;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.test.TestSSLUtils;
+import org.apache.kafka.test.TestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Tests for the SSL transport layer. These use a test harness that runs a simple socket server that echos back responses.
+ */
+
+public class SSLTransportLayerTest {
+
+    private static final int BUFFER_SIZE = 4 * 1024;
+
+    private SSLEchoServer server;
+    private Selector selector;
+    private ChannelBuilder channelBuilder;
+    private CertStores serverCertStores;
+    private CertStores clientCertStores;
+    private Map<String, Object> sslClientConfigs;
+    private Map<String, Object> sslServerConfigs;
+
+    @Before
+    public void setup() throws Exception {
+        // Create certificates for use by client and server. Add server cert to client truststore and vice versa.
+        serverCertStores = new CertStores(true);
+        clientCertStores = new CertStores(false);
+        sslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores);
+        sslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores);
+
+        this.channelBuilder = new SSLChannelBuilder(SSLFactory.Mode.CLIENT);
+        this.channelBuilder.configure(sslClientConfigs);
+        this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", new LinkedHashMap<String, String>(), channelBuilder);
+    }
+
+    @After
+    public void teardown() throws Exception {
+        if (selector != null)
+            this.selector.close();
+        if (server != null)
+            this.server.close();
+    }
+
+    /**
+     * Tests that server certificate with valid IP address is accepted by
+     * a client that validates server endpoint.
+     */
+    @Test
+    public void testValidEndpointIdentification() throws Exception {
+        String node = "0";
+        createEchoServer(sslServerConfigs);
+        sslClientConfigs.put(SSLConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS");
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that server certificate with invalid IP address is not accepted by
+     * a client that validates server endpoint. Certificate uses "localhost" as
+     * common name, test uses host IP to trigger endpoint validation failure.
+     */
+    @Test
+    public void testInvalidEndpointIdentification() throws Exception {
+        String node = "0";
+        String serverHost = InetAddress.getLocalHost().getHostAddress();
+        server = new SSLEchoServer(sslServerConfigs, serverHost);
+        server.start();
+        sslClientConfigs.put(SSLConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS");
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress(serverHost, server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        waitForChannelClose(node);
+    }
+    
+    /**
+     * Tests that server certificate with invalid IP address is accepted by
+     * a client that has disabled endpoint validation
+     */
+    @Test
+    public void testEndpointIdentificationDisabled() throws Exception {
+        String node = "0";
+        String serverHost = InetAddress.getLocalHost().getHostAddress();
+        server = new SSLEchoServer(sslServerConfigs, serverHost);
+        server.start();
+        sslClientConfigs.remove(SSLConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG);
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress(serverHost, server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that server accepts connections from clients with a trusted certificate
+     * when client authentication is required.
+     */
+    @Test
+    public void testClientAuthenticationRequiredValidProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
+        createEchoServer(sslServerConfigs);
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that server does not accept connections from clients with an untrusted certificate
+     * when client authentication is required.
+     */
+    @Test
+    public void testClientAuthenticationRequiredUntrustedProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs = serverCertStores.getUntrustingConfig();
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
+        createEchoServer(sslServerConfigs);        
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        waitForChannelClose(node);
+    }
+    
+    /**
+     * Tests that server does not accept connections from clients which dont
+     * provide a certificate when client authentication is required.
+     */
+    @Test
+    public void testClientAuthenticationRequiredNotProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
+        createEchoServer(sslServerConfigs);
+        
+        sslClientConfigs.remove(SSLConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
+        sslClientConfigs.remove(SSLConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
+        sslClientConfigs.remove(SSLConfigs.SSL_KEY_PASSWORD_CONFIG);
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        waitForChannelClose(node);
+    }
+    
+    /**
+     * Tests that server accepts connections from a client configured
+     * with an untrusted certificate if client authentication is disabled
+     */
+    @Test
+    public void testClientAuthenticationDisabledUntrustedProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs = serverCertStores.getUntrustingConfig();
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "none");
+        createEchoServer(sslServerConfigs);      
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that server accepts connections from a client that does not provide
+     * a certificate if client authentication is disabled
+     */
+    @Test
+    public void testClientAuthenticationDisabledNotProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "none");
+        createEchoServer(sslServerConfigs);
+        
+        sslClientConfigs.remove(SSLConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
+        sslClientConfigs.remove(SSLConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
+        sslClientConfigs.remove(SSLConfigs.SSL_KEY_PASSWORD_CONFIG);
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that server accepts connections from a client configured
+     * with a valid certificate if client authentication is requested
+     */
+    @Test
+    public void testClientAuthenticationRequestedValidProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "requested");
+        createEchoServer(sslServerConfigs);
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that server accepts connections from a client that does not provide
+     * a certificate if client authentication is requested but not required
+     */
+    @Test
+    public void testClientAuthenticationRequestedNotProvided() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_CLIENT_AUTH_CONFIG, "requested");
+        createEchoServer(sslServerConfigs);
+        
+        sslClientConfigs.remove(SSLConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
+        sslClientConfigs.remove(SSLConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
+        sslClientConfigs.remove(SSLConfigs.SSL_KEY_PASSWORD_CONFIG);
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 100, 10);
+    }
+    
+    /**
+     * Tests that channels cannot be created if truststore cannot be loaded
+     */
+    @Test
+    public void testInvalidTruststorePassword() throws Exception {
+        SSLChannelBuilder channelBuilder = new SSLChannelBuilder(SSLFactory.Mode.CLIENT);
+        try {
+            sslClientConfigs.put(SSLConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, "invalid");
+            channelBuilder.configure(sslClientConfigs);
+            fail("SSL channel configured with invalid truststore password");
+        } catch (KafkaException e) {
+            // Expected exception
+        }
+    }
+    
+    /**
+     * Tests that channels cannot be created if keystore cannot be loaded
+     */
+    @Test
+    public void testInvalidKeystorePassword() throws Exception {
+        SSLChannelBuilder channelBuilder = new SSLChannelBuilder(SSLFactory.Mode.CLIENT);
+        try {
+            sslClientConfigs.put(SSLConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, "invalid");
+            channelBuilder.configure(sslClientConfigs);
+            fail("SSL channel configured with invalid keystore password");
+        } catch (KafkaException e) {
+            // Expected exception
+        }
+    }
+    
+    /**
+     * Tests that client connections cannot be created to a server
+     * if key password is invalid
+     */
+    @Test
+    public void testInvalidKeyPassword() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_KEY_PASSWORD_CONFIG, "invalid");
+        createEchoServer(sslServerConfigs);        
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        waitForChannelClose(node);
+    }
+    
+    /**
+     * Tests that connections cannot be made with unsupported TLS versions
+     */
+    @Test
+    public void testUnsupportedTLSVersion() throws Exception {
+        String node = "0";
+        sslServerConfigs.put(SSLConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.2"));
+        createEchoServer(sslServerConfigs);
+        
+        sslClientConfigs.put(SSLConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.1"));
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        waitForChannelClose(node);
+    }
+    
+    /**
+     * Tests that connections cannot be made with unsupported TLS cipher suites
+     */
+    @Test
+    public void testUnsupportedCiphers() throws Exception {
+        String node = "0";
+        String[] cipherSuites = SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites();
+        sslServerConfigs.put(SSLConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[0]));
+        createEchoServer(sslServerConfigs);
+        
+        sslClientConfigs.put(SSLConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[1]));
+        createSelector(sslClientConfigs);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        waitForChannelClose(node);
+    }
+
+    /**
+     * Tests handling of BUFFER_UNDERFLOW during unwrap when network read buffer is smaller than SSL session packet buffer size.
+     */
+    @Test
+    public void testNetReadBufferResize() throws Exception {
+        String node = "0";
+        createEchoServer(sslServerConfigs);
+        createSelector(sslClientConfigs, 10, null, null);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 64000, 10);
+    }
+    
+    /**
+     * Tests handling of BUFFER_OVERFLOW during wrap when network write buffer is smaller than SSL session packet buffer size.
+     */
+    @Test
+    public void testNetWriteBufferResize() throws Exception {
+        String node = "0";
+        createEchoServer(sslServerConfigs);
+        createSelector(sslClientConfigs, null, 10, null);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 64000, 10);
+    }
+
+    /**
+     * Tests handling of BUFFER_OVERFLOW during unwrap when application read buffer is smaller than SSL session application buffer size.
+     */
+    @Test
+    public void testApplicationBufferResize() throws Exception {
+        String node = "0";
+        createEchoServer(sslServerConfigs);
+        createSelector(sslClientConfigs, null, null, 10);
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+        testClientConnection(node, 64000, 10);
+    }
+
+    private void testClientConnection(String node, int minMessageSize, int messageCount) throws Exception {
+
+        String prefix = TestUtils.randomString(minMessageSize);
+        int requests = 0;
+        int responses = 0;
+        // wait for handshake to finish
+        while (!selector.isChannelReady(node)) {
+            selector.poll(1000L);
+        }
+        selector.send(new NetworkSend(node, ByteBuffer.wrap((prefix + "-0").getBytes())));
+        requests++;
+        while (responses < messageCount) {
+            selector.poll(0L);
+            assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size());
+
+            for (NetworkReceive receive : selector.completedReceives()) {
+                assertEquals(prefix + "-" + responses, new String(Utils.toArray(receive.payload())));
+                responses++;
+            }
+
+            for (int i = 0; i < selector.completedSends().size() && requests < messageCount && selector.isChannelReady(node); i++, requests++) {
+                selector.send(new NetworkSend(node, ByteBuffer.wrap((prefix + "-" + requests).getBytes())));
+            }
+        }
+    }
+    
+    private void waitForChannelClose(String node) throws IOException {
+        boolean closed = false;
+        for (int i = 0; i < 30; i++) {
+            selector.poll(1000L);
+            try {
+                selector.channelForId(node);
+            } catch (IllegalStateException e) {
+                closed = true;
+                break;
+            }
+        }
+        assertTrue(closed);
+    }
+    
+    private void createEchoServer(Map<String, Object> sslServerConfigs) throws Exception {
+        server = new SSLEchoServer(sslServerConfigs, "localhost");
+        server.start();
+    }
+    
+    private void createSelector(Map<String, Object> sslClientConfigs) {
+        createSelector(sslClientConfigs, null, null, null);
+    }      
+
+    private void createSelector(Map<String, Object> sslClientConfigs, final Integer netReadBufSize, final Integer netWriteBufSize, final Integer appBufSize) {
+        
+        this.channelBuilder = new SSLChannelBuilder(SSLFactory.Mode.CLIENT) {
+
+            @Override
+            protected SSLTransportLayer buildTransportLayer(SSLFactory sslFactory, String id, SelectionKey key) throws IOException {
+                SocketChannel socketChannel = (SocketChannel) key.channel();
+                SSLEngine sslEngine = sslFactory.createSSLEngine(socketChannel.socket().getInetAddress().getHostName(),
+                                socketChannel.socket().getPort());
+                TestSSLTransportLayer transportLayer = new TestSSLTransportLayer(id, key, sslEngine, netReadBufSize, netWriteBufSize, appBufSize);
+                transportLayer.startHandshake();
+                return transportLayer;
+            }
+
+
+        };
+        this.channelBuilder.configure(sslClientConfigs);
+        this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", new LinkedHashMap<String, String>(), channelBuilder);
+    }
+    
+    private static class CertStores {
+        
+        Map<String, Object> sslConfig;
+        
+        CertStores(boolean server) throws Exception {
+            String name = server ? "server" : "client";
+            SSLFactory.Mode mode = server ? SSLFactory.Mode.SERVER : SSLFactory.Mode.CLIENT;
+            File truststoreFile = File.createTempFile(name + "TS", ".jks");
+            sslConfig = TestSSLUtils.createSSLConfig(!server, true, mode, truststoreFile, name);
+            sslConfig.put(SSLConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Class.forName(SSLConfigs.DEFAULT_PRINCIPAL_BUILDER_CLASS));
+        }
+       
+        private Map<String, Object> getTrustingConfig(CertStores truststoreConfig) {
+            Map<String, Object> config = new HashMap<String, Object>(sslConfig);
+            config.put(SSLConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, truststoreConfig.sslConfig.get(SSLConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG));
+            config.put(SSLConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, truststoreConfig.sslConfig.get(SSLConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG));
+            config.put(SSLConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, truststoreConfig.sslConfig.get(SSLConfigs.SSL_TRUSTSTORE_TYPE_CONFIG));
+            return config;
+        }
+        
+        private Map<String, Object> getUntrustingConfig() {
+            return sslConfig;
+        }
+    }
+
+    /**
+     * SSLTransportLayer with overrides for packet and application buffer size to test buffer resize
+     * code path. The overridden buffer size starts with a small value and increases in size when the buffer
+     * size is retrieved to handle overflow/underflow, until the actual session buffer size is reached.
+     */
+    private static class TestSSLTransportLayer extends SSLTransportLayer {
+
+        private final ResizeableBufferSize netReadBufSize;
+        private final ResizeableBufferSize netWriteBufSize;
+        private final ResizeableBufferSize appBufSize;
+
+        public TestSSLTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine, 
+                Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) throws IOException {
+            super(channelId, key, sslEngine);
+            this.netReadBufSize = new ResizeableBufferSize(netReadBufSize);
+            this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSize);
+            this.appBufSize = new ResizeableBufferSize(appBufSize);
+        }
+        
+        @Override
+        protected int netReadBufferSize() {
+            ByteBuffer netReadBuffer = netReadBuffer();
+            // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read
+            // operation. To avoid the read buffer being expanded too early, increase buffer size
+            // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always
+            // triggered in testNetReadBufferResize().
+            boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining();
+            return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize);
+        }
+        
+        @Override
+        protected int netWriteBufferSize() {
+            return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true);
+        }
+
+        @Override
+        protected int applicationBufferSize() {
+            return appBufSize.updateAndGet(super.applicationBufferSize(), true);
+        }
+        
+        private static class ResizeableBufferSize {
+            private Integer bufSizeOverride;
+            ResizeableBufferSize(Integer bufSizeOverride) {
+                this.bufSizeOverride = bufSizeOverride;
+            }
+            int updateAndGet(int actualSize, boolean update) {
+                int size = actualSize;
+                if (bufSizeOverride != null) {
+                    if (update)
+                        bufSizeOverride = Math.min(bufSizeOverride * 2, size);
+                    size = bufSizeOverride;
+                }
+                return size;
+            }
+        }
+    }
+    
+    // Non-blocking EchoServer implementation that uses SSLTransportLayer
+    private class SSLEchoServer extends Thread {
+        private final int port;
+        private final ServerSocketChannel serverSocketChannel;
+        private final List<SocketChannel> newChannels;
+        private final List<SocketChannel> socketChannels;
+        private final AcceptorThread acceptorThread;
+        private SSLFactory sslFactory;
+        private final Selector selector;
+        private final ConcurrentLinkedQueue<NetworkSend> inflightSends = new ConcurrentLinkedQueue<NetworkSend>();
+
+        public SSLEchoServer(Map<String, ?> configs, String serverHost) throws Exception {
+            this.sslFactory = new SSLFactory(SSLFactory.Mode.SERVER);
+            this.sslFactory.configure(configs);
+            serverSocketChannel = ServerSocketChannel.open();
+            serverSocketChannel.configureBlocking(false);
+            serverSocketChannel.socket().bind(new InetSocketAddress(serverHost, 0));
+            this.port = serverSocketChannel.socket().getLocalPort();
+            this.socketChannels = Collections.synchronizedList(new ArrayList<SocketChannel>());
+            this.newChannels = Collections.synchronizedList(new ArrayList<SocketChannel>());
+            SSLChannelBuilder channelBuilder = new SSLChannelBuilder(SSLFactory.Mode.SERVER);
+            channelBuilder.configure(sslServerConfigs);
+            this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", new LinkedHashMap<String, String>(), channelBuilder);
+            setName("echoserver");
+            setDaemon(true);
+            acceptorThread = new AcceptorThread();
+        }
+
+        @Override
+        public void run() {
+            try {
+                acceptorThread.start();
+                while (serverSocketChannel.isOpen()) {
+                    selector.poll(1000);
+                    for (SocketChannel socketChannel : newChannels) {
+                        String id = id(socketChannel);
+                        selector.register(id, socketChannel);
+                        socketChannels.add(socketChannel);
+                    }
+                    newChannels.clear();
+                    while (true) {
+                        NetworkSend send = inflightSends.peek();
+                        if (send != null && !selector.channelForId(send.destination()).hasSend()) {
+                            send = inflightSends.poll();
+                            selector.send(send);
+                        } else
+                            break;
+                    }
+                    List<NetworkReceive> completedReceives = selector.completedReceives();
+                    for (NetworkReceive rcv : completedReceives) {
+                        NetworkSend send = new NetworkSend(rcv.source(), rcv.payload());
+                        if (!selector.channelForId(send.destination()).hasSend())
+                            selector.send(send);
+                        else
+                            inflightSends.add(send);
+                    }
+                }
+            } catch (IOException e) {
+                // ignore
+            }
+        }
+        
+        private String id(SocketChannel channel) {
+            return channel.socket().getLocalAddress().getHostAddress() + ":" + channel.socket().getLocalPort() + "-" +
+                    channel.socket().getInetAddress().getHostAddress() + ":" + channel.socket().getPort();
+        }
+
+        public void closeConnections() throws IOException {
+            for (SocketChannel channel : socketChannels)
+                channel.close();
+            socketChannels.clear();
+        }
+
+        public void close() throws IOException, InterruptedException {
+            this.serverSocketChannel.close();
+            closeConnections();
+            acceptorThread.interrupt();
+            acceptorThread.join();
+            interrupt();
+            join();
+        }
+        
+        private class AcceptorThread extends Thread {
+            public AcceptorThread() throws IOException {
+                setName("acceptor");
+            }
+            public void run() {
+                try {
+
+                    java.nio.channels.Selector acceptSelector = java.nio.channels.Selector.open();
+                    serverSocketChannel.register(acceptSelector, SelectionKey.OP_ACCEPT);
+                    while (serverSocketChannel.isOpen()) {
+                        if (acceptSelector.select(1000) > 0) {
+                            Iterator<SelectionKey> it = acceptSelector.selectedKeys().iterator();
+                            while (it.hasNext()) {
+                                SelectionKey key = it.next();
+                                if (key.isAcceptable()) {
+                                    SocketChannel socketChannel = ((ServerSocketChannel) key.channel()).accept();
+                                    socketChannel.configureBlocking(false);
+                                    newChannels.add(socketChannel);
+                                    selector.wakeup();
+                                }
+                            }
+                        }
+                    }
+                } catch (IOException e) {
+                    // ignore
+                }
+            }
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/2047a9af/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 66ca530..bfc4be5 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -38,12 +38,12 @@ import org.junit.Test;
  */
 public class SelectorTest {
 
-    private static final int BUFFER_SIZE = 4 * 1024;
+    protected static final int BUFFER_SIZE = 4 * 1024;
 
-    private EchoServer server;
-    private Time time;
-    private Selectable selector;
-    private ChannelBuilder channelBuilder;
+    protected EchoServer server;
+    protected Time time;
+    protected Selectable selector;
+    protected ChannelBuilder channelBuilder;
 
     @Before
     public void setup() throws Exception {
@@ -139,7 +139,7 @@ public class SelectorTest {
         // create connections
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
         for (int i = 0; i < conns; i++)
-            selector.connect(Integer.toString(i), addr, BUFFER_SIZE, BUFFER_SIZE);
+            connect(Integer.toString(i), addr);
         // send echo requests and receive responses
         Map<String, Integer> requests = new HashMap<String, Integer>();
         Map<String, Integer> responses = new HashMap<String, Integer>();
@@ -202,7 +202,7 @@ public class SelectorTest {
         String node = "0";
         int reqs = 50;
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
-        selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+        connect(node, addr);
         String requestPrefix = TestUtils.randomString(bufferSize);
         sendAndReceive(node, requestPrefix, 0, reqs);
     }
@@ -260,6 +260,7 @@ public class SelectorTest {
         assertTrue("The idle connection should have been closed", selector.disconnected().contains(id));
     }
 
+    
     private String blockingRequest(String node, String s) throws IOException {
         selector.send(createSend(node, s));
         selector.poll(1000L);
@@ -270,19 +271,28 @@ public class SelectorTest {
                     return asString(receive);
         }
     }
+    
+    protected void connect(String node, InetSocketAddress serverAddr) throws IOException {
+        selector.connect(node, serverAddr, BUFFER_SIZE, BUFFER_SIZE);
+    }
 
     /* connect and wait for the connection to complete */
     private void blockingConnect(String node) throws IOException {
-        selector.connect(node, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE);
+        blockingConnect(node, new InetSocketAddress("localhost", server.port));
+    }
+    protected void blockingConnect(String node, InetSocketAddress serverAddr) throws IOException {
+        selector.connect(node, serverAddr, BUFFER_SIZE, BUFFER_SIZE);
         while (!selector.connected().contains(node))
             selector.poll(10000L);
+        while (!selector.isChannelReady(node))
+            selector.poll(10000L);
     }
 
-    private NetworkSend createSend(String node, String s) {
+    protected NetworkSend createSend(String node, String s) {
         return new NetworkSend(node, ByteBuffer.wrap(s.getBytes()));
     }
 
-    private String asString(NetworkReceive receive) {
+    protected String asString(NetworkReceive receive) {
         return new String(Utils.toArray(receive.payload()));
     }
 


Mime
View raw message