Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions packages/api/internal/handlers/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,25 +272,25 @@ func (a *APIStore) GetHealth(c *gin.Context) {
c.String(http.StatusServiceUnavailable, "Service is unavailable")
}

func (a *APIStore) GetTeamFromAPIKey(ctx context.Context, _ *gin.Context, apiKey string) (*types.Team, *api.APIError) {
func (a *APIStore) GetTeamFromAPIKey(ctx context.Context, ginCtx *gin.Context, apiKey string) (*types.Team, *api.APIError) {
ctx, span := tracer.Start(ctx, "get team from api key")
defer span.End()

return a.authService.ValidateAPIKey(ctx, apiKey)
return a.authService.ValidateAPIKey(ctx, ginCtx, apiKey)
}

func (a *APIStore) GetUserFromAccessToken(ctx context.Context, _ *gin.Context, accessToken string) (uuid.UUID, *api.APIError) {
func (a *APIStore) GetUserFromAccessToken(ctx context.Context, ginCtx *gin.Context, accessToken string) (uuid.UUID, *api.APIError) {
ctx, span := tracer.Start(ctx, "get user from access token")
defer span.End()

return a.authService.ValidateAccessToken(ctx, accessToken)
return a.authService.ValidateAccessToken(ctx, ginCtx, accessToken)
}

func (a *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, _ *gin.Context, supabaseToken string) (uuid.UUID, *api.APIError) {
func (a *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, supabaseToken string) (uuid.UUID, *api.APIError) {
ctx, span := tracer.Start(ctx, "get user id from supabase token")
defer span.End()

return a.authService.ValidateSupabaseToken(ctx, supabaseToken)
return a.authService.ValidateSupabaseToken(ctx, ginCtx, supabaseToken)
}

func (a *APIStore) GetTeamFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, teamID string) (*types.Team, *api.APIError) {
Expand Down
40 changes: 34 additions & 6 deletions packages/auth/pkg/auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,29 @@ import (
"github.com/google/uuid"

"github.com/e2b-dev/infra/packages/shared/pkg/keys"
"github.com/e2b-dev/infra/packages/shared/pkg/telemetry"
)

type TeamItem interface {
TeamID() string
}

// AuthStore abstracts the DB operations needed for auth validation.
type AuthStore[T any] interface {
type AuthStore[T TeamItem] interface {
GetTeamByHashedAPIKey(ctx context.Context, hashedKey string) (T, error)
GetTeamByIDAndUserID(ctx context.Context, userID uuid.UUID, teamID string) (T, error)
GetUserIDByHashedAccessToken(ctx context.Context, hashedToken string) (uuid.UUID, error)
}

// AuthService encapsulates the cache, store, and JWT secrets for auth validation.
type AuthService[T any] struct {
type AuthService[T TeamItem] struct {
store AuthStore[T]
teamCache *AuthCache[T]
jwtSecrets []string
}

// NewAuthService creates an AuthService with the given store, cache, and JWT secrets.
func NewAuthService[T any](store AuthStore[T], teamCache *AuthCache[T], jwtSecrets []string) *AuthService[T] {
func NewAuthService[T TeamItem](store AuthStore[T], teamCache *AuthCache[T], jwtSecrets []string) *AuthService[T] {
return &AuthService[T]{
store: store,
teamCache: teamCache,
Expand All @@ -36,7 +41,7 @@ func NewAuthService[T any](store AuthStore[T], teamCache *AuthCache[T], jwtSecre
}

// ValidateAPIKey verifies the API key format and fetches the associated team via cache + store.
func (s *AuthService[T]) ValidateAPIKey(ctx context.Context, apiKey string) (T, *APIError) {
func (s *AuthService[T]) ValidateAPIKey(ctx context.Context, ginCtx *gin.Context, apiKey string) (T, *APIError) {
hashedKey, err := keys.VerifyKey(keys.ApiKeyPrefix, apiKey)
if err != nil {
var zero T
Expand Down Expand Up @@ -79,11 +84,17 @@ func (s *AuthService[T]) ValidateAPIKey(ctx context.Context, apiKey string) (T,
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithMaskedAPIKey(keys.MaskToken(keys.ApiKeyPrefix, apiKey)),
telemetry.WithTeamID(result.TeamID()),
)

return result, nil
}

// ValidateAccessToken verifies the access token format and fetches the associated user ID.
func (s *AuthService[T]) ValidateAccessToken(ctx context.Context, accessToken string) (uuid.UUID, *APIError) {
func (s *AuthService[T]) ValidateAccessToken(ctx context.Context, ginCtx *gin.Context, accessToken string) (uuid.UUID, *APIError) {
hashedToken, err := keys.VerifyKey(keys.AccessTokenPrefix, accessToken)
if err != nil {
return uuid.UUID{}, &APIError{
Expand All @@ -102,11 +113,17 @@ func (s *AuthService[T]) ValidateAccessToken(ctx context.Context, accessToken st
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithMaskedAccessToken(keys.MaskToken(keys.AccessTokenPrefix, accessToken)),
telemetry.WithUserID(userID.String()),
)

return userID, nil
}

// ValidateSupabaseToken parses a Supabase JWT and extracts the user ID.
func (s *AuthService[T]) ValidateSupabaseToken(ctx context.Context, supabaseToken string) (uuid.UUID, *APIError) {
func (s *AuthService[T]) ValidateSupabaseToken(ctx context.Context, ginCtx *gin.Context, supabaseToken string) (uuid.UUID, *APIError) {
userID, err := ParseUserIDFromToken(ctx, s.jwtSecrets, supabaseToken)
if err != nil {
return uuid.UUID{}, &APIError{
Expand All @@ -116,6 +133,11 @@ func (s *AuthService[T]) ValidateSupabaseToken(ctx context.Context, supabaseToke
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithUserID(userID.String()),
)

return userID, nil
}

Expand Down Expand Up @@ -165,6 +187,12 @@ func (s *AuthService[T]) ValidateSupabaseTeam(ctx context.Context, ginCtx *gin.C
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithUserID(userID.String()),
telemetry.WithTeamID(result.TeamID()),
)

return result, nil
}

Expand Down
6 changes: 5 additions & 1 deletion packages/auth/pkg/types/teams.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package types

import (
"github.com/e2b-dev/infra/packages/db/pkg/auth/queries"
authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries"
)

type Team struct {
Expand All @@ -10,6 +10,10 @@ type Team struct {
Limits *TeamLimits
}

func (t *Team) TeamID() string {
return t.Team.ID.String()
}

func newTeamLimits(
teamLimits *authqueries.TeamLimit,
) *TeamLimits {
Expand Down
11 changes: 1 addition & 10 deletions packages/client-proxy/internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,7 @@ func NewClientProxy(meterProvider metric.MeterProvider, serviceName string, port
return nil, err
}

l := logger.L().With(
zap.String("origin_host", r.Host),
logger.WithSandboxID(sandboxId),
zap.Uint64("sandbox_req_port", port),
zap.String("sandbox_req_path", r.URL.Path),
zap.String("sandbox_req_method", r.Method),
zap.String("sandbox_req_user_agent", r.UserAgent()),
zap.String("remote_addr", r.RemoteAddr),
zap.Int64("content_length", r.ContentLength),
)
l := logger.L().With(logger.ProxyRequestFields(r, sandboxId, port)...)

trafficAccessToken := r.Header.Get(proxygrpc.MetadataTrafficAccessToken)
envdAccessToken := r.Header.Get(proxygrpc.MetadataEnvdHTTPAccessToken)
Expand Down
4 changes: 2 additions & 2 deletions packages/dashboard-api/internal/handlers/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func (s *APIStore) GetHealth(c *gin.Context) {
})
}

func (s *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, _ *gin.Context, supabaseToken string) (uuid.UUID, *sharedauth.APIError) {
return s.authService.ValidateSupabaseToken(ctx, supabaseToken)
func (s *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, supabaseToken string) (uuid.UUID, *sharedauth.APIError) {
return s.authService.ValidateSupabaseToken(ctx, ginCtx, supabaseToken)
}

func (s *APIStore) GetTeamFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, teamID string) (*types.Team, *sharedauth.APIError) {
Expand Down
16 changes: 5 additions & 11 deletions packages/orchestrator/internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"time"

"go.opentelemetry.io/otel/metric"
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox"
"github.com/e2b-dev/infra/packages/shared/pkg/connlimit"
Expand Down Expand Up @@ -103,16 +102,11 @@ func NewSandboxProxy(meterProvider metric.MeterProvider, port uint16, sandboxes
}

logger := logger.L().With(
zap.String("origin_host", r.Host),
logger.WithSandboxID(sbx.Runtime.SandboxID),
logger.WithTeamID(sbx.Runtime.TeamID),
logger.WithSandboxIP(sbx.Slot.HostIPString()),
zap.Uint64("sandbox_req_port", port),
zap.String("sandbox_req_path", r.URL.Path),
zap.String("sandbox_req_method", r.Method),
zap.String("sandbox_req_user_agent", r.UserAgent()),
zap.String("remote_addr", r.RemoteAddr),
zap.Int64("content_length", r.ContentLength),
append(
logger.ProxyRequestFields(r, sbx.Runtime.SandboxID, port),
logger.WithTeamID(sbx.Runtime.TeamID),
logger.WithSandboxIP(sbx.Slot.HostIPString()),
)...,
)

return &pool.Destination{
Expand Down
12 changes: 12 additions & 0 deletions packages/shared/pkg/keys/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ func GenerateKey(prefix string) (Key, error) {
}, nil
}

// MaskToken masks a prefixed token (e.g. API key or access token) for safe logging.
// The raw token must never be passed to logging or telemetry libraries directly.
func MaskToken(prefix, token string) string {
tokenWithoutPrefix := strings.TrimPrefix(token, prefix)
masked, err := MaskKey(prefix, tokenWithoutPrefix)
if err != nil {
return "invalid_token_format"
}

return fmt.Sprintf("%s%s...%s", masked.Prefix, masked.MaskedValuePrefix, masked.MaskedValueSuffix)
}

func VerifyKey(prefix string, key string) (string, error) {
if !strings.HasPrefix(key, prefix) {
return "", fmt.Errorf("invalid key prefix")
Expand Down
56 changes: 56 additions & 0 deletions packages/shared/pkg/logger/fields.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package logger

import (
"net"
"net/http"
"strings"

"github.com/google/uuid"
"go.uber.org/zap"
)
Expand All @@ -21,6 +25,10 @@ func WithExecutionID(executionID string) zap.Field {
return zap.String("execution.id", executionID)
}

func WithUserID(userID string) zap.Field {
return zap.String("user.id", userID)
}

func WithTeamID(teamID string) zap.Field {
return zap.String("team.id", teamID)
}
Expand All @@ -44,3 +52,51 @@ func WithSandboxIP(sandboxIP string) zap.Field {
func WithEnvdVersion(envdVersion string) zap.Field {
return zap.String("envd.version", envdVersion)
}

func WithClientIP(clientIP string) zap.Field {
return zap.String("http.client_ip", clientIP)
}

func WithMaskedAPIKey(maskedAPIKey string) zap.Field {
return zap.String("auth.api_key", maskedAPIKey)
}

func WithMaskedAccessToken(maskedAccessToken string) zap.Field {
return zap.String("auth.access_token", maskedAccessToken)
}

// ProxyRequestFields returns the common logger fields for a proxied HTTP request.
func ProxyRequestFields(r *http.Request, sandboxID string, sandboxPort uint64) []zap.Field {
return []zap.Field{
zap.String("origin_host", r.Host),
WithSandboxID(sandboxID),
zap.Uint64("sandbox_req_port", sandboxPort),
zap.String("sandbox_req_path", r.URL.Path),
zap.String("sandbox_req_method", r.Method),
zap.String("sandbox_req_user_agent", r.UserAgent()),
zap.String("remote_addr", r.RemoteAddr),
WithClientIP(clientIP(r)),
zap.Int64("content_length", r.ContentLength),
}
}

// clientIP extracts the real client IP from the request.
// It reads the first entry from X-Forwarded-For, falling back to RemoteAddr with the port stripped.
//
// This assumes a trusted upstream proxy overwrites the
// X-Forwarded-For header with the real client IP. The header value is NOT
// client-controllable in this setup because the LB always replaces it.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if ip := strings.TrimSpace(strings.SplitN(xff, ",", 2)[0]); ip != "" {
return ip
}
}

host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}

return host
}
12 changes: 12 additions & 0 deletions packages/shared/pkg/telemetry/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func WithClusterID(clusterID uuid.UUID) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithClusterID(clusterID))
}

func WithUserID(userID string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithUserID(userID))
}

func WithTeamID(teamID string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithTeamID(teamID))
}
Expand All @@ -44,6 +48,14 @@ func WithEnvdVersion(envdVersion string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithEnvdVersion(envdVersion))
}

func WithMaskedAPIKey(maskedAPIKey string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithMaskedAPIKey(maskedAPIKey))
}

func WithMaskedAccessToken(maskedAccessToken string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithMaskedAccessToken(maskedAccessToken))
}

func zapFieldToOTELAttribute(f zap.Field) attribute.KeyValue {
e := &ZapFieldToOTELAttributeEncoder{}
f.AddTo(e)
Expand Down
Loading