Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions packages/api/internal/handlers/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,25 +272,25 @@ func (a *APIStore) GetHealth(c *gin.Context) {
c.String(http.StatusServiceUnavailable, "Service is unavailable")
}

func (a *APIStore) GetTeamFromAPIKey(ctx context.Context, _ *gin.Context, apiKey string) (*types.Team, *api.APIError) {
func (a *APIStore) GetTeamFromAPIKey(ctx context.Context, ginCtx *gin.Context, apiKey string) (*types.Team, *api.APIError) {
ctx, span := tracer.Start(ctx, "get team from api key")
defer span.End()

return a.authService.ValidateAPIKey(ctx, apiKey)
return a.authService.ValidateAPIKey(ctx, ginCtx, apiKey)
}

func (a *APIStore) GetUserFromAccessToken(ctx context.Context, _ *gin.Context, accessToken string) (uuid.UUID, *api.APIError) {
func (a *APIStore) GetUserFromAccessToken(ctx context.Context, ginCtx *gin.Context, accessToken string) (uuid.UUID, *api.APIError) {
ctx, span := tracer.Start(ctx, "get user from access token")
defer span.End()

return a.authService.ValidateAccessToken(ctx, accessToken)
return a.authService.ValidateAccessToken(ctx, ginCtx, accessToken)
}

func (a *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, _ *gin.Context, supabaseToken string) (uuid.UUID, *api.APIError) {
func (a *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, supabaseToken string) (uuid.UUID, *api.APIError) {
ctx, span := tracer.Start(ctx, "get user id from supabase token")
defer span.End()

return a.authService.ValidateSupabaseToken(ctx, supabaseToken)
return a.authService.ValidateSupabaseToken(ctx, ginCtx, supabaseToken)
}

func (a *APIStore) GetTeamFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, teamID string) (*types.Team, *api.APIError) {
Expand Down
40 changes: 34 additions & 6 deletions packages/auth/pkg/auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,29 @@ import (
"github.com/google/uuid"

"github.com/e2b-dev/infra/packages/shared/pkg/keys"
"github.com/e2b-dev/infra/packages/shared/pkg/telemetry"
)

type TeamItem interface {
TeamID() string
}

// AuthStore abstracts the DB operations needed for auth validation.
type AuthStore[T any] interface {
type AuthStore[T TeamItem] interface {
GetTeamByHashedAPIKey(ctx context.Context, hashedKey string) (T, error)
GetTeamByIDAndUserID(ctx context.Context, userID uuid.UUID, teamID string) (T, error)
GetUserIDByHashedAccessToken(ctx context.Context, hashedToken string) (uuid.UUID, error)
}

// AuthService encapsulates the cache, store, and JWT secrets for auth validation.
type AuthService[T any] struct {
type AuthService[T TeamItem] struct {
store AuthStore[T]
teamCache *AuthCache[T]
jwtSecrets []string
}

// NewAuthService creates an AuthService with the given store, cache, and JWT secrets.
func NewAuthService[T any](store AuthStore[T], teamCache *AuthCache[T], jwtSecrets []string) *AuthService[T] {
func NewAuthService[T TeamItem](store AuthStore[T], teamCache *AuthCache[T], jwtSecrets []string) *AuthService[T] {
return &AuthService[T]{
store: store,
teamCache: teamCache,
Expand All @@ -36,7 +41,7 @@ func NewAuthService[T any](store AuthStore[T], teamCache *AuthCache[T], jwtSecre
}

// ValidateAPIKey verifies the API key format and fetches the associated team via cache + store.
func (s *AuthService[T]) ValidateAPIKey(ctx context.Context, apiKey string) (T, *APIError) {
func (s *AuthService[T]) ValidateAPIKey(ctx context.Context, ginCtx *gin.Context, apiKey string) (T, *APIError) {
hashedKey, err := keys.VerifyKey(keys.ApiKeyPrefix, apiKey)
if err != nil {
var zero T
Expand Down Expand Up @@ -79,11 +84,17 @@ func (s *AuthService[T]) ValidateAPIKey(ctx context.Context, apiKey string) (T,
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithMaskedAPIKey(apiKey),
telemetry.WithTeamID(result.TeamID()),
)

return result, nil
}

// ValidateAccessToken verifies the access token format and fetches the associated user ID.
func (s *AuthService[T]) ValidateAccessToken(ctx context.Context, accessToken string) (uuid.UUID, *APIError) {
func (s *AuthService[T]) ValidateAccessToken(ctx context.Context, ginCtx *gin.Context, accessToken string) (uuid.UUID, *APIError) {
hashedToken, err := keys.VerifyKey(keys.AccessTokenPrefix, accessToken)
if err != nil {
return uuid.UUID{}, &APIError{
Expand All @@ -102,11 +113,17 @@ func (s *AuthService[T]) ValidateAccessToken(ctx context.Context, accessToken st
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithMaskedAccessToken(accessToken),
telemetry.WithUserID(userID.String()),
)

return userID, nil
}

// ValidateSupabaseToken parses a Supabase JWT and extracts the user ID.
func (s *AuthService[T]) ValidateSupabaseToken(ctx context.Context, supabaseToken string) (uuid.UUID, *APIError) {
func (s *AuthService[T]) ValidateSupabaseToken(ctx context.Context, ginCtx *gin.Context, supabaseToken string) (uuid.UUID, *APIError) {
userID, err := ParseUserIDFromToken(ctx, s.jwtSecrets, supabaseToken)
if err != nil {
return uuid.UUID{}, &APIError{
Expand All @@ -116,6 +133,11 @@ func (s *AuthService[T]) ValidateSupabaseToken(ctx context.Context, supabaseToke
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithUserID(userID.String()),
)

return userID, nil
}

Expand Down Expand Up @@ -165,6 +187,12 @@ func (s *AuthService[T]) ValidateSupabaseTeam(ctx context.Context, ginCtx *gin.C
}
}

//nolint:contextcheck // We use the gin request context to set attributes on the parent span.
telemetry.SetAttributes(ginCtx.Request.Context(),
telemetry.WithUserID(userID.String()),
telemetry.WithTeamID(result.TeamID()),
)

return result, nil
}

Expand Down
6 changes: 5 additions & 1 deletion packages/auth/pkg/types/teams.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package types

import (
"github.com/e2b-dev/infra/packages/db/pkg/auth/queries"
authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries"
)

type Team struct {
Expand All @@ -10,6 +10,10 @@ type Team struct {
Limits *TeamLimits
}

func (t *Team) TeamID() string {
return t.Team.ID.String()
}

func newTeamLimits(
teamLimits *authqueries.TeamLimit,
) *TeamLimits {
Expand Down
11 changes: 1 addition & 10 deletions packages/client-proxy/internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,7 @@ func NewClientProxy(meterProvider metric.MeterProvider, serviceName string, port
return nil, err
}

l := logger.L().With(
zap.String("origin_host", r.Host),
logger.WithSandboxID(sandboxId),
zap.Uint64("sandbox_req_port", port),
zap.String("sandbox_req_path", r.URL.Path),
zap.String("sandbox_req_method", r.Method),
zap.String("sandbox_req_user_agent", r.UserAgent()),
zap.String("remote_addr", r.RemoteAddr),
zap.Int64("content_length", r.ContentLength),
)
l := logger.L().With(logger.ProxyRequestFields(r, sandboxId, port)...)

trafficAccessToken := r.Header.Get(proxygrpc.MetadataTrafficAccessToken)
envdAccessToken := r.Header.Get(proxygrpc.MetadataEnvdHTTPAccessToken)
Expand Down
4 changes: 2 additions & 2 deletions packages/dashboard-api/internal/handlers/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func (s *APIStore) GetHealth(c *gin.Context) {
})
}

func (s *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, _ *gin.Context, supabaseToken string) (uuid.UUID, *sharedauth.APIError) {
return s.authService.ValidateSupabaseToken(ctx, supabaseToken)
func (s *APIStore) GetUserIDFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, supabaseToken string) (uuid.UUID, *sharedauth.APIError) {
return s.authService.ValidateSupabaseToken(ctx, ginCtx, supabaseToken)
}

func (s *APIStore) GetTeamFromSupabaseToken(ctx context.Context, ginCtx *gin.Context, teamID string) (*types.Team, *sharedauth.APIError) {
Expand Down
16 changes: 5 additions & 11 deletions packages/orchestrator/internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"time"

"go.opentelemetry.io/otel/metric"
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox"
"github.com/e2b-dev/infra/packages/shared/pkg/connlimit"
Expand Down Expand Up @@ -103,16 +102,11 @@ func NewSandboxProxy(meterProvider metric.MeterProvider, port uint16, sandboxes
}

logger := logger.L().With(
zap.String("origin_host", r.Host),
logger.WithSandboxID(sbx.Runtime.SandboxID),
logger.WithTeamID(sbx.Runtime.TeamID),
logger.WithSandboxIP(sbx.Slot.HostIPString()),
zap.Uint64("sandbox_req_port", port),
zap.String("sandbox_req_path", r.URL.Path),
zap.String("sandbox_req_method", r.Method),
zap.String("sandbox_req_user_agent", r.UserAgent()),
zap.String("remote_addr", r.RemoteAddr),
zap.Int64("content_length", r.ContentLength),
append(
logger.ProxyRequestFields(r, sbx.Runtime.SandboxID, port),
logger.WithTeamID(sbx.Runtime.TeamID),
logger.WithSandboxIP(sbx.Slot.HostIPString()),
)...,
)

return &pool.Destination{
Expand Down
69 changes: 69 additions & 0 deletions packages/shared/pkg/logger/fields.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package logger

import (
"fmt"
"net"
"net/http"
"strings"

"github.com/google/uuid"
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/shared/pkg/keys"
)

func WithSandboxID(sandboxID string) zap.Field {
Expand All @@ -21,6 +28,10 @@ func WithExecutionID(executionID string) zap.Field {
return zap.String("execution.id", executionID)
}

func WithUserID(userID string) zap.Field {
return zap.String("user.id", userID)
}

func WithTeamID(teamID string) zap.Field {
return zap.String("team.id", teamID)
}
Expand All @@ -44,3 +55,61 @@ func WithSandboxIP(sandboxIP string) zap.Field {
func WithEnvdVersion(envdVersion string) zap.Field {
return zap.String("envd.version", envdVersion)
}

func WithClientIP(clientIP string) zap.Field {
return zap.String("http.client_ip", clientIP)
}

func WithMaskedAPIKey(apiKey string) zap.Field {
return zap.String("auth.api_key", maskedToken(keys.ApiKeyPrefix, apiKey))
}

func WithMaskedAccessToken(accessToken string) zap.Field {
return zap.String("auth.access_token", maskedToken(keys.AccessTokenPrefix, accessToken))
}

// ProxyRequestFields returns the common logger fields for a proxied HTTP request.
func ProxyRequestFields(r *http.Request, sandboxID string, sandboxPort uint64) []zap.Field {
return []zap.Field{
zap.String("origin_host", r.Host),
WithSandboxID(sandboxID),
zap.Uint64("sandbox_req_port", sandboxPort),
zap.String("sandbox_req_path", r.URL.Path),
zap.String("sandbox_req_method", r.Method),
zap.String("sandbox_req_user_agent", r.UserAgent()),
zap.String("remote_addr", r.RemoteAddr),
WithClientIP(clientIP(r)),
zap.Int64("content_length", r.ContentLength),
}
}

func maskedToken(prefix string, token string) string {
tokenWithoutPrefix := strings.TrimPrefix(token, prefix)
masked, err := keys.MaskKey(prefix, tokenWithoutPrefix)
if err != nil {
return "invalid_token_format"
}

return fmt.Sprintf("%s%s...%s", masked.Prefix, masked.MaskedValuePrefix, masked.MaskedValueSuffix)
}

// clientIP extracts the real client IP from the request.
// It reads the first entry from X-Forwarded-For, falling back to RemoteAddr with the port stripped.
//
// This assumes a trusted upstream proxy overwrites the
// X-Forwarded-For header with the real client IP. The header value is NOT
// client-controllable in this setup because the LB always replaces it.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if ip := strings.TrimSpace(strings.SplitN(xff, ",", 2)[0]); ip != "" {
return ip
}
}

host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}

return host
}
12 changes: 12 additions & 0 deletions packages/shared/pkg/telemetry/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func WithClusterID(clusterID uuid.UUID) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithClusterID(clusterID))
}

func WithUserID(userID string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithUserID(userID))
}

func WithTeamID(teamID string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithTeamID(teamID))
}
Expand All @@ -44,6 +48,14 @@ func WithEnvdVersion(envdVersion string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithEnvdVersion(envdVersion))
}

func WithMaskedAPIKey(apiKey string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithMaskedAPIKey(apiKey))
}

func WithMaskedAccessToken(accessToken string) attribute.KeyValue {
return zapFieldToOTELAttribute(logger.WithMaskedAccessToken(accessToken))
}

func zapFieldToOTELAttribute(f zap.Field) attribute.KeyValue {
e := &ZapFieldToOTELAttributeEncoder{}
f.AddTo(e)
Expand Down
Loading