From 55af91cb953d0dd8f2bf7705359c7ecc87be4aac Mon Sep 17 00:00:00 2001 From: Pavol Ipoth Date: Thu, 21 May 2020 09:09:13 +0200 Subject: [PATCH] Created new oidc token provider plugin --- Makefile | 4 +- cmd/plugin-oidc-provider/main.go | 28 +++ pkg/libs/oidc-provider/factory.go | 51 ++++++ pkg/libs/oidc-provider/plugin.go | 251 ++++++++++++++++++++++++++ pkg/libs/oidc-provider/ticker.go | 117 ++++++++++++ pkg/libs/oidc/certs.go | 74 ++++++++ pkg/libs/oidc/certs_test.go | 32 ++++ pkg/libs/oidc/jwt.go | 75 ++++++++ pkg/libs/oidc/jwt_test.go | 70 +++++++ pkg/libs/oidc/service_account.go | 79 ++++++++ pkg/libs/oidc/service_account_test.go | 24 +++ proxy/client.go | 7 +- 12 files changed, 808 insertions(+), 4 deletions(-) create mode 100644 cmd/plugin-oidc-provider/main.go create mode 100644 pkg/libs/oidc-provider/factory.go create mode 100644 pkg/libs/oidc-provider/plugin.go create mode 100644 pkg/libs/oidc-provider/ticker.go create mode 100644 pkg/libs/oidc/certs.go create mode 100644 pkg/libs/oidc/certs_test.go create mode 100644 pkg/libs/oidc/jwt.go create mode 100644 pkg/libs/oidc/jwt_test.go create mode 100644 pkg/libs/oidc/service_account.go create mode 100644 pkg/libs/oidc/service_account_test.go diff --git a/Makefile b/Makefile index 9b8ae255..f4ab05ee 100644 --- a/Makefile +++ b/Makefile @@ -71,8 +71,10 @@ plugin.unsecured-jwt-info: plugin.unsecured-jwt-provider: CGO_ENABLED=0 go build -o build/unsecured-jwt-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-unsecured-jwt-provider/main.go +plugin.oidc-provider: + CGO_ENABLED=0 go build -o build/oidc-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-oidc-provider/main.go -all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider +all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider plugin.oidc-provider clean: @rm -rf build diff --git a/cmd/plugin-oidc-provider/main.go b/cmd/plugin-oidc-provider/main.go new file mode 100644 index 00000000..1f15344c --- /dev/null +++ b/cmd/plugin-oidc-provider/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "os" + + oidcprovider "github.com/grepplabs/kafka-proxy/pkg/libs/oidc-provider" + "github.com/grepplabs/kafka-proxy/plugin/token-provider/shared" + "github.com/hashicorp/go-plugin" + "github.com/sirupsen/logrus" +) + +func main() { + tokenProvider, err := new(oidcprovider.Factory).New(os.Args[1:]) + + if err != nil { + logrus.Errorf("cannot initialize oidc-token provider: %v", err) + os.Exit(1) + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: shared.Handshake, + Plugins: map[string]plugin.Plugin{ + "tokenProvider": &shared.TokenProviderPlugin{Impl: tokenProvider}, + }, + // A non-nil value here enables gRPC serving for this plugin... + GRPCServer: plugin.DefaultGRPCServer, + }) +} diff --git a/pkg/libs/oidc-provider/factory.go b/pkg/libs/oidc-provider/factory.go new file mode 100644 index 00000000..fd715ffd --- /dev/null +++ b/pkg/libs/oidc-provider/factory.go @@ -0,0 +1,51 @@ +package oidcprovider + +import ( + "flag" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/pkg/registry" +) + +func init() { + registry.NewComponentInterface(new(apis.TokenProviderFactory)) + registry.Register(new(Factory), "oidc-provider") +} + +func (f *pluginMeta) flagSet() *flag.FlagSet { + fs := flag.NewFlagSet("oidc provider settings", flag.ContinueOnError) + return fs +} + +type pluginMeta struct { + timeout int + + credentialsWatch bool + credentialsFile string + targetAudience string +} + +// Factory type +type Factory struct { +} + +// New implements apis.TokenProviderFactory +func (t *Factory) New(params []string) (apis.TokenProvider, error) { + pluginMeta := &pluginMeta{} + fs := pluginMeta.flagSet() + fs.IntVar(&pluginMeta.timeout, "timeout", 10, "Request timeout in seconds") + fs.StringVar(&pluginMeta.credentialsFile, "credentials-file", "", "Location of the JSON file with the application credentials") + fs.BoolVar(&pluginMeta.credentialsWatch, "credentials-watch", true, "Watch credential for reload") + fs.StringVar(&pluginMeta.targetAudience, "target-audience", "", "URI of audience claim") + + fs.Parse(params) + + options := TokenProviderOptions{ + Timeout: pluginMeta.timeout, + CredentialsWatch: pluginMeta.credentialsWatch, + CredentialsFile: pluginMeta.credentialsFile, + TargetAudience: pluginMeta.targetAudience, + } + + return NewTokenProvider(options) +} diff --git a/pkg/libs/oidc-provider/plugin.go b/pkg/libs/oidc-provider/plugin.go new file mode 100644 index 00000000..e319731c --- /dev/null +++ b/pkg/libs/oidc-provider/plugin.go @@ -0,0 +1,251 @@ +package oidcprovider + +import ( + "context" + "fmt" + "os" + "sync" + "time" + + "github.com/cenkalti/backoff" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/pkg/libs/oidc" + "github.com/grepplabs/kafka-proxy/pkg/libs/util" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + StatusOK = 0 + StatusGetTokenFailed = 1 + StatusParseTokenFailed = 2 +) + +var ( + clockSkew = 1 * time.Minute + nowFn = time.Now +) + +// TokenProvider - represents TokenProvider +type TokenProvider struct { + timeout time.Duration + idTokenSource idTokenSource + + idToken *oidc.Token + l sync.RWMutex +} + +// TokenProviderOptions - options specific for oidc +type TokenProviderOptions struct { + Timeout int + + CredentialsWatch bool + CredentialsFile string + TargetAudience string +} + +// NewTokenProvider - Generate new OIDC token provider +func NewTokenProvider(options TokenProviderOptions) (*TokenProvider, error) { + os.Setenv("OIDC_CREDENTIALS", options.CredentialsFile) + stopChannel := make(chan bool, 1) + var idTokenSource idTokenSource + + serviceAccountSource, err := NewServiceAccountSource( + options.CredentialsFile, + options.TargetAudience) + + idTokenSource = serviceAccountSource + + if err != nil { + return nil, errors.Wrap(err, "creation of service account source failed") + } + + if options.CredentialsWatch { + action := func() { + logrus.Infof("reloading credential file %s", options.CredentialsFile) + + newConfig, errServ := oidc.NewServiceAccountTokenSource( + options.CredentialsFile, + options.TargetAudience) + + if errServ != nil { + logrus.Errorf("error while reloading credentials files: %s", err) + return + } + + serviceAccountSource.setServiceAccountTokenSource(newConfig) + } + + err = util.WatchForUpdates(options.CredentialsFile, stopChannel, action) + + if err != nil { + return nil, errors.Wrap(err, "cannot watch credentials file") + } + } + + tokenProvider := &TokenProvider{ + timeout: time.Duration(options.Timeout) * time.Second, + idTokenSource: idTokenSource} + + op := func() error { + return initToken(tokenProvider) + } + + err = backoff.Retry( + op, + backoff.WithMaxTries(backoff.NewConstantBackOff(1*time.Second), 3)) + + if err != nil { + return nil, errors.Wrap(err, "getting of initial oidc token failed") + } + + tokenRefresher := &TokenRefresher{ + tokenProvider: tokenProvider, + stopChannel: stopChannel, + } + + go tokenRefresher.refreshLoop() + + return tokenProvider, nil +} + +func initToken(tokenProvider *TokenProvider) error { + ctx, cancel := context.WithTimeout(context.Background(), tokenProvider.timeout) + defer cancel() + + resp, err := tokenProvider.GetToken(ctx, apis.TokenRequest{}) + + if err != nil { + return err + } + + if !resp.Success { + return fmt.Errorf("get token failed with status: %d", resp.Status) + } + + return nil +} + +func (p *TokenProvider) getCurrentToken() string { + p.l.RLock() + defer p.l.RUnlock() + + if p.idToken == nil { + return "" + } + + if renewLatest(p.idToken) { + return "" + } + + return p.idToken.Raw +} + +func (p *TokenProvider) getCurrentClaimSet() *oidc.ClaimSet { + p.l.RLock() + defer p.l.RUnlock() + + if p.idToken == nil { + return nil + } + + return p.idToken.ClaimSet +} + +func (p *TokenProvider) setCurrentToken(idToken *oidc.Token) { + p.l.Lock() + defer p.l.Unlock() + + p.idToken = idToken +} + +func renewLatest(token *oidc.Token) bool { + if token == nil { + return true + } + // renew before expiry + advExp := token.ClaimSet.Exp - int64(clockSkew.Seconds()) + + if nowFn().Unix() > advExp { + return true + } + + return false +} + +// GetToken implements apis.TokenProvider.GetToken method +func (p *TokenProvider) GetToken(parent context.Context, _ apis.TokenRequest) (apis.TokenResponse, error) { + + currentToken := p.getCurrentToken() + + if currentToken != "" { + return getTokenResponse(currentToken, StatusOK) + } + + ctx, cancel := context.WithTimeout(parent, p.timeout) + defer cancel() + + token, err := p.idTokenSource.GetIDToken(ctx) + + if err != nil { + logrus.Errorf("GetIDToken failed %v", err) + return getTokenResponse("", StatusGetTokenFailed) + } + + idToken, err := oidc.ParseJWT(token) + + if err != nil { + logrus.Error(err) + return getTokenResponse("", StatusParseTokenFailed) + } + + p.setCurrentToken(idToken) + + logrus.Infof("New token expiry %d (%v)", idToken.ClaimSet.Exp, time.Unix(idToken.ClaimSet.Exp, 0)) + + return getTokenResponse(token, StatusOK) +} + +func getTokenResponse(token string, status int) (apis.TokenResponse, error) { + success := status == StatusOK + return apis.TokenResponse{Success: success, Status: int32(status), Token: token}, nil +} + +type idTokenSource interface { + GetIDToken(ctx context.Context) (string, error) +} + +type serviceAccountSource struct { + source *oidc.ServiceAccountTokenSource + l sync.RWMutex +} + +func (p *serviceAccountSource) getServiceAccountTokenSource() *oidc.ServiceAccountTokenSource { + p.l.RLock() + defer p.l.RUnlock() + + return p.source +} + +func (p *serviceAccountSource) setServiceAccountTokenSource(source *oidc.ServiceAccountTokenSource) { + p.l.Lock() + defer p.l.Unlock() + + p.source = source +} + +func NewServiceAccountSource(credentialsFile string, targetAudience string) (*serviceAccountSource, error) { + source, err := oidc.NewServiceAccountTokenSource(credentialsFile, targetAudience) + + if err != nil { + return nil, err + } + + return &serviceAccountSource{ + source: source, + }, nil +} + +func (s *serviceAccountSource) GetIDToken(parent context.Context) (string, error) { + return s.getServiceAccountTokenSource().GetIDToken(parent) +} diff --git a/pkg/libs/oidc-provider/ticker.go b/pkg/libs/oidc-provider/ticker.go new file mode 100644 index 00000000..e79ff76a --- /dev/null +++ b/pkg/libs/oidc-provider/ticker.go @@ -0,0 +1,117 @@ +package oidcprovider + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff" + "github.com/grepplabs/kafka-proxy/pkg/libs/oidc" + "github.com/sirupsen/logrus" +) + +// TokenRefresher - struct providing refreshing of tokens +type TokenRefresher struct { + tokenProvider *TokenProvider + stopChannel chan bool +} + +func (p *TokenRefresher) refreshLoop() { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok := r.(error) + + if ok { + logrus.Error(fmt.Sprintf("token refresh loop error %v", err)) + } + } + }() + + syncTicker := time.NewTicker(2 * time.Second) + + for { + select { + case <-syncTicker.C: + p.refreshTick() + case <-p.stopChannel: + return + } + } +} + +func (p *TokenRefresher) newToken() (*oidc.Token, error) { + ctx, cancel := context.WithTimeout(context.Background(), p.tokenProvider.timeout) + defer cancel() + + token, err := p.tokenProvider.idTokenSource.GetIDToken(ctx) + + if err != nil { + return nil, err + } + + return oidc.ParseJWT(token) +} + +func (p *TokenRefresher) tryRefresh() error { + var idToken *oidc.Token + + op := func() error { + var err error + idToken, err = p.newToken() + return err + } + + backOff := backoff.NewExponentialBackOff() + backOff.MaxElapsedTime = 60 * time.Second + backOff.MaxInterval = 10 * time.Second + err := backoff.Retry(op, backOff) + + if err != nil { + return err + } + + logrus.Infof( + "Refreshed token expiry %d (%v)", + idToken.ClaimSet.Exp, + time.Unix(idToken.ClaimSet.Exp, 0), + ) + + p.tokenProvider.setCurrentToken(idToken) + + return nil +} + +func (p *TokenRefresher) refreshTick() { + claimSet := p.tokenProvider.getCurrentClaimSet() + + if renewEarliest(claimSet) { + err := p.tryRefresh() + + if err != nil { + logrus.Errorf("refreshing of google-id-token failed : %v", err) + } + } +} + +func renewEarliest(claimSet *oidc.ClaimSet) bool { + if claimSet == nil { + return true + } + + validity := claimSet.Exp - claimSet.Iat + + if validity <= 0 { + // should would be invalid claim + return true + } + + refreshAfter := int64(validity / 2) + refreshTime := claimSet.Exp - int64(clockSkew.Seconds()) - refreshAfter + // logrus.Debugf("New refresh time %d (%v)", refreshTime, time.Unix(refreshTime, 0)) + if nowFn().Unix() > refreshTime { + return true + } + + return false +} diff --git a/pkg/libs/oidc/certs.go b/pkg/libs/oidc/certs.go new file mode 100644 index 00000000..45507cbd --- /dev/null +++ b/pkg/libs/oidc/certs.go @@ -0,0 +1,74 @@ +package oidc + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "golang.org/x/net/context/ctxhttp" + "io/ioutil" + "math/big" + "net/http" + "time" +) + +const ( + // https://accounts.google.com/.well-known/openid-configuration + jwksUri = "https://www.googleapis.com/oauth2/v3/certs" +) + +type Certs struct { + Keys []Keys `json:"keys"` +} + +// https://tools.ietf.org/html/rfc7517#appendix-A +type Keys struct { + Kty string `json:"kty"` + Alg string `json:"alg"` + Use string `json:"use"` + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` +} + +func (key *Keys) GetPublicKey() (*rsa.PublicKey, error) { + b, err := base64.RawURLEncoding.DecodeString(key.N) + if err != nil { + return nil, err + } + n := new(big.Int).SetBytes(b) + + b, err = base64.RawURLEncoding.DecodeString(key.E) + if err != nil { + return nil, err + } + e := new(big.Int).SetBytes(b) + + pub := &rsa.PublicKey{N: n, E: int(e.Uint64())} + return pub, nil +} + +func GetCerts(ctx context.Context) (*Certs, error) { + client := &http.Client{ + Timeout: time.Second * 10, + } + resp, err := ctxhttp.Get(ctx, client, jwksUri) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if c := resp.StatusCode; c < 200 || c > 299 { + return nil, fmt.Errorf("cannot fetch certs: %v\nResponse: %s", resp.Status, body) + } + var certs *Certs + if err = json.Unmarshal(body, &certs); err != nil { + return nil, err + } + return certs, nil +} diff --git a/pkg/libs/oidc/certs_test.go b/pkg/libs/oidc/certs_test.go new file mode 100644 index 00000000..9535c85a --- /dev/null +++ b/pkg/libs/oidc/certs_test.go @@ -0,0 +1,32 @@ +package oidc + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestGetCerts(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + certs, err := GetCerts(ctx) + + a := assert.New(t) + a.Nil(err) + a.NotNil(certs) + a.True(len(certs.Keys) > 0) + + for _, key := range certs.Keys { + a.NotEmpty(key.Kty) + a.NotEmpty(key.Alg) + a.NotEmpty(key.Use) + a.NotEmpty(key.E) + a.NotEmpty(key.N) + a.NotEmpty(key.Kid) + + pk, err := key.GetPublicKey() + a.Nil(err) + a.NotNil(pk) + } +} diff --git a/pkg/libs/oidc/jwt.go b/pkg/libs/oidc/jwt.go new file mode 100644 index 00000000..901af4ed --- /dev/null +++ b/pkg/libs/oidc/jwt.go @@ -0,0 +1,75 @@ +package oidc + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "strings" + + "golang.org/x/oauth2/jws" +) + +var ( + ErrInvalidToken = errors.New("invalid token") +) + +// ClaimSet represents claim set +type ClaimSet struct { + Iss string `json:"iss"` // email address of the client_id of the application making the access token request + Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests + Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional). + Azp string `json:"azp"` + Exp int64 `json:"exp"` // the expiration time of the assertion (seconds since Unix epoch) + Iat int64 `json:"iat"` // the time the assertion was issued (seconds since Unix epoch) + Typ string `json:"typ,omitempty"` // token type (Optional). + Sub string `json:"sub,omitempty"` // Email for which the application is requesting delegated access (Optional). + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` +} + +type Token struct { + Raw string + Header *jws.Header + ClaimSet *ClaimSet +} + +func ParseJWT(token string) (*Token, error) { + args := strings.Split(token, ".") + + if len(args) != 3 { + return nil, ErrInvalidToken + } + + decodedHeader, err := base64.RawURLEncoding.DecodeString(args[0]) + + if err != nil { + return nil, err + } + + decodedPayload, err := base64.RawURLEncoding.DecodeString(args[1]) + + if err != nil { + return nil, err + } + + header := &jws.Header{} + err = json.NewDecoder(bytes.NewBuffer(decodedHeader)).Decode(header) + + if err != nil { + return nil, err + } + + claimSet := &ClaimSet{} + err = json.NewDecoder(bytes.NewBuffer(decodedPayload)).Decode(claimSet) + + if err != nil { + return nil, err + } + + return &Token{ + Raw: token, + Header: header, + ClaimSet: claimSet, + }, nil +} diff --git a/pkg/libs/oidc/jwt_test.go b/pkg/libs/oidc/jwt_test.go new file mode 100644 index 00000000..c1f4c0c3 --- /dev/null +++ b/pkg/libs/oidc/jwt_test.go @@ -0,0 +1,70 @@ +package oidc + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseJWT(t *testing.T) { + a := assert.New(t) + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + a.Nil(err) + + testHeader := `{ + "alg": "RS256", + "kid": "978ca4118bf1883b316bbca6ce9044d9977f2027" + }` + testClaims := `{ + "azp": "4712.apps.googleusercontent.com", + "aud": "4711.apps.googleusercontent.com", + "sub": "100004711", + "hd": "grepplabs.com", + "email": "info@grepplabs.com", + "email_verified": true, + "exp": 2114380800, + "iss": "accounts.google.com", + "iat": 1516304351 + }` + + tokenString, err := encodeTestToken(testHeader, testClaims, privateKey) + a.Nil(err) + a.NotEmpty(tokenString) + + token, err := ParseJWT(tokenString) + a.Nil(err) + a.NotNil(token) + + a.Equal(token.Raw, tokenString) + a.Equal(token.Header.Algorithm, "RS256") + a.Equal(token.Header.KeyID, "978ca4118bf1883b316bbca6ce9044d9977f2027") + + a.Equal(token.ClaimSet.Azp, "4712.apps.googleusercontent.com") + a.Equal(token.ClaimSet.Aud, "4711.apps.googleusercontent.com") + a.Equal(token.ClaimSet.Sub, "100004711") + a.Equal(token.ClaimSet.Email, "info@grepplabs.com") + a.Equal(token.ClaimSet.EmailVerified, true) + a.Equal(token.ClaimSet.Exp, int64(2114380800)) + a.Equal(token.ClaimSet.Iss, "accounts.google.com") + a.Equal(token.ClaimSet.Iat, int64(1516304351)) + +} + +func encodeTestToken(headerJSON string, claimsJSON string, key *rsa.PrivateKey) (string, error) { + sg := func(data []byte) (sig []byte, err error) { + h := sha256.New() + h.Write(data) + return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil)) + } + ss := fmt.Sprintf("%s.%s", base64.RawURLEncoding.EncodeToString([]byte(headerJSON)), base64.RawURLEncoding.EncodeToString([]byte(claimsJSON))) + sig, err := sg([]byte(ss)) + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil +} diff --git a/pkg/libs/oidc/service_account.go b/pkg/libs/oidc/service_account.go new file mode 100644 index 00000000..3183fa96 --- /dev/null +++ b/pkg/libs/oidc/service_account.go @@ -0,0 +1,79 @@ +package oidc + +import ( + "context" + "encoding/json" + "errors" + "io/ioutil" + + "github.com/sirupsen/logrus" + "golang.org/x/oauth2/clientcredentials" +) + +const ( + // https://accounts.google.com/.well-known/openid-configuration + tokenEndpoint = "https://www.googleapis.com/oauth2/v4/token" +) + +type ServiceAccountTokenSource struct { + // ClientID is the application's ID. + ClientID string `json:"client_id"` + + // ClientSecret is the application's secret. + ClientSecret string `json:"client_secret"` + + // TokenURL is the resource server's token endpoint + // URL. This is a constant specific to each server. + TokenURL string `json:"token_url"` + + // Scope specifies optional requested permissions. + Scopes []string `json:"scopes"` +} + +func NewServiceAccountTokenSource(credentialsFile string, targetAudience string) (*ServiceAccountTokenSource, error) { + data, err := ioutil.ReadFile(credentialsFile) + source := &ServiceAccountTokenSource{} + + if err != nil { + return nil, err + } + + err = json.Unmarshal(data, source) + + if err != nil { + return nil, err + } + + return source, nil +} + +// GetIDToken - retrieve token from endpoint +func (s *ServiceAccountTokenSource) GetIDToken(parent context.Context) (string, error) { + conf := &clientcredentials.Config{ + ClientID: s.ClientID, + ClientSecret: s.ClientSecret, + Scopes: s.Scopes, + TokenURL: s.TokenURL, + } + + logrus.Debugf("Configuration is %v", conf) + + token, err := conf.Token(parent) + + if err != nil { + logrus.Errorf("Retrieving oidc token failed %v", err) + return "", err + } + + var idTokenRes struct { + IDToken string `json:"id_token"` + } + + if token.AccessToken == "" { + return "", errors.New("oidc access_token is missing") + } + + idTokenRes.IDToken = token.AccessToken + + return idTokenRes.IDToken, nil +} diff --git a/pkg/libs/oidc/service_account_test.go b/pkg/libs/oidc/service_account_test.go new file mode 100644 index 00000000..419157e7 --- /dev/null +++ b/pkg/libs/oidc/service_account_test.go @@ -0,0 +1,24 @@ +package oidc + +import ( + "context" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" + "time" +) + +func TestGetServiceAccountIDToken(t *testing.T) { + t.Skip() // Uncomment to execute + + credentialsFile := filepath.Join(os.Getenv("HOME"), "kafka-gateway-service-account.json") + src, err := NewServiceAccountTokenSource(credentialsFile, "tcp://kafka-gateway.grepplabs.com") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + token, err := src.GetIDToken(ctx) + + a := assert.New(t) + a.Nil(err) + a.NotEmpty(token) +} diff --git a/proxy/client.go b/proxy/client.go index 1a42076c..89b6ed69 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -3,13 +3,14 @@ package proxy import ( "crypto/tls" "fmt" + "net" + "sync" + "time" + "github.com/grepplabs/kafka-proxy/config" "github.com/grepplabs/kafka-proxy/pkg/apis" "github.com/pkg/errors" "github.com/sirupsen/logrus" - "net" - "sync" - "time" ) // Conn represents a connection from a client to a specific instance.