Skip to content

Commit

Permalink
Created new oidc token provider plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavol Ipoth committed May 21, 2020
1 parent efbf3bc commit 55af91c
Show file tree
Hide file tree
Showing 12 changed files with 808 additions and 4 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ plugin.unsecured-jwt-info:
plugin.unsecured-jwt-provider:
CGO_ENABLED=0 go build -o build/unsecured-jwt-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-unsecured-jwt-provider/main.go

plugin.oidc-provider:
CGO_ENABLED=0 go build -o build/oidc-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-oidc-provider/main.go

all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider
all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider plugin.oidc-provider

clean:
@rm -rf build
28 changes: 28 additions & 0 deletions cmd/plugin-oidc-provider/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package main

import (
"os"

oidcprovider "github.com/grepplabs/kafka-proxy/pkg/libs/oidc-provider"
"github.com/grepplabs/kafka-proxy/plugin/token-provider/shared"
"github.com/hashicorp/go-plugin"
"github.com/sirupsen/logrus"
)

func main() {
tokenProvider, err := new(oidcprovider.Factory).New(os.Args[1:])

if err != nil {
logrus.Errorf("cannot initialize oidc-token provider: %v", err)
os.Exit(1)
}

plugin.Serve(&plugin.ServeConfig{
HandshakeConfig: shared.Handshake,
Plugins: map[string]plugin.Plugin{
"tokenProvider": &shared.TokenProviderPlugin{Impl: tokenProvider},
},
// A non-nil value here enables gRPC serving for this plugin...
GRPCServer: plugin.DefaultGRPCServer,
})
}
51 changes: 51 additions & 0 deletions pkg/libs/oidc-provider/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package oidcprovider

import (
"flag"

"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/pkg/registry"
)

func init() {
registry.NewComponentInterface(new(apis.TokenProviderFactory))
registry.Register(new(Factory), "oidc-provider")
}

func (f *pluginMeta) flagSet() *flag.FlagSet {
fs := flag.NewFlagSet("oidc provider settings", flag.ContinueOnError)
return fs
}

type pluginMeta struct {
timeout int

credentialsWatch bool
credentialsFile string
targetAudience string
}

// Factory type
type Factory struct {
}

// New implements apis.TokenProviderFactory
func (t *Factory) New(params []string) (apis.TokenProvider, error) {
pluginMeta := &pluginMeta{}
fs := pluginMeta.flagSet()
fs.IntVar(&pluginMeta.timeout, "timeout", 10, "Request timeout in seconds")
fs.StringVar(&pluginMeta.credentialsFile, "credentials-file", "", "Location of the JSON file with the application credentials")
fs.BoolVar(&pluginMeta.credentialsWatch, "credentials-watch", true, "Watch credential for reload")
fs.StringVar(&pluginMeta.targetAudience, "target-audience", "", "URI of audience claim")

fs.Parse(params)

options := TokenProviderOptions{
Timeout: pluginMeta.timeout,
CredentialsWatch: pluginMeta.credentialsWatch,
CredentialsFile: pluginMeta.credentialsFile,
TargetAudience: pluginMeta.targetAudience,
}

return NewTokenProvider(options)
}
251 changes: 251 additions & 0 deletions pkg/libs/oidc-provider/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package oidcprovider

import (
"context"
"fmt"
"os"
"sync"
"time"

"github.com/cenkalti/backoff"
"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/pkg/libs/oidc"
"github.com/grepplabs/kafka-proxy/pkg/libs/util"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

const (
StatusOK = 0
StatusGetTokenFailed = 1
StatusParseTokenFailed = 2
)

var (
clockSkew = 1 * time.Minute
nowFn = time.Now
)

// TokenProvider - represents TokenProvider
type TokenProvider struct {
timeout time.Duration
idTokenSource idTokenSource

idToken *oidc.Token
l sync.RWMutex
}

// TokenProviderOptions - options specific for oidc
type TokenProviderOptions struct {
Timeout int

CredentialsWatch bool
CredentialsFile string
TargetAudience string
}

// NewTokenProvider - Generate new OIDC token provider
func NewTokenProvider(options TokenProviderOptions) (*TokenProvider, error) {
os.Setenv("OIDC_CREDENTIALS", options.CredentialsFile)
stopChannel := make(chan bool, 1)
var idTokenSource idTokenSource

serviceAccountSource, err := NewServiceAccountSource(
options.CredentialsFile,
options.TargetAudience)

idTokenSource = serviceAccountSource

if err != nil {
return nil, errors.Wrap(err, "creation of service account source failed")
}

if options.CredentialsWatch {
action := func() {
logrus.Infof("reloading credential file %s", options.CredentialsFile)

newConfig, errServ := oidc.NewServiceAccountTokenSource(
options.CredentialsFile,
options.TargetAudience)

if errServ != nil {
logrus.Errorf("error while reloading credentials files: %s", err)
return
}

serviceAccountSource.setServiceAccountTokenSource(newConfig)
}

err = util.WatchForUpdates(options.CredentialsFile, stopChannel, action)

if err != nil {
return nil, errors.Wrap(err, "cannot watch credentials file")
}
}

tokenProvider := &TokenProvider{
timeout: time.Duration(options.Timeout) * time.Second,
idTokenSource: idTokenSource}

op := func() error {
return initToken(tokenProvider)
}

err = backoff.Retry(
op,
backoff.WithMaxTries(backoff.NewConstantBackOff(1*time.Second), 3))

if err != nil {
return nil, errors.Wrap(err, "getting of initial oidc token failed")
}

tokenRefresher := &TokenRefresher{
tokenProvider: tokenProvider,
stopChannel: stopChannel,
}

go tokenRefresher.refreshLoop()

return tokenProvider, nil
}

func initToken(tokenProvider *TokenProvider) error {
ctx, cancel := context.WithTimeout(context.Background(), tokenProvider.timeout)
defer cancel()

resp, err := tokenProvider.GetToken(ctx, apis.TokenRequest{})

if err != nil {
return err
}

if !resp.Success {
return fmt.Errorf("get token failed with status: %d", resp.Status)
}

return nil
}

func (p *TokenProvider) getCurrentToken() string {
p.l.RLock()
defer p.l.RUnlock()

if p.idToken == nil {
return ""
}

if renewLatest(p.idToken) {
return ""
}

return p.idToken.Raw
}

func (p *TokenProvider) getCurrentClaimSet() *oidc.ClaimSet {
p.l.RLock()
defer p.l.RUnlock()

if p.idToken == nil {
return nil
}

return p.idToken.ClaimSet
}

func (p *TokenProvider) setCurrentToken(idToken *oidc.Token) {
p.l.Lock()
defer p.l.Unlock()

p.idToken = idToken
}

func renewLatest(token *oidc.Token) bool {
if token == nil {
return true
}
// renew before expiry
advExp := token.ClaimSet.Exp - int64(clockSkew.Seconds())

if nowFn().Unix() > advExp {
return true
}

return false
}

// GetToken implements apis.TokenProvider.GetToken method
func (p *TokenProvider) GetToken(parent context.Context, _ apis.TokenRequest) (apis.TokenResponse, error) {

currentToken := p.getCurrentToken()

if currentToken != "" {
return getTokenResponse(currentToken, StatusOK)
}

ctx, cancel := context.WithTimeout(parent, p.timeout)
defer cancel()

token, err := p.idTokenSource.GetIDToken(ctx)

if err != nil {
logrus.Errorf("GetIDToken failed %v", err)
return getTokenResponse("", StatusGetTokenFailed)
}

idToken, err := oidc.ParseJWT(token)

if err != nil {
logrus.Error(err)
return getTokenResponse("", StatusParseTokenFailed)
}

p.setCurrentToken(idToken)

logrus.Infof("New token expiry %d (%v)", idToken.ClaimSet.Exp, time.Unix(idToken.ClaimSet.Exp, 0))

return getTokenResponse(token, StatusOK)
}

func getTokenResponse(token string, status int) (apis.TokenResponse, error) {
success := status == StatusOK
return apis.TokenResponse{Success: success, Status: int32(status), Token: token}, nil
}

type idTokenSource interface {
GetIDToken(ctx context.Context) (string, error)
}

type serviceAccountSource struct {
source *oidc.ServiceAccountTokenSource
l sync.RWMutex
}

func (p *serviceAccountSource) getServiceAccountTokenSource() *oidc.ServiceAccountTokenSource {
p.l.RLock()
defer p.l.RUnlock()

return p.source
}

func (p *serviceAccountSource) setServiceAccountTokenSource(source *oidc.ServiceAccountTokenSource) {
p.l.Lock()
defer p.l.Unlock()

p.source = source
}

func NewServiceAccountSource(credentialsFile string, targetAudience string) (*serviceAccountSource, error) {
source, err := oidc.NewServiceAccountTokenSource(credentialsFile, targetAudience)

if err != nil {
return nil, err
}

return &serviceAccountSource{
source: source,
}, nil
}

func (s *serviceAccountSource) GetIDToken(parent context.Context) (string, error) {
return s.getServiceAccountTokenSource().GetIDToken(parent)
}
Loading

0 comments on commit 55af91c

Please sign in to comment.