Skip to content

Commit

Permalink
Proxy initiated oauthbearer auth
Browse files Browse the repository at this point in the history
  • Loading branch information
everesio committed Nov 25, 2018
1 parent 43640ba commit 5edf5c4
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sudo: false
language: go

go:
- "1.10.x"
- "1.11.x"

env:
global:
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.build
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM golang:1.10 as builder
FROM golang:1.11 as builder

ARG GOOS=linux
ARG GOARCH=amd64
Expand Down
20 changes: 10 additions & 10 deletions proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type Client struct {
stopRun chan struct{}
stopOnce sync.Once

saslPlainAuth *SASLPlainAuth
authClient *AuthClient
saslAuthByProxy SASLAuthByProxy
authClient *AuthClient
}

func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.NetAddressMappingFunc, localPasswordAuthenticator apis.PasswordAuthenticator, localTokenAuthenticator apis.TokenInfo, gatewayTokenProvider apis.TokenProvider, gatewayTokenInfo apis.TokenInfo) (*Client, error) {
Expand Down Expand Up @@ -72,7 +72,7 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne
}

return &Client{conns: conns, config: c, dialer: dialer, tcpConnOptions: tcpConnOptions, stopRun: make(chan struct{}, 1),
saslPlainAuth: &SASLPlainAuth{
saslAuthByProxy: &SASLPlainAuth{
clientID: c.Kafka.ClientID,
writeTimeout: c.Kafka.WriteTimeout,
readTimeout: c.Kafka.ReadTimeout,
Expand Down Expand Up @@ -193,7 +193,7 @@ func (c *Client) handleConn(conn Conn) {
server, err := c.DialAndAuth(conn.BrokerAddress)
if err != nil {
logrus.Infof("couldn't connect to %s: %v", conn.BrokerAddress, err)
conn.LocalConnection.Close()
_ = conn.LocalConnection.Close()
return
}
if tcpConn, ok := server.(*net.TCPConn); ok {
Expand All @@ -215,7 +215,7 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) {
return nil, err
}
if err := conn.SetDeadline(time.Time{}); err != nil {
conn.Close()
_ = conn.Close()
return nil, err
}
err = c.auth(conn)
Expand All @@ -228,22 +228,22 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) {
func (c *Client) auth(conn net.Conn) error {
if c.config.Auth.Gateway.Client.Enable {
if err := c.authClient.sendAndReceiveGatewayAuth(conn); err != nil {
conn.Close()
_ = conn.Close()
return err
}
if err := conn.SetDeadline(time.Time{}); err != nil {
conn.Close()
_ = conn.Close()
return err
}
}
if c.config.Kafka.SASL.Enable {
err := c.saslPlainAuth.sendAndReceiveSASLPlainAuth(conn)
err := c.saslAuthByProxy.sendAndReceiveSASLAuth(conn)
if err != nil {
conn.Close()
_ = conn.Close()
return err
}
if err := conn.SetDeadline(time.Time{}); err != nil {
conn.Close()
_ = conn.Close()
return err
}
}
Expand Down
128 changes: 123 additions & 5 deletions proxy/sasl.go → proxy/sasl_by_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package proxy

import (
"bytes"
"context"
"encoding/binary"
"fmt"
"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/pkg/errors"
"io"
Expand All @@ -15,6 +17,24 @@ const (
SASLOAuthBearer = "OAUTHBEARER"
)

type SASLHandshake struct {
clientID string
version int16
mechanism string

writeTimeout time.Duration
readTimeout time.Duration
}

type SASLOAuthBearerAuth struct {
clientID string

writeTimeout time.Duration
readTimeout time.Duration

tokenProvider apis.TokenProvider
}

type SASLPlainAuth struct {
clientID string

Expand All @@ -25,6 +45,10 @@ type SASLPlainAuth struct {
password string
}

type SASLAuthByProxy interface {
sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error
}

// In SASL Plain, Kafka expects the auth header to be in the following format
// Message format (from https://tools.ietf.org/html/rfc4616):
//
Expand All @@ -40,9 +64,16 @@ type SASLPlainAuth struct {
// When credentials are valid, Kafka returns a 4 byte array of null characters.
// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
// of responding to bad credentials but thats how its being done today.
func (b *SASLPlainAuth) sendAndReceiveSASLPlainAuth(conn DeadlineReaderWriter) error {

handshakeErr := b.sendAndReceiveSASLPlainHandshake(conn)
func (b *SASLPlainAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error {

saslHandshake := &SASLHandshake{
clientID: b.clientID,
version: 0,
mechanism: SASLPlain,
writeTimeout: b.writeTimeout,
readTimeout: b.readTimeout,
}
handshakeErr := saslHandshake.sendAndReceiveHandshake(conn)
if handshakeErr != nil {
return handshakeErr
}
Expand Down Expand Up @@ -78,11 +109,11 @@ func (b *SASLPlainAuth) sendAndReceiveSASLPlainAuth(conn DeadlineReaderWriter) e
return nil
}

func (b *SASLPlainAuth) sendAndReceiveSASLPlainHandshake(conn DeadlineReaderWriter) error {
func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error {

req := &protocol.Request{
ClientID: b.clientID,
Body: &protocol.SaslHandshakeRequestV0orV1{Version: 0, Mechanism: SASLPlain},
Body: &protocol.SaslHandshakeRequestV0orV1{Version: b.version, Mechanism: b.mechanism},
}
reqBuf, err := protocol.Encode(req)
if err != nil {
Expand Down Expand Up @@ -128,3 +159,90 @@ func (b *SASLPlainAuth) sendAndReceiveSASLPlainHandshake(conn DeadlineReaderWrit
}
return nil
}

func (b *SASLOAuthBearerAuth) getOAuthBearerToken() (string, error) {
resp, err := b.tokenProvider.GetToken(context.Background(), apis.TokenRequest{})
if err != nil {
return "", err
}
if !resp.Success {
return "", fmt.Errorf("get sasl token failed with status: %d", resp.Status)
}
if resp.Token == "" {
return "", errors.New("get sasl token returned empty token")
}
return resp.Token, nil
}

func (b *SASLOAuthBearerAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error {

token, err := b.getOAuthBearerToken()
if err != nil {
return err
}
saslHandshake := &SASLHandshake{
clientID: b.clientID,
version: 1,
mechanism: SASLOAuthBearer,
writeTimeout: b.writeTimeout,
readTimeout: b.readTimeout,
}
handshakeErr := saslHandshake.sendAndReceiveHandshake(conn)
if handshakeErr != nil {
return handshakeErr
}
return b.sendSaslAuthenticateRequest(token, conn)
}

func (b *SASLOAuthBearerAuth) sendSaslAuthenticateRequest(token string, conn DeadlineReaderWriter) error {
saslAuthReqV0 := protocol.SaslAuthenticateRequestV0{SaslAuthBytes: SaslOAuthBearer{}.ToBytes(token, "", make(map[string]string, 0))}

req := &protocol.Request{
ClientID: b.clientID,
Body: &saslAuthReqV0,
}
reqBuf, err := protocol.Encode(req)
if err != nil {
return err
}
sizeBuf := make([]byte, 4)
binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf)))

err = conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err != nil {
return err
}

_, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil))
if err != nil {
return errors.Wrap(err, "Failed to send SASL auth request")
}

err = conn.SetReadDeadline(time.Now().Add(b.readTimeout))
if err != nil {
return err
}

//wait for the response
header := make([]byte, 8) // response header
_, err = io.ReadFull(conn, header)
if err != nil {
return errors.Wrap(err, "Failed to read SASL auth header")
}
length := binary.BigEndian.Uint32(header[:4])
payload := make([]byte, length-4)
_, err = io.ReadFull(conn, payload)
if err != nil {
return errors.Wrap(err, "Failed to read SASL auth payload")
}

res := &protocol.SaslAuthenticateResponseV0{}
err = protocol.Decode(payload, res)
if err != nil {
return errors.Wrap(err, "Failed to parse SASL auth response")
}
if res.Err != protocol.ErrNoError {
return errors.Wrapf(res.Err, "SASL authentication failed, error message is '%v'", res.ErrMsg)
}
return nil
}
2 changes: 1 addition & 1 deletion proxy/sasl_local_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func NewLocalSaslOauth(tokenAuthenticator apis.TokenInfo) *LocalSaslOauth {

// implements LocalSaslAuth
func (p *LocalSaslOauth) doLocalAuth(saslAuthBytes []byte) (err error) {
token, err := p.saslOAuthBearer.GetToken(saslAuthBytes)
token, _, _, err := p.saslOAuthBearer.GetClientInitialResponse(saslAuthBytes)
if err != nil {
return err
}
Expand Down
46 changes: 43 additions & 3 deletions proxy/sasl_oauthbearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (

// https://tools.ietf.org/html/rfc7628#section-3.1
// https://tools.ietf.org/html/rfc5801#section-4
// https://tools.ietf.org/html/rfc5801 (UTF8-1-safe)
const (
saslOauthSeparator = "\u0001"
saslOauthSaslName = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+"
saslOauthSaslName = "(?:[\x01-\x2b]|[\x2d-\x3c]|[\x3e-\x7F]|=2C|=3D)+"
saslOauthKey = "[A-Za-z]+"
saslOauthValue = "[\\x21-\\x7E \t\r\n]+"
saslOauthAuthKey = "auth"
Expand All @@ -25,18 +26,32 @@ var (

type SaslOAuthBearer struct{}

func (p SaslOAuthBearer) GetToken(saslAuthBytes []byte) (string, error) {
func (p SaslOAuthBearer) GetClientInitialResponse(saslAuthBytes []byte) (token string, authzid string, extensions map[string]string, err error) {
match := saslOauthClientInitialResponsePattern.FindSubmatch(saslAuthBytes)
if len(match) == 0 {
return "", "", nil, errors.New("invalid OAUTHBEARER initial client response: 'saslAuthBytes' parse error")
}

result := make(map[string][]byte)
for i, name := range saslOauthClientInitialResponsePattern.SubexpNames() {
if i != 0 && name != "" {
if i >= len(match) {
return "", "", nil, errors.New("invalid OAUTHBEARER initial client response: 'SubexpNames' range error")
}
result[name] = match[i]
}
}

authzid = string(result["authzid"])
kvpairs := result["kvpairs"]
properties := p.parseMap(string(kvpairs), "=", saslOauthSeparator)
return p.parseToken(properties[saslOauthAuthKey])

token, err = p.parseToken(properties[saslOauthAuthKey])
if err != nil {
return "", "", nil, err
}
delete(properties, saslOauthAuthKey)
return token, authzid, properties, nil
}

func (SaslOAuthBearer) parseToken(auth string) (string, error) {
Expand Down Expand Up @@ -73,3 +88,28 @@ func (SaslOAuthBearer) parseMap(mapStr string, keyValueSeparator string, element
}
return result
}

func (SaslOAuthBearer) mkString(mapValues map[string]string, keyValueSeparator string, elementSeparator string) string {
if len(mapValues) == 0 {
return ""
}
elements := make([]string, 0, len(mapValues))
for k, v := range mapValues {
elements = append(elements, strings.Join([]string{k, v}, keyValueSeparator))
}
return strings.Join(elements, elementSeparator)
}

func (p SaslOAuthBearer) ToBytes(tokenValue string, authorizationId string, saslExtensions map[string]string) []byte {
authzid := authorizationId
if authzid != "" {
authzid = "a=" + authorizationId
}
extensions := p.mkString(saslExtensions, "=", saslOauthSeparator)
if extensions != "" {
extensions = saslOauthSeparator + extensions
}
message := fmt.Sprintf("n,%s,%sauth=Bearer %s%s%s%s", authzid,
saslOauthSeparator, tokenValue, extensions, saslOauthSeparator, saslOauthSeparator)
return []byte(message)
}
55 changes: 54 additions & 1 deletion proxy/sasl_oauthbearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,60 @@ func TestSaslOauthParseToken(t *testing.T) {
saslAuthBytes, err := hex.DecodeString(saslBytes)
a.Nil(err)

token, err := SaslOAuthBearer{}.GetToken(saslAuthBytes)
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse(saslAuthBytes)
a.Nil(err)
a.Empty(authzid)
a.Empty(extensions)
a.Equal("eyJhbGciOiJub25lIn0.eyJleHAiOjEuNTM5NTE2Njk0NDE4RTksImlhdCI6MS41Mzk1MTMwOTQ0MThFOSwic3ViIjoiYWxpY2UyIn0.", token)

a.Equal(saslAuthBytes, SaslOAuthBearer{}.ToBytes(token, authzid, extensions))

}
func TestSaslOAuthBearerToBytes(t *testing.T) {
a := assert.New(t)
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001"))
a.Nil(err)
a.Equal("123.345.567", token)
a.Empty(authzid)
a.Equal(map[string]string{"nineteen": "42"}, extensions)
}

func TestSaslOAuthBearerAuthorizationId(t *testing.T) {
a := assert.New(t)
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,a=myuser,\u0001auth=Bearer 345\u0001\u0001"))
a.Nil(err)
a.Equal("345", token)
a.Equal("myuser", authzid)
a.Empty(extensions)
}

func TestSaslOAuthBearerExtensions(t *testing.T) {
a := assert.New(t)
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 567\u0001propB=valueB\u0001\u0001"))
a.Nil(err)
a.Equal("567", token)
a.Empty(authzid)
a.Equal(map[string]string{"propA": "valueA1, valueA2", "propB": "valueB"}, extensions)
}

func TestSaslOAuthBearerRfc7688Example(t *testing.T) {
a := assert.New(t)
message := "n,[email protected],\u0001host=server.example.com\u0001port=143\u0001" +
"auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte(message))
a.Nil(err)
a.Equal("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", token)
a.Equal("[email protected]", authzid)
a.Equal(map[string]string{"host": "server.example.com", "port": "143"}, extensions)
}

func TestSaslOAuthBearerNoExtensionsFromByteArray(t *testing.T) {
a := assert.New(t)
message := "n,[email protected],\u0001" +
"auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte(message))
a.Nil(err)
a.Equal("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", token)
a.Equal("[email protected]", authzid)
a.Empty(extensions)
}

0 comments on commit 5edf5c4

Please sign in to comment.