forked from grepplabs/kafka-proxy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Created new oidc token provider plugin
- Loading branch information
Pavol Ipoth
committed
May 21, 2020
1 parent
efbf3bc
commit 55af91c
Showing
12 changed files
with
808 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package main | ||
|
||
import ( | ||
"os" | ||
|
||
oidcprovider "github.com/grepplabs/kafka-proxy/pkg/libs/oidc-provider" | ||
"github.com/grepplabs/kafka-proxy/plugin/token-provider/shared" | ||
"github.com/hashicorp/go-plugin" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
func main() { | ||
tokenProvider, err := new(oidcprovider.Factory).New(os.Args[1:]) | ||
|
||
if err != nil { | ||
logrus.Errorf("cannot initialize oidc-token provider: %v", err) | ||
os.Exit(1) | ||
} | ||
|
||
plugin.Serve(&plugin.ServeConfig{ | ||
HandshakeConfig: shared.Handshake, | ||
Plugins: map[string]plugin.Plugin{ | ||
"tokenProvider": &shared.TokenProviderPlugin{Impl: tokenProvider}, | ||
}, | ||
// A non-nil value here enables gRPC serving for this plugin... | ||
GRPCServer: plugin.DefaultGRPCServer, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package oidcprovider | ||
|
||
import ( | ||
"flag" | ||
|
||
"github.com/grepplabs/kafka-proxy/pkg/apis" | ||
"github.com/grepplabs/kafka-proxy/pkg/registry" | ||
) | ||
|
||
func init() { | ||
registry.NewComponentInterface(new(apis.TokenProviderFactory)) | ||
registry.Register(new(Factory), "oidc-provider") | ||
} | ||
|
||
func (f *pluginMeta) flagSet() *flag.FlagSet { | ||
fs := flag.NewFlagSet("oidc provider settings", flag.ContinueOnError) | ||
return fs | ||
} | ||
|
||
type pluginMeta struct { | ||
timeout int | ||
|
||
credentialsWatch bool | ||
credentialsFile string | ||
targetAudience string | ||
} | ||
|
||
// Factory type | ||
type Factory struct { | ||
} | ||
|
||
// New implements apis.TokenProviderFactory | ||
func (t *Factory) New(params []string) (apis.TokenProvider, error) { | ||
pluginMeta := &pluginMeta{} | ||
fs := pluginMeta.flagSet() | ||
fs.IntVar(&pluginMeta.timeout, "timeout", 10, "Request timeout in seconds") | ||
fs.StringVar(&pluginMeta.credentialsFile, "credentials-file", "", "Location of the JSON file with the application credentials") | ||
fs.BoolVar(&pluginMeta.credentialsWatch, "credentials-watch", true, "Watch credential for reload") | ||
fs.StringVar(&pluginMeta.targetAudience, "target-audience", "", "URI of audience claim") | ||
|
||
fs.Parse(params) | ||
|
||
options := TokenProviderOptions{ | ||
Timeout: pluginMeta.timeout, | ||
CredentialsWatch: pluginMeta.credentialsWatch, | ||
CredentialsFile: pluginMeta.credentialsFile, | ||
TargetAudience: pluginMeta.targetAudience, | ||
} | ||
|
||
return NewTokenProvider(options) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
package oidcprovider | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"os" | ||
"sync" | ||
"time" | ||
|
||
"github.com/cenkalti/backoff" | ||
"github.com/grepplabs/kafka-proxy/pkg/apis" | ||
"github.com/grepplabs/kafka-proxy/pkg/libs/oidc" | ||
"github.com/grepplabs/kafka-proxy/pkg/libs/util" | ||
"github.com/pkg/errors" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
const ( | ||
StatusOK = 0 | ||
StatusGetTokenFailed = 1 | ||
StatusParseTokenFailed = 2 | ||
) | ||
|
||
var ( | ||
clockSkew = 1 * time.Minute | ||
nowFn = time.Now | ||
) | ||
|
||
// TokenProvider - represents TokenProvider | ||
type TokenProvider struct { | ||
timeout time.Duration | ||
idTokenSource idTokenSource | ||
|
||
idToken *oidc.Token | ||
l sync.RWMutex | ||
} | ||
|
||
// TokenProviderOptions - options specific for oidc | ||
type TokenProviderOptions struct { | ||
Timeout int | ||
|
||
CredentialsWatch bool | ||
CredentialsFile string | ||
TargetAudience string | ||
} | ||
|
||
// NewTokenProvider - Generate new OIDC token provider | ||
func NewTokenProvider(options TokenProviderOptions) (*TokenProvider, error) { | ||
os.Setenv("OIDC_CREDENTIALS", options.CredentialsFile) | ||
stopChannel := make(chan bool, 1) | ||
var idTokenSource idTokenSource | ||
|
||
serviceAccountSource, err := NewServiceAccountSource( | ||
options.CredentialsFile, | ||
options.TargetAudience) | ||
|
||
idTokenSource = serviceAccountSource | ||
|
||
if err != nil { | ||
return nil, errors.Wrap(err, "creation of service account source failed") | ||
} | ||
|
||
if options.CredentialsWatch { | ||
action := func() { | ||
logrus.Infof("reloading credential file %s", options.CredentialsFile) | ||
|
||
newConfig, errServ := oidc.NewServiceAccountTokenSource( | ||
options.CredentialsFile, | ||
options.TargetAudience) | ||
|
||
if errServ != nil { | ||
logrus.Errorf("error while reloading credentials files: %s", err) | ||
return | ||
} | ||
|
||
serviceAccountSource.setServiceAccountTokenSource(newConfig) | ||
} | ||
|
||
err = util.WatchForUpdates(options.CredentialsFile, stopChannel, action) | ||
|
||
if err != nil { | ||
return nil, errors.Wrap(err, "cannot watch credentials file") | ||
} | ||
} | ||
|
||
tokenProvider := &TokenProvider{ | ||
timeout: time.Duration(options.Timeout) * time.Second, | ||
idTokenSource: idTokenSource} | ||
|
||
op := func() error { | ||
return initToken(tokenProvider) | ||
} | ||
|
||
err = backoff.Retry( | ||
op, | ||
backoff.WithMaxTries(backoff.NewConstantBackOff(1*time.Second), 3)) | ||
|
||
if err != nil { | ||
return nil, errors.Wrap(err, "getting of initial oidc token failed") | ||
} | ||
|
||
tokenRefresher := &TokenRefresher{ | ||
tokenProvider: tokenProvider, | ||
stopChannel: stopChannel, | ||
} | ||
|
||
go tokenRefresher.refreshLoop() | ||
|
||
return tokenProvider, nil | ||
} | ||
|
||
func initToken(tokenProvider *TokenProvider) error { | ||
ctx, cancel := context.WithTimeout(context.Background(), tokenProvider.timeout) | ||
defer cancel() | ||
|
||
resp, err := tokenProvider.GetToken(ctx, apis.TokenRequest{}) | ||
|
||
if err != nil { | ||
return err | ||
} | ||
|
||
if !resp.Success { | ||
return fmt.Errorf("get token failed with status: %d", resp.Status) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (p *TokenProvider) getCurrentToken() string { | ||
p.l.RLock() | ||
defer p.l.RUnlock() | ||
|
||
if p.idToken == nil { | ||
return "" | ||
} | ||
|
||
if renewLatest(p.idToken) { | ||
return "" | ||
} | ||
|
||
return p.idToken.Raw | ||
} | ||
|
||
func (p *TokenProvider) getCurrentClaimSet() *oidc.ClaimSet { | ||
p.l.RLock() | ||
defer p.l.RUnlock() | ||
|
||
if p.idToken == nil { | ||
return nil | ||
} | ||
|
||
return p.idToken.ClaimSet | ||
} | ||
|
||
func (p *TokenProvider) setCurrentToken(idToken *oidc.Token) { | ||
p.l.Lock() | ||
defer p.l.Unlock() | ||
|
||
p.idToken = idToken | ||
} | ||
|
||
func renewLatest(token *oidc.Token) bool { | ||
if token == nil { | ||
return true | ||
} | ||
// renew before expiry | ||
advExp := token.ClaimSet.Exp - int64(clockSkew.Seconds()) | ||
|
||
if nowFn().Unix() > advExp { | ||
return true | ||
} | ||
|
||
return false | ||
} | ||
|
||
// GetToken implements apis.TokenProvider.GetToken method | ||
func (p *TokenProvider) GetToken(parent context.Context, _ apis.TokenRequest) (apis.TokenResponse, error) { | ||
|
||
currentToken := p.getCurrentToken() | ||
|
||
if currentToken != "" { | ||
return getTokenResponse(currentToken, StatusOK) | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(parent, p.timeout) | ||
defer cancel() | ||
|
||
token, err := p.idTokenSource.GetIDToken(ctx) | ||
|
||
if err != nil { | ||
logrus.Errorf("GetIDToken failed %v", err) | ||
return getTokenResponse("", StatusGetTokenFailed) | ||
} | ||
|
||
idToken, err := oidc.ParseJWT(token) | ||
|
||
if err != nil { | ||
logrus.Error(err) | ||
return getTokenResponse("", StatusParseTokenFailed) | ||
} | ||
|
||
p.setCurrentToken(idToken) | ||
|
||
logrus.Infof("New token expiry %d (%v)", idToken.ClaimSet.Exp, time.Unix(idToken.ClaimSet.Exp, 0)) | ||
|
||
return getTokenResponse(token, StatusOK) | ||
} | ||
|
||
func getTokenResponse(token string, status int) (apis.TokenResponse, error) { | ||
success := status == StatusOK | ||
return apis.TokenResponse{Success: success, Status: int32(status), Token: token}, nil | ||
} | ||
|
||
type idTokenSource interface { | ||
GetIDToken(ctx context.Context) (string, error) | ||
} | ||
|
||
type serviceAccountSource struct { | ||
source *oidc.ServiceAccountTokenSource | ||
l sync.RWMutex | ||
} | ||
|
||
func (p *serviceAccountSource) getServiceAccountTokenSource() *oidc.ServiceAccountTokenSource { | ||
p.l.RLock() | ||
defer p.l.RUnlock() | ||
|
||
return p.source | ||
} | ||
|
||
func (p *serviceAccountSource) setServiceAccountTokenSource(source *oidc.ServiceAccountTokenSource) { | ||
p.l.Lock() | ||
defer p.l.Unlock() | ||
|
||
p.source = source | ||
} | ||
|
||
func NewServiceAccountSource(credentialsFile string, targetAudience string) (*serviceAccountSource, error) { | ||
source, err := oidc.NewServiceAccountTokenSource(credentialsFile, targetAudience) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &serviceAccountSource{ | ||
source: source, | ||
}, nil | ||
} | ||
|
||
func (s *serviceAccountSource) GetIDToken(parent context.Context) (string, error) { | ||
return s.getServiceAccountTokenSource().GetIDToken(parent) | ||
} |
Oops, something went wrong.