Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion internal/controller/task_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,23 @@ func (r *TaskReconciler) resolveGitHubAppToken(ctx context.Context, task *kelosv
}
}

tokenResp, err := tc.GenerateInstallationToken(ctx, creds)
// Scope the installation token to repos declared in the workspace.
// This includes the primary repo and any additional remotes.
var repos []string
if workspace.Repo != "" {
if _, _, repoName := parseGitHubRepo(workspace.Repo); repoName != "" {
repos = append(repos, repoName)
}
}
for _, remote := range workspace.Remotes {
if _, _, repoName := parseGitHubRepo(remote.URL); repoName != "" {
repos = append(repos, repoName)
}
}
opts := &githubapp.TokenOptions{
Repositories: repos,
}
tokenResp, err := tc.GenerateInstallationToken(ctx, creds, opts)
if err != nil {
return nil, fmt.Errorf("generating installation token: %w", err)
}
Expand Down
29 changes: 26 additions & 3 deletions internal/githubapp/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package githubapp

import (
"bytes"
"context"
"crypto"
"crypto/rand"
Expand Down Expand Up @@ -51,6 +52,14 @@ type TokenResponse struct {
ExpiresAt time.Time
}

// TokenOptions configures optional scoping for installation tokens.
type TokenOptions struct {
// Repositories limits the token to these repository names.
// Names are relative to the installation owner (e.g., "my-repo",
// not "org/my-repo").
Repositories []string `json:"repositories,omitempty"`
}

// TokenClient generates GitHub App installation tokens.
type TokenClient struct {
BaseURL string
Expand Down Expand Up @@ -125,19 +134,33 @@ func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
}

// GenerateInstallationToken exchanges GitHub App credentials for an installation token.
func (tc *TokenClient) GenerateInstallationToken(ctx context.Context, creds *Credentials) (*TokenResponse, error) {
// When opts is non-nil and contains Repositories, the token is scoped to those repositories.
func (tc *TokenClient) GenerateInstallationToken(ctx context.Context, creds *Credentials, opts *TokenOptions) (*TokenResponse, error) {
jwt, err := generateJWT(creds)
if err != nil {
return nil, fmt.Errorf("generating JWT: %w", err)
}

url := fmt.Sprintf("%s/app/installations/%s/access_tokens", tc.baseURL(), creds.InstallationID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)

var body io.Reader
if opts != nil && len(opts.Repositories) > 0 {
payload, err := json.Marshal(opts)
if err != nil {
return nil, fmt.Errorf("marshaling token options: %w", err)
}
body = bytes.NewReader(payload)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
Comment thread
omercnet marked this conversation as resolved.
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+jwt)
req.Header.Set("Accept", "application/vnd.github.v3+json")
if body != nil {
req.Header.Set("Content-Type", "application/json")
}

resp, err := tc.httpClient().Do(req)
if err != nil {
Expand Down Expand Up @@ -242,7 +265,7 @@ func (tp *TokenProvider) Token(ctx context.Context) (string, error) {
return tp.token, nil
}

resp, err := tp.client.GenerateInstallationToken(ctx, tp.creds)
resp, err := tp.client.GenerateInstallationToken(ctx, tp.creds, nil)
if err != nil {
// Fall back to cached token if it has not actually expired yet
if tp.token != "" && now.Before(tp.expiresAt) {
Expand Down
133 changes: 131 additions & 2 deletions internal/githubapp/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -230,7 +231,7 @@ func TestGenerateInstallationToken(t *testing.T) {
Client: server.Client(),
}

resp, err := tc.GenerateInstallationToken(context.Background(), creds)
resp, err := tc.GenerateInstallationToken(context.Background(), creds, nil)
if err != nil {
t.Fatalf("GenerateInstallationToken: %v", err)
}
Expand Down Expand Up @@ -417,8 +418,136 @@ func TestGenerateInstallationToken_Error(t *testing.T) {
Client: server.Client(),
}

_, err = tc.GenerateInstallationToken(context.Background(), creds)
_, err = tc.GenerateInstallationToken(context.Background(), creds, nil)
if err == nil {
t.Error("expected error for 401 response")
}
}

func TestGenerateInstallationToken_WithRepositories(t *testing.T) {
_, keyPEM := generateTestKey(t)

expiresAt := time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second)

var receivedBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type = %q, want application/json", ct)
}
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading request body: %v", err)
}
if len(body) > 0 {
if err := json.Unmarshal(body, &receivedBody); err != nil {
t.Fatalf("unmarshaling request body: %v", err)
}
}

w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"token": "ghs_scoped_token",
"expires_at": expiresAt.Format(time.RFC3339),
})
}))
defer server.Close()

creds, err := ParseCredentials(map[string][]byte{
"appID": []byte("12345"),
"installationID": []byte("67890"),
"privateKey": keyPEM,
})
if err != nil {
t.Fatalf("ParseCredentials: %v", err)
}

tc := &TokenClient{
BaseURL: server.URL,
Client: server.Client(),
}

opts := &TokenOptions{
Repositories: []string{"my-repo", "other-repo"},
}
resp, err := tc.GenerateInstallationToken(context.Background(), creds, opts)
if err != nil {
t.Fatalf("GenerateInstallationToken: %v", err)
}

if resp.Token != "ghs_scoped_token" {
t.Errorf("Token = %q, want %q", resp.Token, "ghs_scoped_token")
}

// Verify the request body contained the repositories
repos, ok := receivedBody["repositories"]
if !ok {
t.Fatal("request body missing 'repositories' field")
}
repoList, ok := repos.([]interface{})
if !ok {
t.Fatalf("repositories is not an array: %T", repos)
}
if len(repoList) != 2 {
t.Errorf("repositories length = %d, want 2", len(repoList))
}
if repoList[0] != "my-repo" || repoList[1] != "other-repo" {
t.Errorf("repositories = %v, want [my-repo other-repo]", repoList)
}
}

func TestGenerateInstallationToken_NilOpts(t *testing.T) {
_, keyPEM := generateTestKey(t)

expiresAt := time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second)

var requestBodyEmpty bool
var contentType string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
requestBodyEmpty = len(body) == 0
contentType = r.Header.Get("Content-Type")

w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"token": "ghs_unscoped",
"expires_at": expiresAt.Format(time.RFC3339),
})
}))
defer server.Close()

creds, err := ParseCredentials(map[string][]byte{
"appID": []byte("12345"),
"installationID": []byte("67890"),
"privateKey": keyPEM,
})
if err != nil {
t.Fatalf("ParseCredentials: %v", err)
}

tc := &TokenClient{BaseURL: server.URL, Client: server.Client()}

// nil opts should send no body (backward compatible)
_, err = tc.GenerateInstallationToken(context.Background(), creds, nil)
if err != nil {
t.Fatalf("GenerateInstallationToken: %v", err)
}
if !requestBodyEmpty {
t.Error("expected empty request body for nil opts")
}
if contentType != "" {
t.Errorf("expected no Content-Type header for nil opts, got %q", contentType)
}

// Empty repositories should also send no body
requestBodyEmpty = false
_, err = tc.GenerateInstallationToken(context.Background(), creds, &TokenOptions{})
if err != nil {
t.Fatalf("GenerateInstallationToken with empty opts: %v", err)
}
if !requestBodyEmpty {
t.Error("expected empty request body for empty opts")
}
if contentType != "" {
t.Errorf("expected no Content-Type header for empty opts, got %q", contentType)
}
}
Loading