kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rsiva...@apache.org
Subject [kafka] branch trunk updated: KAFKA-7169: Custom SASL extensions for OAuthBearer authentication mechanism (KIP-342) (#5379)
Date Mon, 06 Aug 2018 16:22:08 GMT
This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 518e9d3  KAFKA-7169: Custom SASL extensions for OAuthBearer authentication mechanism (KIP-342) (#5379)
518e9d3 is described below

commit 518e9d3eee1b1d2c8c76d68e043edd3cf49139fa
Author: Stanislav Kozlovski <familyguyuser192@windowslive.com>
AuthorDate: Mon Aug 6 17:22:04 2018 +0100

    KAFKA-7169: Custom SASL extensions for OAuthBearer authentication mechanism (KIP-342) (#5379)
    
    Reviewers: Ron Dagostino <rndgstn@gmail.com>, Rajini Sivaram <rajinisivaram@googlemail.com>
---
 .../kafka/common/security/auth/SaslExtensions.java |  57 ++++++++
 .../SaslExtensionsCallback.java}                   |  25 ++--
 .../authenticator/SaslClientCallbackHandler.java   |  19 ++-
 .../oauthbearer/OAuthBearerLoginModule.java        |  78 +++++++++--
 .../OAuthBearerClientInitialResponse.java          |  64 +++++++--
 .../internals/OAuthBearerSaslClient.java           |  24 +++-
 .../OAuthBearerSaslClientCallbackHandler.java      |  18 ++-
 .../internals/OAuthBearerSaslServer.java           |  16 ++-
 .../OAuthBearerUnsecuredLoginCallbackHandler.java  |  51 ++++++-
 .../security/scram/ScramExtensionsCallback.java    |  10 +-
 .../security/scram/internals/ScramExtensions.java  |  24 +---
 .../security/scram/internals/ScramMessages.java    |   5 +-
 .../security/scram/internals/ScramSaslServer.java  |   6 +-
 .../java/org/apache/kafka/common/utils/Utils.java  |  13 ++
 .../kafka/common/security/SaslExtensionsTest.java  |  52 ++++++++
 .../oauthbearer/OAuthBearerLoginModuleTest.java    | 147 +++++++++++++++++++--
 .../OAuthBearerClientInitialResponseTest.java      |  47 ++++++-
 .../internals/OAuthBearerSaslClientTest.java       | 125 ++++++++++++++++++
 .../internals/OAuthBearerSaslServerTest.java       |  40 +++++-
 ...uthBearerUnsecuredLoginCallbackHandlerTest.java |  34 +++++
 docs/security.html                                 |   7 +
 21 files changed, 765 insertions(+), 97 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java
new file mode 100644
index 0000000..75cac05
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.auth;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A simple immutable value object class holding customizable SASL extensions
+ */
+public class SaslExtensions {
+    private final Map<String, String> extensionsMap;
+
+    public SaslExtensions(Map<String, String> extensionsMap) {
+        this.extensionsMap = Collections.unmodifiableMap(new HashMap<>(extensionsMap));
+    }
+
+    /**
+     * Returns an <strong>immutable</strong> map of the extension names and their values
+     */
+    public Map<String, String> map() {
+        return extensionsMap;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return extensionsMap.equals(((SaslExtensions) o).extensionsMap);
+    }
+
+    @Override
+    public String toString() {
+        return extensionsMap.toString();
+    }
+
+    @Override
+    public int hashCode() {
+        return extensionsMap.hashCode();
+    }
+
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java
similarity index 58%
copy from clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
copy to clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java
index debe163..d07be32 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java
@@ -15,32 +15,29 @@
  * limitations under the License.
  */
 
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.auth;
 
 import javax.security.auth.callback.Callback;
-import java.util.Collections;
-import java.util.Map;
 
 /**
- * Optional callback used for SCRAM mechanisms if any extensions need to be set
- * in the SASL/SCRAM exchange.
+ * Optional callback used for SASL mechanisms if any extensions need to be set
+ * in the SASL exchange.
  */
-public class ScramExtensionsCallback implements Callback {
-    private Map<String, String> extensions = Collections.emptyMap();
+public class SaslExtensionsCallback implements Callback {
+    private SaslExtensions extensions;
 
     /**
-     * Returns the extension names and values that are sent by the client to
-     * the server in the initial client SCRAM authentication message.
-     * Default is an empty map.
+     * Returns a {@link SaslExtensions} consisting of the extension names and values that are sent by the client to
+     * the server in the initial client SASL authentication message.
      */
-    public Map<String, String> extensions() {
+    public SaslExtensions extensions() {
         return extensions;
     }
 
     /**
-     * Sets the SCRAM extensions on this callback.
+     * Sets the SASL extensions on this callback.
      */
-    public void extensions(Map<String, String> extensions) {
+    public void extensions(SaslExtensions extensions) {
         this.extensions = extensions;
     }
-}
\ No newline at end of file
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
index 5b2a281..8b830c0 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
@@ -30,14 +30,19 @@ import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
 
 import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.scram.internals.ScramMechanism;
 
 /**
  * Default callback handler for Sasl clients. The callbacks required for the SASL mechanism
  * configured for the client should be supported by this callback handler. See
  * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html">Java SASL API</a>
  * for the list of SASL callback handlers required for each SASL mechanism.
+ *
+ * For adding custom SASL extensions, a {@link SaslExtensions} may be added to the subject's public credentials
  */
 public class SaslClientCallbackHandler implements AuthenticateCallbackHandler {
 
@@ -78,9 +83,15 @@ public class SaslClientCallbackHandler implements AuthenticateCallbackHandler {
                 if (ac.isAuthorized())
                     ac.setAuthorizedID(authzId);
             } else if (callback instanceof ScramExtensionsCallback) {
-                ScramExtensionsCallback sc = (ScramExtensionsCallback) callback;
-                if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
-                    sc.extensions((Map<String, String>) subject.getPublicCredentials(Map.class).iterator().next());
+                if (ScramMechanism.isScram(mechanism) && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
+                    Map<String, String> extensions = (Map<String, String>) subject.getPublicCredentials(Map.class).iterator().next();
+                    ((ScramExtensionsCallback) callback).extensions(extensions);
+                }
+            } else if (callback instanceof SaslExtensionsCallback) {
+                if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) &&
+                        subject != null && !subject.getPublicCredentials(SaslExtensions.class).isEmpty()) {
+                    SaslExtensions extensions = subject.getPublicCredentials(SaslExtensions.class).iterator().next();
+                    ((SaslExtensionsCallback) callback).extensions(extensions);
                 }
             }  else {
                 throw new UnsupportedCallbackException(callback, "Unrecognized SASL ClientCallback");
diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
index 07382a8..57fa5d2 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.common.security.oauthbearer;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Objects;
@@ -31,6 +32,8 @@ import javax.security.auth.spi.LoginModule;
 import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.Login;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClientProvider;
 import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslServerProvider;
 import org.slf4j.Logger;
@@ -91,6 +94,16 @@ import org.slf4j.LoggerFactory;
  * </tr>
  * </table>
  * <p>
+ * <p>
+ * You can also add custom unsecured SASL extensions when using the default, builtin {@link AuthenticateCallbackHandler}
+ * implementation through using the configurable option {@code unsecuredLoginExtension_<extensionname>}. Note that there
+ * are validations for the key/values in order to conform to the OAuth standard, including the reserved key at
+ * {@link org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse#AUTH_KEY}.
+ * The {@code OAuthBearerLoginModule} instance also asks its configured {@link AuthenticateCallbackHandler}
+ * implementation to handle an instance of {@link SaslExtensionsCallback} and return an instance of {@link SaslExtensions}.
+ * The configured callback handler does not need to handle this callback, though -- any {@code UnsupportedCallbackException}
+ * that is thrown is ignored, and no SASL extensions will be associated with the login.
+ * <p>
  * Production use cases will require writing an implementation of
  * {@link AuthenticateCallbackHandler} that can handle an instance of
  * {@link OAuthBearerTokenCallback} and declaring it via either the
@@ -227,10 +240,13 @@ public class OAuthBearerLoginModule implements LoginModule {
      */
     public static final String OAUTHBEARER_MECHANISM = "OAUTHBEARER";
     private static final Logger log = LoggerFactory.getLogger(OAuthBearerLoginModule.class);
+    private static final SaslExtensions EMPTY_EXTENSIONS = new SaslExtensions(Collections.emptyMap());
     private Subject subject = null;
     private AuthenticateCallbackHandler callbackHandler = null;
     private OAuthBearerToken tokenRequiringCommit = null;
     private OAuthBearerToken myCommittedToken = null;
+    private SaslExtensions extensionsRequiringCommit = null;
+    private SaslExtensions myCommittedExtensions = null;
 
     static {
         OAuthBearerSaslClientProvider.initialize(); // not part of public API
@@ -256,22 +272,51 @@ public class OAuthBearerLoginModule implements LoginModule {
             throw new IllegalStateException(String.format(
                     "Already have a committed token with private credential token count=%d; must login on another login context or logout here first before reusing the same login context",
                     committedTokenCount()));
-        OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
+
+        identifyToken();
+        identifyExtensions();
+
+        log.info("Login succeeded; invoke commit() to commit it; current committed token count={}",
+                committedTokenCount());
+        return true;
+    }
+
+    private void identifyToken() throws LoginException {
+        OAuthBearerTokenCallback tokenCallback = new OAuthBearerTokenCallback();
         try {
-            callbackHandler.handle(new Callback[] {callback});
+            callbackHandler.handle(new Callback[] {tokenCallback});
         } catch (IOException | UnsupportedCallbackException e) {
             log.error(e.getMessage(), e);
-            throw new LoginException("An internal error occurred");
+            throw new LoginException("An internal error occurred while retrieving token from callback handler");
         }
-        tokenRequiringCommit = callback.token();
+
+        tokenRequiringCommit = tokenCallback.token();
         if (tokenRequiringCommit == null) {
-            log.info(String.format("Login failed: %s : %s (URI=%s)", callback.errorCode(), callback.errorDescription(),
-                    callback.errorUri()));
-            throw new LoginException(callback.errorDescription());
+            log.info("Login failed: {} : {} (URI={})", tokenCallback.errorCode(), tokenCallback.errorDescription(),
+                    tokenCallback.errorUri());
+            throw new LoginException(tokenCallback.errorDescription());
+        }
+    }
+
+    /**
+     * Attaches SASL extensions to the Subject
+     */
+    private void identifyExtensions() throws LoginException {
+        SaslExtensionsCallback extensionsCallback = new SaslExtensionsCallback();
+        try {
+            callbackHandler.handle(new Callback[] {extensionsCallback});
+            extensionsRequiringCommit = extensionsCallback.extensions();
+        } catch (IOException e) {
+            log.error(e.getMessage(), e);
+            throw new LoginException("An internal error occurred while retrieving SASL extensions from callback handler");
+        } catch (UnsupportedCallbackException e) {
+            extensionsRequiringCommit = EMPTY_EXTENSIONS;
+            log.info("CallbackHandler {} does not support SASL extensions. No extensions will be added", callbackHandler.getClass().getName());
+        }
+        if (extensionsRequiringCommit ==  null) {
+            log.error("SASL Extensions cannot be null. Check whether your callback handler is explicitly setting them as null.");
+            throw new LoginException("Extensions cannot be null.");
         }
-        log.info("Login succeeded; invoke commit() to commit it; current committed token count={}",
-                committedTokenCount());
-        return true;
     }
 
     @Override
@@ -294,6 +339,12 @@ public class OAuthBearerLoginModule implements LoginModule {
             }
         }
         log.info("Done logging out my token; committed token count is now {}", committedTokenCount());
+
+        log.info("Logging out my extensions");
+        if (subject.getPublicCredentials().removeIf(e -> myCommittedExtensions == e))
+            myCommittedExtensions = null;
+        log.info("Done logging out my extensions");
+
         return true;
     }
 
@@ -304,11 +355,17 @@ public class OAuthBearerLoginModule implements LoginModule {
                 log.debug("Nothing here to commit");
             return false;
         }
+
         log.info("Committing my token; current committed token count = {}", committedTokenCount());
         subject.getPrivateCredentials().add(tokenRequiringCommit);
         myCommittedToken = tokenRequiringCommit;
         tokenRequiringCommit = null;
         log.info("Done committing my token; committed token count is now {}", committedTokenCount());
+
+        subject.getPublicCredentials().add(extensionsRequiringCommit);
+        myCommittedExtensions = extensionsRequiringCommit;
+        extensionsRequiringCommit = null;
+
         return true;
     }
 
@@ -317,6 +374,7 @@ public class OAuthBearerLoginModule implements LoginModule {
         if (tokenRequiringCommit != null) {
             log.info("Login aborted");
             tokenRequiringCommit = null;
+            extensionsRequiringCommit = null;
             return true;
         }
         if (log.isDebugEnabled())
diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
index 8d4b18a..ef16ea2 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
@@ -16,11 +16,11 @@
  */
 package org.apache.kafka.common.security.oauthbearer.internals;
 
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.utils.Utils;
 
 import javax.security.sasl.SaslException;
 import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -31,15 +31,19 @@ public class OAuthBearerClientInitialResponse {
     private static final String SASLNAME = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+";
     private static final String KEY = "[A-Za-z]+";
     private static final String VALUE = "[\\x21-\\x7E \t\r\n]+";
+
     private static final String KVPAIRS = String.format("(%s=%s%s)*", KEY, VALUE, SEPARATOR);
     private static final Pattern AUTH_PATTERN = Pattern.compile("(?<scheme>[\\w]+)[ ]+(?<token>[-_\\.a-zA-Z0-9]+)");
     private static final Pattern CLIENT_INITIAL_RESPONSE_PATTERN = Pattern.compile(
             String.format("n,(a=(?<authzid>%s))?,%s(?<kvpairs>%s)%s", SASLNAME, SEPARATOR, KVPAIRS, SEPARATOR));
-    private static final String AUTH_KEY = "auth";
+    public static final String AUTH_KEY = "auth";
 
     private final String tokenValue;
     private final String authorizationId;
-    private final Map<String, String> properties;
+    private SaslExtensions saslExtensions;
+
+    public static final Pattern EXTENSION_KEY_PATTERN = Pattern.compile(KEY);
+    public static final Pattern EXTENSION_VALUE_PATTERN = Pattern.compile(VALUE);
 
     public OAuthBearerClientInitialResponse(byte[] response) throws SaslException {
         String responseMsg = new String(response, StandardCharsets.UTF_8);
@@ -49,10 +53,12 @@ public class OAuthBearerClientInitialResponse {
         String authzid = matcher.group("authzid");
         this.authorizationId = authzid == null ? "" : authzid;
         String kvPairs = matcher.group("kvpairs");
-        this.properties = Utils.parseMap(kvPairs, "=", SEPARATOR);
+        Map<String, String> properties = Utils.parseMap(kvPairs, "=", SEPARATOR);
         String auth = properties.get(AUTH_KEY);
         if (auth == null)
             throw new SaslException("Invalid OAUTHBEARER client first message: 'auth' not specified");
+        properties.remove(AUTH_KEY);
+        this.saslExtensions = validateExtensions(new SaslExtensions(properties));
 
         Matcher authMatcher = AUTH_PATTERN.matcher(auth);
         if (!authMatcher.matches())
@@ -65,20 +71,29 @@ public class OAuthBearerClientInitialResponse {
         this.tokenValue = authMatcher.group("token");
     }
 
-    public OAuthBearerClientInitialResponse(String tokenValue) {
-        this(tokenValue, "", new HashMap<>());
+    public OAuthBearerClientInitialResponse(String tokenValue, SaslExtensions extensions) throws SaslException {
+        this(tokenValue, "", extensions);
     }
 
-    public OAuthBearerClientInitialResponse(String tokenValue, String authorizationId, Map<String, String> props) {
+    public OAuthBearerClientInitialResponse(String tokenValue, String authorizationId, SaslExtensions extensions) throws SaslException {
         this.tokenValue = tokenValue;
         this.authorizationId = authorizationId == null ? "" : authorizationId;
-        this.properties = new HashMap<>(props);
+        this.saslExtensions = validateExtensions(extensions);
+    }
+
+    public SaslExtensions extensions() {
+        return saslExtensions;
     }
 
     public byte[] toBytes() {
         String authzid = authorizationId.isEmpty() ? "" : "a=" + authorizationId;
-        String message = String.format("n,%s,%sauth=Bearer %s%s%s", authzid,
-                SEPARATOR, tokenValue, SEPARATOR, SEPARATOR);
+        String extensions = extensionsMessage();
+        if (extensions.length() > 0)
+            extensions = SEPARATOR + extensions;
+
+        String message = String.format("n,%s,%sauth=Bearer %s%s%s%s", authzid,
+                SEPARATOR, tokenValue, extensions, SEPARATOR, SEPARATOR);
+
         return message.getBytes(StandardCharsets.UTF_8);
     }
 
@@ -90,7 +105,32 @@ public class OAuthBearerClientInitialResponse {
         return authorizationId;
     }
 
-    public String propertyValue(String name) {
-        return properties.get(name);
+    /**
+     * Validates that the given extensions conform to the standard. They should also not contain the reserve key name {@link OAuthBearerClientInitialResponse#AUTH_KEY}
+     *
+     * @see <a href="https://tools.ietf.org/html/rfc7628#section-3.1">RFC 7628,
+     *  Section 3.1</a>
+     */
+    public static SaslExtensions validateExtensions(SaslExtensions extensions) throws SaslException {
+        if (extensions.map().containsKey(OAuthBearerClientInitialResponse.AUTH_KEY))
+            throw new SaslException("Extension name " + OAuthBearerClientInitialResponse.AUTH_KEY + " is invalid");
+
+        for (Map.Entry<String, String> entry : extensions.map().entrySet()) {
+            String extensionName = entry.getKey();
+            String extensionValue = entry.getValue();
+
+            if (!EXTENSION_KEY_PATTERN.matcher(extensionName).matches())
+                throw new SaslException("Extension name " + extensionName + " is invalid");
+            if (!EXTENSION_VALUE_PATTERN.matcher(extensionValue).matches())
+                throw new SaslException("Extension value (" + extensionValue + ") for extension " + extensionName + " is invalid");
+        }
+        return extensions;
+    }
+
+    /**
+     * Converts the SASLExtensions to an OAuth protocol-friendly string
+     */
+    private String extensionsMessage() {
+        return Utils.mkString(saslExtensions.map(), "", "", "=", SEPARATOR);
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
index 4d4ee57..16db3c8 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
@@ -30,6 +30,8 @@ import javax.security.sasl.SaslClientFactory;
 import javax.security.sasl.SaslException;
 
 import org.apache.kafka.common.errors.IllegalSaslStateException;
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
@@ -42,7 +44,8 @@ import org.slf4j.LoggerFactory;
  * implementation requires an instance of {@code AuthenticateCallbackHandler}
  * that can handle an instance of {@link OAuthBearerTokenCallback} and return
  * the {@link OAuthBearerToken} generated by the {@code login()} event on the
- * {@code LoginContext}.
+ * {@code LoginContext}. Said handler can also optionally handle an instance of {@link SaslExtensionsCallback}
+ * to return any extensions generated by the {@code login()} event on the {@code LoginContext}.
  *
  * @see <a href="https://tools.ietf.org/html/rfc6750#section-2.1">RFC 6750,
  *      Section 2.1</a>
@@ -87,8 +90,11 @@ public class OAuthBearerSaslClient implements SaslClient {
                     if (challenge != null && challenge.length != 0)
                         throw new SaslException("Expected empty challenge");
                     callbackHandler().handle(new Callback[] {callback});
+                    SaslExtensions extensions = retrieveCustomExtensions();
+
                     setState(State.RECEIVE_SERVER_FIRST_MESSAGE);
-                    return new OAuthBearerClientInitialResponse(callback.token().value()).toBytes();
+
+                    return new OAuthBearerClientInitialResponse(callback.token().value(), extensions).toBytes();
                 case RECEIVE_SERVER_FIRST_MESSAGE:
                     if (challenge != null && challenge.length != 0) {
                         String jsonErrorResponse = new String(challenge, StandardCharsets.UTF_8);
@@ -150,6 +156,20 @@ public class OAuthBearerSaslClient implements SaslClient {
         this.state = state;
     }
 
+    private SaslExtensions retrieveCustomExtensions() throws SaslException {
+        SaslExtensionsCallback extensionsCallback = new SaslExtensionsCallback();
+        try {
+            callbackHandler().handle(new Callback[] {extensionsCallback});
+        } catch (UnsupportedCallbackException e) {
+            log.debug("Extensions callback is not supported by client callback handler {}, no extensions will be added",
+                    callbackHandler());
+        } catch (Exception e) {
+            throw new SaslException("SASL extensions could not be obtained", e);
+        }
+
+        return extensionsCallback.extensions();
+    }
+
     public static class OAuthBearerSaslClientFactory implements SaslClientFactory {
         @Override
         public SaslClient createSaslClient(String[] mechanisms, String authorizationId, String protocol,
diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
index 586c523..ab2b716 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
@@ -28,7 +28,9 @@ import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.login.AppConfigurationEntry;
 
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
@@ -38,7 +40,9 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
  * {@link OAuthBearerTokenCallback} and retrieves OAuth 2 Bearer Token that was
  * created when the {@code OAuthBearerLoginModule} logged in by looking for an
  * instance of {@link OAuthBearerToken} in the {@code Subject}'s private
- * credentials.
+ * credentials. This class also recognizes {@link SaslExtensionsCallback} and retrieves any SASL extensions that were
+ * created when the {@code OAuthBearerLoginModule} logged in by looking for an instance of {@link SaslExtensions}
+ * in the {@code Subject}'s public credentials
  * <p>
  * Use of this class is configured automatically and does not need to be
  * explicitly set via the {@code sasl.client.callback.handler.class}
@@ -70,6 +74,8 @@ public class OAuthBearerSaslClientCallbackHandler implements AuthenticateCallbac
         for (Callback callback : callbacks) {
             if (callback instanceof OAuthBearerTokenCallback)
                 handleCallback((OAuthBearerTokenCallback) callback);
+            else if (callback instanceof SaslExtensionsCallback)
+                handleCallback((SaslExtensionsCallback) callback, Subject.getSubject(AccessController.getContext()));
             else
                 throw new UnsupportedCallbackException(callback);
         }
@@ -93,4 +99,14 @@ public class OAuthBearerSaslClientCallbackHandler implements AuthenticateCallbac
                             privateCredentials.size()));
         callback.token(privateCredentials.iterator().next());
     }
+
+    /**
+     * Attaches the first {@link SaslExtensions} found in the public credentials of the Subject
+     */
+    private static void handleCallback(SaslExtensionsCallback extensionsCallback, Subject subject) {
+        if (subject != null && !subject.getPublicCredentials(SaslExtensions.class).isEmpty()) {
+            SaslExtensions extensions = subject.getPublicCredentials(SaslExtensions.class).iterator().next();
+            extensionsCallback.extensions(extensions);
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
index aacc8fa..6573f69 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
@@ -31,6 +31,7 @@ import javax.security.sasl.SaslServer;
 import javax.security.sasl.SaslServerFactory;
 
 import org.apache.kafka.common.errors.SaslAuthenticationException;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
@@ -46,6 +47,7 @@ import org.slf4j.LoggerFactory;
  * for example).
  */
 public class OAuthBearerSaslServer implements SaslServer {
+
     private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslServer.class);
     private static final String NEGOTIATED_PROPERTY_KEY_TOKEN = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM + ".token";
     private static final String INTERNAL_ERROR_ON_SERVER = "Authentication could not be performed due to an internal error on the server";
@@ -55,6 +57,7 @@ public class OAuthBearerSaslServer implements SaslServer {
     private boolean complete;
     private OAuthBearerToken tokenForNegotiatedProperty = null;
     private String errorMessage = null;
+    private SaslExtensions extensions;
 
     public OAuthBearerSaslServer(CallbackHandler callbackHandler) {
         if (!(Objects.requireNonNull(callbackHandler) instanceof AuthenticateCallbackHandler))
@@ -84,6 +87,7 @@ public class OAuthBearerSaslServer implements SaslServer {
             throw new SaslAuthenticationException(errorMessage);
         }
         errorMessage = null;
+
         OAuthBearerClientInitialResponse clientResponse;
         try {
             clientResponse = new OAuthBearerClientInitialResponse(response);
@@ -91,7 +95,8 @@ public class OAuthBearerSaslServer implements SaslServer {
             log.debug(e.getMessage());
             throw e;
         }
-        return process(clientResponse.tokenValue(), clientResponse.authorizationId());
+
+        return process(clientResponse.tokenValue(), clientResponse.authorizationId(), clientResponse.extensions());
     }
 
     @Override
@@ -110,7 +115,10 @@ public class OAuthBearerSaslServer implements SaslServer {
     public Object getNegotiatedProperty(String propName) {
         if (!complete)
             throw new IllegalStateException("Authentication exchange has not completed");
-        return NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName) ? tokenForNegotiatedProperty : null;
+        if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName))
+            return tokenForNegotiatedProperty;
+
+        return extensions.map().get(propName);
     }
 
     @Override
@@ -136,9 +144,10 @@ public class OAuthBearerSaslServer implements SaslServer {
     public void dispose() throws SaslException {
         complete = false;
         tokenForNegotiatedProperty = null;
+        extensions = null;
     }
 
-    private byte[] process(String tokenValue, String authorizationId) throws SaslException {
+    private byte[] process(String tokenValue, String authorizationId, SaslExtensions extensions) throws SaslException {
         OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(tokenValue);
         try {
             callbackHandler.handle(new Callback[] {callback});
@@ -165,6 +174,7 @@ public class OAuthBearerSaslServer implements SaslServer {
                     "Authentication failed: Client requested an authorization id (%s) that is different from the token's principal name (%s)",
                     authorizationId, token.principalName()));
         tokenForNegotiatedProperty = token;
+        this.extensions = extensions;
         complete = true;
         if (log.isDebugEnabled())
             log.debug("Successfully authenticate User={}", token.principalName());
diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
index 67a75ae..88399ac 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
@@ -22,6 +22,7 @@ import java.util.Arrays;
 import java.util.Base64;
 import java.util.Base64.Encoder;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -31,18 +32,23 @@ import java.util.Set;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.sasl.SaslException;
 
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
+import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
 import org.apache.kafka.common.utils.Time;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
  * A {@code CallbackHandler} that recognizes {@link OAuthBearerTokenCallback}
- * and returns an unsecured OAuth 2 bearer token.
+ * to return an unsecured OAuth 2 bearer token and {@link SaslExtensionsCallback} to return SASL extensions
  * <p>
  * Claims and their values on the returned token can be specified using
  * {@code unsecuredLoginStringClaim_<claimname>},
@@ -52,6 +58,11 @@ import org.slf4j.LoggerFactory;
  * name and value except '{@code iat}' and '{@code exp}', both of which are
  * calculated automatically.
  * <p>
+ * <p>
+ * You can also add custom unsecured SASL extensions using
+ * {@code unsecuredLoginExtension_<extensionname>}. Extension keys and values are subject to regex validation.
+ * The extension key must also not be equal to the reserved key {@link OAuthBearerClientInitialResponse#AUTH_KEY}
+ * <p>
  * This implementation also accepts the following options:
  * <ul>
  * <li>{@code unsecuredLoginPrincipalClaimName} set to a custom claim name if
@@ -72,7 +83,8 @@ import org.slf4j.LoggerFactory;
  *      org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required
  *      unsecuredLoginStringClaim_sub="thePrincipalName"
  *      unsecuredLoginListClaim_scope="|scopeValue1|scopeValue2"
- *      unsecuredLoginLifetimeSeconds="60";
+ *      unsecuredLoginLifetimeSeconds="60"
+ *      unsecuredLoginExtension_traceId="123";
  * };
  * </pre>
  * 
@@ -96,6 +108,7 @@ public class OAuthBearerUnsecuredLoginCallbackHandler implements AuthenticateCal
     private static final String STRING_CLAIM_PREFIX = OPTION_PREFIX + "StringClaim_";
     private static final String NUMBER_CLAIM_PREFIX = OPTION_PREFIX + "NumberClaim_";
     private static final String LIST_CLAIM_PREFIX = OPTION_PREFIX + "ListClaim_";
+    private static final String EXTENSION_PREFIX = OPTION_PREFIX + "Extension_";
     private static final String QUOTE = "\"";
     private Time time = Time.SYSTEM;
     private Map<String, String> moduleOptions = null;
@@ -140,7 +153,13 @@ public class OAuthBearerUnsecuredLoginCallbackHandler implements AuthenticateCal
         for (Callback callback : callbacks) {
             if (callback instanceof OAuthBearerTokenCallback)
                 try {
-                    handleCallback((OAuthBearerTokenCallback) callback);
+                    handleTokenCallback((OAuthBearerTokenCallback) callback);
+                } catch (KafkaException e) {
+                    throw new IOException(e.getMessage(), e);
+                }
+            else if (callback instanceof SaslExtensionsCallback)
+                try {
+                    handleExtensionsCallback((SaslExtensionsCallback) callback);
                 } catch (KafkaException e) {
                     throw new IOException(e.getMessage(), e);
                 }
@@ -154,7 +173,7 @@ public class OAuthBearerUnsecuredLoginCallbackHandler implements AuthenticateCal
         // empty
     }
 
-    private void handleCallback(OAuthBearerTokenCallback callback) throws IOException {
+    private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException {
         if (callback.token() != null)
             throw new IllegalArgumentException("Callback had a token already");
         String principalClaimNameValue = optionValue(PRINCIPAL_CLAIM_NAME_OPTION);
@@ -190,6 +209,30 @@ public class OAuthBearerUnsecuredLoginCallbackHandler implements AuthenticateCal
         }
     }
 
+    /**
+     *  Add and validate all the configured extensions.
+     *  Token keys, apart from passing regex validation, must not be equal to the reserved key {@link OAuthBearerClientInitialResponse#AUTH_KEY}
+     */
+    private void handleExtensionsCallback(SaslExtensionsCallback callback) {
+        Map<String, String> extensions = new HashMap<>();
+        for (Map.Entry<String, String> configEntry : this.moduleOptions.entrySet()) {
+            String key = configEntry.getKey();
+            if (!key.startsWith(EXTENSION_PREFIX))
+                continue;
+
+            extensions.put(key.substring(EXTENSION_PREFIX.length()), configEntry.getValue());
+        }
+
+        SaslExtensions saslExtensions = new SaslExtensions(extensions);
+        try {
+            OAuthBearerClientInitialResponse.validateExtensions(saslExtensions);
+        } catch (SaslException e) {
+            throw new ConfigException(e.getMessage());
+        }
+
+        callback.extensions(saslExtensions);
+    }
+
     private String commaPrependedStringNumberAndListClaimsJsonText() throws OAuthBearerConfigException {
         StringBuilder sb = new StringBuilder();
         for (String key : moduleOptions.keySet()) {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
index debe163..b83c94e 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
@@ -14,13 +14,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.kafka.common.security.scram;
 
 import javax.security.auth.callback.Callback;
 import java.util.Collections;
 import java.util.Map;
 
+
 /**
  * Optional callback used for SCRAM mechanisms if any extensions need to be set
  * in the SASL/SCRAM exchange.
@@ -29,18 +29,18 @@ public class ScramExtensionsCallback implements Callback {
     private Map<String, String> extensions = Collections.emptyMap();
 
     /**
-     * Returns the extension names and values that are sent by the client to
+     * Returns map of the extension names and values that are sent by the client to
      * the server in the initial client SCRAM authentication message.
-     * Default is an empty map.
+     * Default is an empty unmodifiable map.
      */
     public Map<String, String> extensions() {
         return extensions;
     }
 
     /**
-     * Sets the SCRAM extensions on this callback.
+     * Sets the SCRAM extensions on this callback. Maps passed in should be unmodifiable
      */
     public void extensions(Map<String, String> extensions) {
         this.extensions = extensions;
     }
-}
\ No newline at end of file
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
index 5028329..7b51890 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
@@ -16,15 +16,14 @@
  */
 package org.apache.kafka.common.security.scram.internals;
 
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
 import org.apache.kafka.common.utils.Utils;
 
 import java.util.Collections;
 import java.util.Map;
-import java.util.Set;
 
-public class ScramExtensions {
-    private final Map<String, String> extensionMap;
+public class ScramExtensions extends SaslExtensions {
 
     public ScramExtensions() {
         this(Collections.<String, String>emptyMap());
@@ -35,23 +34,10 @@ public class ScramExtensions {
     }
 
     public ScramExtensions(Map<String, String> extensionMap) {
-        this.extensionMap = extensionMap;
-    }
-
-    public String extensionValue(String name) {
-        return extensionMap.get(name);
-    }
-
-    public Set<String> extensionNames() {
-        return extensionMap.keySet();
+        super(extensionMap);
     }
 
     public boolean tokenAuthenticated() {
-        return Boolean.parseBoolean(extensionMap.get(ScramLoginModule.TOKEN_AUTH_CONFIG));
-    }
-
-    @Override
-    public String toString() {
-        return Utils.mkString(extensionMap, "", "", "=", ",");
+        return Boolean.parseBoolean(map().get(ScramLoginModule.TOKEN_AUTH_CONFIG));
     }
-}
\ No newline at end of file
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
index b56d759..0551296 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.security.scram.internals;
 
+import org.apache.kafka.common.utils.Utils;
+
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
 import java.util.Map;
@@ -112,7 +114,8 @@ public class ScramMessages {
         }
 
         public String clientFirstMessageBare() {
-            String extensionStr = extensions.toString();
+            String extensionStr = Utils.mkString(extensions.map(), "", "", "=", ",");
+
             if (extensionStr.isEmpty())
                 return String.format("n=%s,r=%s", saslName, nonce);
             else
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
index d464e89..b11300a 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
@@ -98,9 +98,9 @@ public class ScramSaslServer implements SaslServer {
                 case RECEIVE_CLIENT_FIRST_MESSAGE:
                     this.clientFirstMessage = new ClientFirstMessage(response);
                     this.scramExtensions = clientFirstMessage.extensions();
-                    if (!SUPPORTED_EXTENSIONS.containsAll(scramExtensions.extensionNames())) {
+                    if (!SUPPORTED_EXTENSIONS.containsAll(scramExtensions.map().keySet())) {
                         log.debug("Unsupported extensions will be ignored, supported {}, provided {}",
-                                SUPPORTED_EXTENSIONS, scramExtensions.extensionNames());
+                                SUPPORTED_EXTENSIONS, scramExtensions.map().keySet());
                     }
                     String serverNonce = formatter.secureRandomString();
                     try {
@@ -183,7 +183,7 @@ public class ScramSaslServer implements SaslServer {
             throw new IllegalStateException("Authentication exchange has not completed");
 
         if (SUPPORTED_EXTENSIONS.contains(propName))
-            return scramExtensions.extensionValue(propName);
+            return scramExtensions.map().get(propName);
         else
             return null;
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 07f91a7..6e0b693 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -498,6 +498,12 @@ public final class Utils {
         return sb.toString();
     }
 
+    /**
+     *  Converts a {@code Map} class into a string, concatenating keys and values
+     *  Example:
+     *      {@code mkString({ key: "hello", keyTwo: "hi" }, "|START|", "|END|", "=", ",")
+     *          => "|START|key=hello,keyTwo=hi|END|"}
+     */
     public static <K, V> String mkString(Map<K, V> map, String begin, String end,
                                          String keyValueSeparator, String elementSeparator) {
         StringBuilder bld = new StringBuilder();
@@ -512,6 +518,13 @@ public final class Utils {
         return bld.toString();
     }
 
+    /**
+     *  Converts an extensions string into a {@code Map<String, String>}.
+     *
+     *  Example:
+     *      {@code parseMap("key=hey,keyTwo=hi,keyThree=hello", "=", ",") => { key: "hey", keyTwo: "hi", keyThree: "hello" }}
+     *
+     */
     public static Map<String, String> parseMap(String mapStr, String keyValueSeparator, String elementSeparator) {
         Map<String, String> map = new HashMap<>();
 
diff --git a/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java b/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java
new file mode 100644
index 0000000..77a4523
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security;
+
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertNull;
+
+public class SaslExtensionsTest {
+    Map<String, String> map;
+
+    @Before
+    public void setUp() {
+        this.map = new HashMap<>();
+        this.map.put("what", "42");
+        this.map.put("who", "me");
+    }
+
+    @Test(expected = UnsupportedOperationException.class)
+    public void testReturnedMapIsImmutable() {
+        SaslExtensions extensions = new SaslExtensions(this.map);
+        extensions.map().put("hello", "test");
+    }
+
+    @Test
+    public void testCannotAddValueToMapReferenceAndGetFromExtensions() {
+        SaslExtensions extensions = new SaslExtensions(this.map);
+
+        assertNull(extensions.map().get("hello"));
+        this.map.put("hello", "42");
+        assertNull(extensions.map().get("hello"));
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
index d883e5e..a9620fa 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
@@ -19,6 +19,8 @@ package org.apache.kafka.common.security.oauthbearer;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertNotNull;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -36,16 +38,24 @@ import javax.security.auth.login.LoginException;
 
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.easymock.EasyMock;
 import org.junit.Test;
 
 public class OAuthBearerLoginModuleTest {
-    private static class TestTokenCallbackHandler implements AuthenticateCallbackHandler {
+
+    public static final SaslExtensions RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG = null;
+
+    private static class TestCallbackHandler implements AuthenticateCallbackHandler {
         private final OAuthBearerToken[] tokens;
         private int index = 0;
+        private int extensionsIndex = 0;
+        private final SaslExtensions[] extensions;
 
-        public TestTokenCallbackHandler(OAuthBearerToken[] tokens) {
+        public TestCallbackHandler(OAuthBearerToken[] tokens, SaslExtensions[] extensions) {
             this.tokens = Objects.requireNonNull(tokens);
+            this.extensions = extensions;
         }
 
         @Override
@@ -57,7 +67,13 @@ public class OAuthBearerLoginModuleTest {
                     } catch (KafkaException e) {
                         throw new IOException(e.getMessage(), e);
                     }
-                else
+                else if (callback instanceof SaslExtensionsCallback) {
+                    try {
+                        handleExtensionsCallback((SaslExtensionsCallback) callback);
+                    } catch (KafkaException e) {
+                        throw new IOException(e.getMessage(), e);
+                    }
+                } else
                     throw new UnsupportedCallbackException(callback);
             }
         }
@@ -81,6 +97,19 @@ public class OAuthBearerLoginModuleTest {
             else
                 throw new IOException("no more tokens");
         }
+
+        private void handleExtensionsCallback(SaslExtensionsCallback callback) throws IOException, UnsupportedCallbackException {
+            if (extensions.length > extensionsIndex) {
+                SaslExtensions extension = extensions[extensionsIndex++];
+
+                if (extension == RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG) {
+                    throw new UnsupportedCallbackException(callback);
+                }
+
+                callback.extensions(extension);
+            } else
+                throw new IOException("no more extensions");
+        }
     }
 
     @Test
@@ -92,12 +121,16 @@ public class OAuthBearerLoginModuleTest {
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] {EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class), EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] {EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class), EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1], tokens[2]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[2]);
+        TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions);
 
         // Create login modules
         OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
@@ -112,47 +145,68 @@ public class OAuthBearerLoginModuleTest {
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.commit();
-        // Now we should have the first token
+        // Now we should have the first token and extensions
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
 
         // Now login on loginModule2 to get the second token
+        // loginModule2 does not support the extensions callback and will raise UnsupportedCallbackException
         loginModule2.login();
-        // Should still have just the first token
+        // Should still have just the first token and extensions
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule2.commit();
         // Should have the first and second tokens at this point
         assertEquals(2, privateCredentials.size());
+        assertEquals(2, publicCredentials.size());
         Iterator<Object> iterator = privateCredentials.iterator();
+        Iterator<Object> publicIterator = publicCredentials.iterator();
         assertNotSame(tokens[2], iterator.next());
         assertNotSame(tokens[2], iterator.next());
+        assertNotSame(extensions[2], publicIterator.next());
+        assertNotSame(extensions[2], publicIterator.next());
         // finally logout() on loginModule1
         loginModule1.logout();
-        // Now we should have just the second token
+        // Now we should have just the second token and extension
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
 
         // Now login on loginModule3 to get the third token
         loginModule3.login();
-        // Should still have just the second token
+        // Should still have just the second token and extensions
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
         loginModule3.commit();
         // Should have the second and third tokens at this point
         assertEquals(2, privateCredentials.size());
+        assertEquals(2, publicCredentials.size());
         iterator = privateCredentials.iterator();
+        publicIterator = publicCredentials.iterator();
         assertNotSame(tokens[0], iterator.next());
         assertNotSame(tokens[0], iterator.next());
+        assertNotSame(extensions[0], publicIterator.next());
+        assertNotSame(extensions[0], publicIterator.next());
         // finally logout() on loginModule2
         loginModule2.logout();
         // Now we should have just the third token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[2], privateCredentials.iterator().next());
+        assertSame(extensions[2], publicCredentials.iterator().next());
     }
 
     @Test
@@ -163,12 +217,16 @@ public class OAuthBearerLoginModuleTest {
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] {EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] {EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[1]);
+        TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions);
 
         // Create login modules
         OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
@@ -180,27 +238,36 @@ public class OAuthBearerLoginModuleTest {
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.commit();
         // Now we should have the first token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule1.logout();
         // Should have nothing again
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
 
         loginModule2.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule2.commit();
         // Now we should have the second token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
         loginModule2.logout();
         // Should have nothing again
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
     }
 
     @Test
@@ -210,12 +277,16 @@ public class OAuthBearerLoginModuleTest {
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] {EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] {EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[1]);
+        TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions);
 
         // Create login module
         OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule();
@@ -224,23 +295,30 @@ public class OAuthBearerLoginModuleTest {
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule.abort();
         // Should still have nothing since we aborted
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
 
         loginModule.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule.commit();
         // Now we should have the second token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
         loginModule.logout();
         // Should have nothing again
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
     }
 
     @Test
@@ -251,12 +329,16 @@ public class OAuthBearerLoginModuleTest {
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] {EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class), EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] {EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class), EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1], tokens[2]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[1], extensions[2]);
+        TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions);
 
         // Create login modules
         OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
@@ -271,38 +353,81 @@ public class OAuthBearerLoginModuleTest {
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.commit();
         // Now we should have the first token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
 
         // Now go get the second token
         loginModule2.login();
         // Should still have first token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule2.abort();
         // Should still have just the first token because we aborted
         assertEquals(1, privateCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertEquals(1, publicCredentials.size());
+        assertSame(extensions[0], publicCredentials.iterator().next());
 
         // Now go get the third token
         loginModule2.login();
         // Should still have first token
         assertEquals(1, privateCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertEquals(1, publicCredentials.size());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule2.commit();
         // Should have first and third tokens at this point
         assertEquals(2, privateCredentials.size());
         Iterator<Object> iterator = privateCredentials.iterator();
         assertNotSame(tokens[1], iterator.next());
         assertNotSame(tokens[1], iterator.next());
+        assertEquals(2, publicCredentials.size());
+        Iterator<Object> publicIterator = publicCredentials.iterator();
+        assertNotSame(extensions[1], publicIterator.next());
+        assertNotSame(extensions[1], publicIterator.next());
         loginModule1.logout();
         // Now we should have just the third token
         assertEquals(1, privateCredentials.size());
         assertSame(tokens[2], privateCredentials.iterator().next());
+        assertEquals(1, publicCredentials.size());
+        assertSame(extensions[2], publicCredentials.iterator().next());
+    }
+
+    /**
+     * 2.1.0 added customizable SASL extensions and a new callback type.
+     * Ensure that old, custom-written callbackHandlers that do not handle the callback work
+     */
+    @Test
+    public void commitDoesNotThrowOnUnsupportedExtensionsCallback() throws LoginException {
+        Subject subject = new Subject();
+
+        // Create callback handler
+        OAuthBearerToken[] tokens = new OAuthBearerToken[] {EasyMock.mock(OAuthBearerToken.class),
+                EasyMock.mock(OAuthBearerToken.class), EasyMock.mock(OAuthBearerToken.class)};
+        EasyMock.replay(tokens[0], tokens[1], tokens[2]); // expect nothing
+        TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, new SaslExtensions[] {RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG});
+
+        // Create login modules
+        OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
+        loginModule1.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(),
+                Collections.emptyMap());
+
+        loginModule1.login();
+        // Should populate public credentials with SaslExtensions and not throw an exception
+        loginModule1.commit();
+        SaslExtensions extensions = subject.getPublicCredentials(SaslExtensions.class).iterator().next();
+        assertNotNull(extensions);
+        assertTrue(extensions.map().isEmpty());
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
index eccf2dd..3de6408 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
@@ -18,12 +18,49 @@ package org.apache.kafka.common.security.oauthbearer.internals;
 
 import static org.junit.Assert.assertEquals;
 
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.junit.Test;
 
+import javax.security.sasl.SaslException;
 import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
 
 public class OAuthBearerClientInitialResponseTest {
 
+    /*
+        Test how a client would build a response
+     */
+    @Test
+    public void testBuildClientResponseToBytes() throws Exception {
+        String expectedMesssage = "n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001";
+
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("nineteen", "42");
+        OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions));
+
+        String message = new String(response.toBytes(), StandardCharsets.UTF_8);
+
+        assertEquals(expectedMesssage, message);
+    }
+
+    @Test
+    public void testBuildServerResponseToBytes() throws Exception {
+        String serverMessage = "n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001";
+        OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(serverMessage.getBytes(StandardCharsets.UTF_8));
+
+        String message = new String(response.toBytes(), StandardCharsets.UTF_8);
+
+        assertEquals(serverMessage, message);
+    }
+
+    @Test(expected = SaslException.class)
+    public void testThrowsSaslExceptionOnInvalidExtensionKey() throws Exception {
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("19", "42"); // keys can only be a-z
+        new OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions));
+    }
+
     @Test
     public void testToken() throws Exception {
         String message = "n,,\u0001auth=Bearer 123.345.567\u0001\u0001";
@@ -41,13 +78,13 @@ public class OAuthBearerClientInitialResponseTest {
     }
 
     @Test
-    public void testProperties() throws Exception {
+    public void testExtensions() throws Exception {
         String message = "n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 567\u0001propB=valueB\u0001\u0001";
         OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8));
         assertEquals("567", response.tokenValue());
         assertEquals("", response.authorizationId());
-        assertEquals("valueA1, valueA2", response.propertyValue("propA"));
-        assertEquals("valueB", response.propertyValue("propB"));
+        assertEquals("valueA1, valueA2", response.extensions().map().get("propA"));
+        assertEquals("valueB", response.extensions().map().get("propB"));
     }
 
     // The example in the RFC uses `vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==` as the token
@@ -59,7 +96,7 @@ public class OAuthBearerClientInitialResponseTest {
         OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8));
         assertEquals("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", response.tokenValue());
         assertEquals("user@example.com", response.authorizationId());
-        assertEquals("server.example.com", response.propertyValue("host"));
-        assertEquals("143", response.propertyValue("port"));
+        assertEquals("server.example.com", response.extensions().map().get("host"));
+        assertEquals("143", response.extensions().map().get("port"));
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java
new file mode 100644
index 0000000..55a8624
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.oauthbearer.internals;
+
+import org.apache.kafka.common.config.ConfigException;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
+import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredJws;
+import org.easymock.EasyMockSupport;
+import org.junit.Test;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.sasl.SaslException;
+import java.nio.charset.StandardCharsets;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class OAuthBearerSaslClientTest extends EasyMockSupport {
+
+    private static final Map<String, String> TEST_PROPERTIES = new LinkedHashMap<String, String>() {
+        {
+            put("One", "1");
+            put("Two", "2");
+            put("Three", "3");
+        }
+    };
+    private SaslExtensions testExtensions = new SaslExtensions(TEST_PROPERTIES);
+    private final String errorMessage = "Error as expected!";
+
+    public class ExtensionsCallbackHandler implements AuthenticateCallbackHandler {
+        private boolean configured = false;
+        private boolean toThrow;
+
+        ExtensionsCallbackHandler(boolean toThrow) {
+            this.toThrow = toThrow;
+        }
+
+        public boolean configured() {
+            return configured;
+        }
+
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            configured = true;
+        }
+
+        @Override
+        public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
+            for (Callback callback : callbacks) {
+                if (callback instanceof OAuthBearerTokenCallback)
+                    ((OAuthBearerTokenCallback) callback).token(createMock(OAuthBearerUnsecuredJws.class));
+                else if (callback instanceof SaslExtensionsCallback) {
+                    if (toThrow)
+                        throw new ConfigException(errorMessage);
+                    else
+                        ((SaslExtensionsCallback) callback).extensions(testExtensions);
+                } else
+                    throw new UnsupportedCallbackException(callback);
+            }
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    @Test
+    public void testAttachesExtensionsToFirstClientMessage() throws Exception {
+        String expectedToken = new String(new OAuthBearerClientInitialResponse(null, testExtensions).toBytes(), StandardCharsets.UTF_8);
+
+        OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false));
+
+        String message = new String(client.evaluateChallenge("".getBytes()), StandardCharsets.UTF_8);
+
+        assertEquals(expectedToken, message);
+    }
+
+    @Test
+    public void testNoExtensionsDoesNotAttachAnythingToFirstClientMessage() throws Exception {
+        TEST_PROPERTIES.clear();
+        testExtensions = new SaslExtensions(TEST_PROPERTIES);
+        String expectedToken = new String(new OAuthBearerClientInitialResponse(null, new SaslExtensions(TEST_PROPERTIES)).toBytes(), StandardCharsets.UTF_8);
+        OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false));
+
+        String message = new String(client.evaluateChallenge("".getBytes()), StandardCharsets.UTF_8);
+
+        assertEquals(expectedToken, message);
+    }
+
+    @Test
+    public void testWrapsExtensionsCallbackHandlingErrorInSaslExceptionInFirstClientMessage() {
+        OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(true));
+        try {
+            client.evaluateChallenge("".getBytes());
+            fail("Should have failed with " + SaslException.class.getName());
+        } catch (SaslException e) {
+            // assert it has caught our expected exception
+            assertEquals(ConfigException.class, e.getCause().getClass());
+            assertEquals(errorMessage, e.getCause().getMessage());
+        }
+
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
index 6b53e96..fc96f9f 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
@@ -18,6 +18,7 @@ package org.apache.kafka.common.security.oauthbearer.internals;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertNull;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
@@ -34,6 +35,7 @@ import org.apache.kafka.common.config.types.Password;
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
@@ -68,7 +70,7 @@ public class OAuthBearerSaslServerTest {
     private OAuthBearerSaslServer saslServer;
 
     @Before
-    public void setUp() throws Exception {
+    public void setUp() {
         saslServer = new OAuthBearerSaslServer(VALIDATOR_CALLBACK_HANDLER);
     }
 
@@ -80,6 +82,32 @@ public class OAuthBearerSaslServerTest {
     }
 
     @Test
+    public void savesCustomExtensionAsNegotiatedProperty() throws Exception {
+        Map<String, String> customExtensions = new HashMap<>();
+        customExtensions.put("firstKey", "value1");
+        customExtensions.put("secondKey", "value2");
+
+        byte[] nextChallenge = saslServer
+                .evaluateResponse(clientInitialResponse(null, false, customExtensions));
+
+        assertTrue("Next challenge is not empty", nextChallenge.length == 0);
+        assertEquals("value1", saslServer.getNegotiatedProperty("firstKey"));
+        assertEquals("value2", saslServer.getNegotiatedProperty("secondKey"));
+    }
+
+    @Test
+    public void returnsNullForNonExistentProperty() throws Exception {
+        Map<String, String> customExtensions = new HashMap<>();
+        customExtensions.put("firstKey", "value1");
+
+        byte[] nextChallenge = saslServer
+                .evaluateResponse(clientInitialResponse(null, false, customExtensions));
+
+        assertTrue("Next challenge is not empty", nextChallenge.length == 0);
+        assertNull(saslServer.getNegotiatedProperty("secondKey"));
+    }
+
+    @Test
     public void authorizatonIdEqualsAuthenticationId() throws Exception {
         byte[] nextChallenge = saslServer
                 .evaluateResponse(clientInitialResponse(USER));
@@ -93,7 +121,7 @@ public class OAuthBearerSaslServerTest {
 
     @Test
     public void illegalToken() throws Exception {
-        byte[] bytes = saslServer.evaluateResponse(clientInitialResponse(null, true));
+        byte[] bytes = saslServer.evaluateResponse(clientInitialResponse(null, true, Collections.emptyMap()));
         String challenge = new String(bytes, StandardCharsets.UTF_8);
         assertEquals("{\"status\":\"invalid_token\"}", challenge);
     }
@@ -105,11 +133,17 @@ public class OAuthBearerSaslServerTest {
 
     private byte[] clientInitialResponse(String authorizationId, boolean illegalToken)
             throws OAuthBearerConfigException, IOException, UnsupportedCallbackException, LoginException {
+        return clientInitialResponse(authorizationId, false, Collections.emptyMap());
+    }
+
+    private byte[] clientInitialResponse(String authorizationId, boolean illegalToken, Map<String, String> customExtensions)
+            throws OAuthBearerConfigException, IOException, UnsupportedCallbackException {
         OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
         LOGIN_CALLBACK_HANDLER.handle(new Callback[] {callback});
         OAuthBearerToken token = callback.token();
         String compactSerialization = token.value();
+
         String tokenValue = compactSerialization + (illegalToken ? "AB" : "");
-        return new OAuthBearerClientInitialResponse(tokenValue, authorizationId, Collections.emptyMap()).toBytes();
+        return new OAuthBearerClientInitialResponse(tokenValue, authorizationId, new SaslExtensions(customExtensions)).toBytes();
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
index a5c216d..be01fe3 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
@@ -31,6 +31,7 @@ import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.login.LoginException;
 
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.authenticator.TestJaasConfig;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
@@ -38,6 +39,39 @@ import org.apache.kafka.common.utils.MockTime;
 import org.junit.Test;
 
 public class OAuthBearerUnsecuredLoginCallbackHandlerTest {
+
+    @Test
+    public void addsExtensions() throws IOException, UnsupportedCallbackException {
+        Map<String, String> options = new HashMap<>();
+        options.put("unsecuredLoginExtension_testId", "1");
+        OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, new MockTime());
+        SaslExtensionsCallback callback = new SaslExtensionsCallback();
+
+        callbackHandler.handle(new Callback[] {callback});
+
+        assertEquals("1", callback.extensions().map().get("testId"));
+    }
+
+    @Test(expected = IOException.class)
+    public void throwsErrorOnInvalidExtensionName() throws IOException, UnsupportedCallbackException {
+        Map<String, String> options = new HashMap<>();
+        options.put("unsecuredLoginExtension_test.Id", "1");
+        OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, new MockTime());
+        SaslExtensionsCallback callback = new SaslExtensionsCallback();
+
+        callbackHandler.handle(new Callback[] {callback});
+    }
+
+    @Test(expected = IOException.class)
+    public void throwsErrorOnInvalidExtensionValue() throws IOException, UnsupportedCallbackException {
+        Map<String, String> options = new HashMap<>();
+        options.put("unsecuredLoginExtension_testId", "Çalifornia");
+        OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, new MockTime());
+        SaslExtensionsCallback callback = new SaslExtensionsCallback();
+
+        callbackHandler.handle(new Callback[] {callback});
+    }
+
     @Test
     public void minimalToken() throws IOException, UnsupportedCallbackException {
         Map<String, String> options = new HashMap<>();
diff --git a/docs/security.html b/docs/security.html
index 7f76509..743d673 100644
--- a/docs/security.html
+++ b/docs/security.html
@@ -750,6 +750,13 @@
                  automatically generated).</td>
                  </tr>
                  <tr>
+                 <td><tt>unsecuredLoginExtension_&lt;extensionname&gt;="value"</tt></td>
+                 <td>Creates a <tt>String</tt> extension with the given name and value.
+                 For example: <tt>unsecuredLoginExtension_traceId="123"</tt>. A valid extension name
+                 is any sequence of lowercase or uppercase alphabet characters. In addition, the "auth" extension name is reserved.
+                 A valid extension value is any combination of characters with ASCII codes 1-127.
+                 </tr>
+                 <tr>
                  <td><tt>unsecuredLoginPrincipalClaimName</tt></td>
                  <td>Set to a custom claim name if you wish the name of the <tt>String</tt>
                  claim holding the principal name to be something other than '<tt>sub</tt>'.</td>


Mime
View raw message