diff --git a/services/auth/internal/errors/errors.go b/services/auth/internal/errors/errors.go index 077526a..a3f3749 100644 --- a/services/auth/internal/errors/errors.go +++ b/services/auth/internal/errors/errors.go @@ -3,11 +3,12 @@ package errors import "errors" var ( + ErrInvalidEmail = errors.New("invalid email") + ErrInvalidInput = errors.New("invalid input") + ErrConflict = errors.New("conflict") ErrInternal = errors.New("internal error") - ErrNotFound = errors.New("resource not found") ErrUnauthorized = errors.New("unauthorized") - ErrConflict = errors.New("resource already exists") - ErrTokenGeneration = errors.New("token generation failed ") - ErrInvalidEmail = errors.New("invalid Email") - ErrInvalidInput = errors.New("invalid input") + ErrTokenGeneration = errors.New("token generation failed") + ErrTooManyAttempts = errors.New("too many login attempts, account locked temporarily") + ErrNotFound = errors.New("not found") ) diff --git a/services/auth/internal/service/auth_service.go b/services/auth/internal/service/auth_service.go index 6752cfb..f5502f9 100644 --- a/services/auth/internal/service/auth_service.go +++ b/services/auth/internal/service/auth_service.go @@ -5,6 +5,7 @@ import ( "errors" "log/slog" "regexp" + "time" "golang.org/x/crypto/bcrypt" @@ -26,23 +27,48 @@ type authService struct { store storage.UserStorage logger *slog.Logger tokenSvc TokenService + // Add rate limiting map or store here if needed + loginAttempts map[string]int + lockoutTime map[string]time.Time + lockoutDuration time.Duration + maxAttempts int } func NewAuthService(store storage.UserStorage, logger *slog.Logger, tokenSvc TokenService) AuthService { l := logger.With("layer", "service", "component", "authService") - return &authService{store: store, logger: l, tokenSvc: tokenSvc} + return &authService{ + store: store, + logger: l, + tokenSvc: tokenSvc, + loginAttempts: make(map[string]int), + lockoutTime: make(map[string]time.Time), + lockoutDuration: 15 * time.Minute, + maxAttempts: 5, + } } func (s *authService) Register(ctx context.Context, email, password string) (*model.User, error) { s.logger.Info("Register called", slog.String("email", email)) - if !regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`).MatchString(email) { + if !regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`).MatchString(email) { s.logger.Error("Invalid email") return nil, appErr.ErrInvalidEmail } - if len(password) == 0 { - s.logger.Error("Invalid password") + if len(password) < 8 { + s.logger.Error("Password too short") + return nil, appErr.ErrInvalidInput + } + + // Enforce password complexity: at least one uppercase, one lowercase, one digit, one special char + var ( + hasUpper = regexp.MustCompile(`[A-Z]`).MatchString + hasLower = regexp.MustCompile(`[a-z]`).MatchString + hasDigit = regexp.MustCompile(`[0-9]`).MatchString + hasSpecial = regexp.MustCompile(`[\W_]`).MatchString + ) + if !hasUpper(password) || !hasLower(password) || !hasDigit(password) || !hasSpecial(password) { + s.logger.Error("Password does not meet complexity requirements") return nil, appErr.ErrInvalidInput } @@ -69,6 +95,18 @@ func (s *authService) Register(ctx context.Context, email, password string) (*mo func (s *authService) Login(ctx context.Context, email, password string) (*model.User, string, error) { s.logger.Info("Login called", slog.String("email", email)) + // Check if user is locked out + if lockoutUntil, locked := s.lockoutTime[email]; locked { + if time.Now().Before(lockoutUntil) { + s.logger.Warn("User account locked due to too many failed login attempts", slog.String("email", email)) + return nil, "", appErr.ErrTooManyAttempts + } else { + // Lockout expired, reset + delete(s.lockoutTime, email) + s.loginAttempts[email] = 0 + } + } + user, err := s.store.GetUserByEmail(ctx, email) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -82,9 +120,19 @@ func (s *authService) Login(ctx context.Context, email, password string) (*model // Compare the provided password with the stored hashed password if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { + s.loginAttempts[email]++ + if s.loginAttempts[email] >= s.maxAttempts { + s.lockoutTime[email] = time.Now().Add(s.lockoutDuration) + s.logger.Warn("User account locked due to too many failed login attempts", slog.String("email", email)) + return nil, "", appErr.ErrTooManyAttempts + } s.logger.Warn("Invalid password", slog.String("email", email)) return nil, "", appErr.ErrUnauthorized } + + // Reset login attempts on successful login + s.loginAttempts[email] = 0 + token, err := s.tokenSvc.GenerateToken(user) if err != nil { s.logger.Error("Token generation failed ", slog.String("email", email)) diff --git a/services/auth/internal/service/token_service.go b/services/auth/internal/service/token_service.go index feb9c60..e17b0e7 100644 --- a/services/auth/internal/service/token_service.go +++ b/services/auth/internal/service/token_service.go @@ -37,6 +37,7 @@ func (s *jwtService) GenerateToken(user *model.User) (string, error) { "email": user.Email, "exp": time.Now().Add(s.expiryTime).Unix(), "iat": time.Now().Unix(), + "nbf": time.Now().Unix(), // Not valid before now } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString([]byte(s.secret)) @@ -44,6 +45,11 @@ func (s *jwtService) GenerateToken(user *model.User) (string, error) { func (s *jwtService) ValidateToken(tokenStr string) (string, string, error) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) { + // Validate the signing method to prevent algorithm confuses attack + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + s.logger.Error("Unexpected signing method", slog.String("method", token.Header["alg"].(string))) + return nil, jwt.ErrSignatureInvalid + } return []byte(s.secret), nil }) @@ -63,6 +69,29 @@ func (s *jwtService) ValidateToken(tokenStr string) (string, string, error) { return "", "", jwt.ErrTokenMalformed } email, ok := claims["email"].(string) + if !ok { + s.logger.Error("Invalid email claim", slog.String("email", email)) + return "", "", jwt.ErrTokenMalformed + } + + // validate time based claims + now := time.Now().Unix() + + // check expiry + if exp, ok := claims["exp"].(float64); ok { + if int64(exp) < now { + s.logger.Error("Token expired", slog.Int64("exp", int64(exp)), slog.Int64("now", now)) + return "", "", jwt.ErrTokenExpired + } + } + + // check not before + if nbf, ok := claims["nbf"].(float64); ok { + if int64(nbf) > now { + s.logger.Error("Token not valid yet", slog.Int64("nbf", int64(nbf)), slog.Int64("now", now)) + return "", "", jwt.ErrTokenNotValidYet + } + } return userID, email, nil } diff --git a/services/auth/internal/storage/user_storage.go b/services/auth/internal/storage/user_storage.go index 553be44..b050361 100644 --- a/services/auth/internal/storage/user_storage.go +++ b/services/auth/internal/storage/user_storage.go @@ -2,7 +2,6 @@ package storage import ( "context" - "fmt" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -60,8 +59,6 @@ func (s *userStorage) GetUserByEmail(ctx context.Context, email string) (*model. if err := row.Scan(&user.ID, &user.Email, &user.Password, &user.CreatedAt); err != nil { return nil, err } - fmt.Println(user) - return &user, nil } diff --git a/services/gateway/.env.example b/services/gateway/.env.example new file mode 100644 index 0000000..73c2c4a --- /dev/null +++ b/services/gateway/.env.example @@ -0,0 +1,42 @@ +# API Gateway Service Configuration + +# Server Configuration +GATEWAY_PORT=8080 +GATEWAY_HOST=0.0.0.0 +GATEWAY_READ_TIMEOUT=30s +GATEWAY_WRITE_TIMEOUT=30s +GATEWAY_IDLE_TIMEOUT=60s +GATEWAY_SHUTDOWN_TIMEOUT=15s + +# Upstream Service URLs +AUTH_SERVICE_URL=http://hcaas_auth:8081 +URL_SERVICE_URL=http://hcaas_web:8080 +NOTIFICATION_SERVICE_URL=http://hcaas_notification:8082 + +# Service Timeouts and Retries +AUTH_SERVICE_TIMEOUT=10s +AUTH_SERVICE_RETRIES=3 +URL_SERVICE_TIMEOUT=10s +URL_SERVICE_RETRIES=3 +NOTIFICATION_SERVICE_TIMEOUT=10s +NOTIFICATION_SERVICE_RETRIES=3 + +# Security Configuration +JWT_SECRET=your-super-secret-jwt-key-change-in-production-min-32-chars +ALLOWED_ORIGINS=* +TRUSTED_PROXIES= + +# Rate Limiting +RATE_LIMIT_ENABLED=true +RATE_LIMIT_RPS=100 +RATE_LIMIT_BURST=200 + +# Observability +METRICS_ENABLED=true +TRACING_ENABLED=true +OTEL_EXPORTER_OTLP_ENDPOINT=http://hcaas_jaeger_all_in_one:4317 +OTEL_SERVICE_NAME=hcaas_gateway_service + +# Logging +LOG_LEVEL=info +SERVICE_VERSION=v1.0.0 diff --git a/services/gateway/Dockerfile b/services/gateway/Dockerfile new file mode 100644 index 0000000..b7603ef --- /dev/null +++ b/services/gateway/Dockerfile @@ -0,0 +1,49 @@ +# Build stage +FROM golang:1.24-alpine AS builder + +# Set working directory +WORKDIR /app + +# Install dependencies +RUN apk add --no-cache git ca-certificates tzdata + +# Copy go mod files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy source code +COPY . . + +# Build the application +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ + -ldflags='-w -s -extldflags "-static"' \ + -tags=netgo \ + -o gateway \ + ./cmd/gateway + +# Final stage +FROM scratch + +# Copy CA certificates for HTTPS requests +COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ + +# Copy timezone data +COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo + +# Copy the binary +COPY --from=builder /app/gateway /gateway + +# Create non-root user +USER 65534:65534 + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD ["/gateway", "healthcheck"] || exit 1 + +# Run the application +ENTRYPOINT ["/gateway"] diff --git a/services/gateway/cmd/gateway/main.go b/services/gateway/cmd/gateway/main.go new file mode 100644 index 0000000..cb99340 --- /dev/null +++ b/services/gateway/cmd/gateway/main.go @@ -0,0 +1,179 @@ +// Package main provides the entry point for the API Gateway service. +// +// This service acts as the single point of entry for all client requests, +// handling authentication, rate limiting, request routing, and observability. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/samims/hcaas/pkg/tracing" + "github.com/samims/hcaas/services/gateway/internal/config" + "github.com/samims/hcaas/services/gateway/internal/handler" + "github.com/samims/hcaas/services/gateway/internal/logger" + "github.com/samims/hcaas/services/gateway/internal/metrics" + gwMiddleware "github.com/samims/hcaas/services/gateway/internal/middleware" + "github.com/samims/hcaas/services/gateway/internal/proxy" +) + +func main() { + // Load configuration + cfg, err := config.Load() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize logger + l := logger.New(cfg.Observability.ServiceName) + l.Info("Starting API Gateway", + logger.FieldService(cfg.Observability.ServiceName), + ) + + // Initialize tracing + var tracerShutdown func(context.Context) error + if cfg.Observability.TracingEnabled { + // Create slog logger for tracing setup + slogLogger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + + shutdown, err := tracing.SetupTracing(context.Background(), slogLogger) + if err != nil { + l.LogError(context.Background(), err, "Failed to setup tracing", + logger.FieldService("tracing"), + ) + // Continue without tracing rather than failing + } else { + tracerShutdown = shutdown + l.Info("Tracing initialized successfully") + } + } + + // Initialize metrics + metricsObj := metrics.New() + + // Initialize service proxy + serviceProxy := proxy.NewServiceProxy(&cfg.Services, l, metricsObj) + + // Initialize middleware + authMiddleware := gwMiddleware.NewAuthMiddleware(cfg.Security.JWTSecret, l, metricsObj) + + var rateLimitMiddleware *gwMiddleware.RateLimitMiddleware + if cfg.RateLimit.Enabled { + rateLimitMiddleware = gwMiddleware.NewRateLimitMiddleware( + cfg.RateLimit.RequestsPerSecond, + cfg.RateLimit.BurstSize, + l, + metricsObj, + ) + defer rateLimitMiddleware.Stop() + } + + // Initialize handlers + healthHandler := handler.NewHealthHandler(l, &cfg.Services) + gatewayHandler := handler.NewGatewayHandler(serviceProxy, l, metricsObj) + + // Setup router + r := chi.NewRouter() + + // Add global middleware + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Recoverer) + r.Use(middleware.Timeout(60 * time.Second)) + + // Add custom middleware + r.Use(gwMiddleware.NewLoggingMiddleware(l, metricsObj).Middleware()) + r.Use(gwMiddleware.NewCORSMiddleware().Middleware()) + + // Add tracing middleware if tracing is enabled + if cfg.Observability.TracingEnabled { + r.Use(gwMiddleware.NewTracingMiddleware(l).Middleware()) + } + + if cfg.RateLimit.Enabled { + r.Use(rateLimitMiddleware.Middleware()) + } + + r.Use(authMiddleware.Middleware()) + + // Health and metrics endpoints (public) + r.Route("/", func(r chi.Router) { + r.Get("/healthz", healthHandler.HealthCheck) + r.Get("/readyz", healthHandler.ReadinessCheck) + r.Get("/metrics", promhttp.Handler().ServeHTTP) + }) + + // API routes (proxied to upstream services) + r.Route("/", func(r chi.Router) { + r.HandleFunc("/*", gatewayHandler.ProxyHandler) + }) + + // Setup HTTP server + server := &http.Server{ + Addr: fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port), + Handler: r, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + IdleTimeout: cfg.Server.IdleTimeout, + } + + // Start server in a goroutine + serverErrors := make(chan error, 1) + go func() { + l.Info("Starting HTTP server", + logger.FieldService("http-server"), + logger.FieldPath(server.Addr), + ) + serverErrors <- server.ListenAndServe() + }() + + // Setup graceful shutdown + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, syscall.SIGINT, syscall.SIGTERM) + + select { + case err := <-serverErrors: + l.LogError(context.Background(), err, "Server startup failed") + + case sig := <-shutdown: + l.Info("Shutdown signal received", + logger.FieldService("shutdown"), + logger.FieldUserID(sig.String()), + ) + + // Give outstanding requests time to complete + ctx, cancel := context.WithTimeout(context.Background(), cfg.Server.ShutdownTimeout) + defer cancel() + + // Shutdown tracer first to flush remaining spans + if tracerShutdown != nil { + if err := tracerShutdown(ctx); err != nil { + l.LogError(ctx, err, "Failed to shutdown tracer") + } else { + l.Info("Tracer shutdown completed") + } + } + + if err := server.Shutdown(ctx); err != nil { + l.LogError(ctx, err, "Graceful shutdown failed") + + // Force shutdown + if err := server.Close(); err != nil { + l.LogError(ctx, err, "Force shutdown failed") + } + } + } + + l.Info("API Gateway shutdown complete") +} diff --git a/services/gateway/go.mod b/services/gateway/go.mod new file mode 100644 index 0000000..2ff4af2 --- /dev/null +++ b/services/gateway/go.mod @@ -0,0 +1,60 @@ +module github.com/samims/hcaas/services/gateway + +go 1.24.4 + +require ( + github.com/go-chi/chi/v5 v5.2.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/prometheus/client_golang v1.20.5 + github.com/samims/hcaas/pkg v0.0.0-00010101000000-000000000000 + go.opentelemetry.io/otel v1.37.0 + go.opentelemetry.io/otel/trace v1.37.0 + golang.org/x/time v0.8.0 +) + +replace github.com/samims/hcaas/pkg => ../../pkg + +require ( + github.com/IBM/sarama v1.45.2 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v5 v5.0.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/eapache/go-resiliency v1.7.0 // indirect + github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 // indirect + github.com/eapache/queue v1.1.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/go-uuid v1.0.3 // indirect + github.com/jcmturner/aescts/v2 v2.0.0 // indirect + github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect + github.com/jcmturner/gofork v1.7.6 // indirect + github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect + github.com/jcmturner/rpc/v2 v2.0.3 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pierrec/lz4/v4 v4.1.22 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.61.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.opentelemetry.io/proto/otlp v1.7.0 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/grpc v1.73.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect +) diff --git a/services/gateway/internal/auth/jwt.go b/services/gateway/internal/auth/jwt.go new file mode 100644 index 0000000..58aa971 --- /dev/null +++ b/services/gateway/internal/auth/jwt.go @@ -0,0 +1,221 @@ +// Package auth provides JWT authentication functionality for the API Gateway. +// +// This package implements JWT validation with proper error handling, +// context propagation, and metrics collection following Google's security practices. +package auth + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// JWTValidator handles JWT token validation and user context extraction. +type JWTValidator struct { + secretKey []byte +} + +// NewJWTValidator creates a new JWT validator with the given secret key. +func NewJWTValidator(secretKey string) *JWTValidator { + return &JWTValidator{ + secretKey: []byte(secretKey), + } +} + +// UserClaims represents the JWT claims structure. +// This should match the claims structure used by your auth service. +type UserClaims struct { + UserID string `json:"user_id"` + Email string `json:"email"` + jwt.RegisteredClaims +} + +// UserContext represents authenticated user context. +type UserContext struct { + UserID string + Email string + TokenType string + IssuedAt time.Time + ExpiresAt time.Time +} + +// ValidateToken validates a JWT token and returns user context. +func (v *JWTValidator) ValidateToken(tokenString string) (*UserContext, error) { + // Parse the token + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + // Validate signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return v.secretKey, nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + // Validate token and extract claims + if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + // Additional validation + if err := v.validateClaims(claims); err != nil { + return nil, fmt.Errorf("invalid claims: %w", err) + } + + return &UserContext{ + UserID: claims.UserID, + Email: claims.Email, + TokenType: "JWT", + IssuedAt: claims.IssuedAt.Time, + ExpiresAt: claims.ExpiresAt.Time, + }, nil + } + + return nil, errors.New("invalid token") +} + +// validateClaims performs additional validation on JWT claims. +func (v *JWTValidator) validateClaims(claims *UserClaims) error { + now := time.Now() + + // Check if token is expired + if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) { + return errors.New("token is expired") + } + + // Check if token is not valid yet (nbf - not before) + if claims.NotBefore != nil && claims.NotBefore.After(now) { + return errors.New("token is not valid yet") + } + + // Validate required fields + if claims.UserID == "" { + return errors.New("missing user_id in token") + } + + if claims.Email == "" { + return errors.New("missing email in token") + } + + return nil +} + +// ExtractTokenFromHeader extracts JWT token from Authorization header. +// Expected format: "Bearer " +func ExtractTokenFromHeader(authHeader string) (string, error) { + if authHeader == "" { + return "", errors.New("missing authorization header") + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 { + return "", errors.New("invalid authorization header format") + } + + scheme := strings.ToLower(parts[0]) + if scheme != "bearer" { + return "", errors.New("unsupported authorization scheme") + } + + token := parts[1] + if token == "" { + return "", errors.New("missing token in authorization header") + } + + return token, nil +} + +// Context keys for storing user information +type contextKey string + +const ( + UserContextKey contextKey = "user_context" + UserIDKey contextKey = "user_id" +) + +// WithUserContext adds user context to the request context. +func WithUserContext(ctx context.Context, userCtx *UserContext) context.Context { + ctx = context.WithValue(ctx, UserContextKey, userCtx) + ctx = context.WithValue(ctx, UserIDKey, userCtx.UserID) + return ctx +} + +// GetUserContext extracts user context from request context. +func GetUserContext(ctx context.Context) (*UserContext, bool) { + userCtx, ok := ctx.Value(UserContextKey).(*UserContext) + return userCtx, ok +} + +// GetUserID extracts user ID from request context. +func GetUserID(ctx context.Context) (string, bool) { + userID, ok := ctx.Value(UserIDKey).(string) + return userID, ok +} + +// ValidationResult represents the result of token validation. +type ValidationResult string + +const ( + ValidationSuccess ValidationResult = "success" + ValidationExpired ValidationResult = "expired" + ValidationInvalid ValidationResult = "invalid" + ValidationMissing ValidationResult = "missing" +) + +// GetValidationResult determines the validation result from an error. +func GetValidationResult(err error) ValidationResult { + if err == nil { + return ValidationSuccess + } + + errStr := strings.ToLower(err.Error()) + switch { + case strings.Contains(errStr, "expired"): + return ValidationExpired + case strings.Contains(errStr, "missing"): + return ValidationMissing + default: + return ValidationInvalid + } +} + +// TokenInfo provides additional information about a token for debugging/logging. +type TokenInfo struct { + Valid bool + UserID string + Email string + IssuedAt time.Time + ExpiresAt time.Time + TimeLeft time.Duration +} + +// GetTokenInfo extracts token information for logging/debugging purposes. +func (v *JWTValidator) GetTokenInfo(tokenString string) *TokenInfo { + info := &TokenInfo{Valid: false} + + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + return v.secretKey, nil + }) + + if err != nil { + return info + } + + if claims, ok := token.Claims.(*UserClaims); ok { + info.UserID = claims.UserID + info.Email = claims.Email + if claims.IssuedAt != nil { + info.IssuedAt = claims.IssuedAt.Time + } + if claims.ExpiresAt != nil { + info.ExpiresAt = claims.ExpiresAt.Time + info.TimeLeft = time.Until(claims.ExpiresAt.Time) + } + info.Valid = token.Valid && info.TimeLeft > 0 + } + + return info +} diff --git a/services/gateway/internal/config/config.go b/services/gateway/internal/config/config.go new file mode 100644 index 0000000..370ac34 --- /dev/null +++ b/services/gateway/internal/config/config.go @@ -0,0 +1,198 @@ +// Package config provides configuration management for the API Gateway service. +// +// It follows Google's configuration patterns with environment-based configuration, +// validation, and structured logging of configuration parameters. +package config + +import ( + "fmt" + "os" + "strconv" + "time" +) + +// Config holds all configuration for the API Gateway service. +// It uses struct tags for validation and follows Google's naming conventions. +type Config struct { + // Server configuration + Server ServerConfig `validate:"required"` + + // Services configuration for upstream services + Services ServicesConfig `validate:"required"` + + // Security configuration + Security SecurityConfig `validate:"required"` + + // Observability configuration + Observability ObservabilityConfig `validate:"required"` + + // Rate limiting configuration + RateLimit RateLimitConfig `validate:"required"` +} + +// ServerConfig contains HTTP server configuration. +type ServerConfig struct { + Port string `validate:"required"` + Host string + ReadTimeout time.Duration `validate:"min=1s"` + WriteTimeout time.Duration `validate:"min=1s"` + IdleTimeout time.Duration `validate:"min=1s"` + ShutdownTimeout time.Duration `validate:"min=1s"` +} + +// ServicesConfig defines upstream service endpoints. +type ServicesConfig struct { + AuthService ServiceEndpoint `validate:"required"` + URLService ServiceEndpoint `validate:"required"` + NotificationService ServiceEndpoint `validate:"required"` +} + +// ServiceEndpoint represents a single upstream service configuration. +type ServiceEndpoint struct { + BaseURL string `validate:"required,url"` + Timeout time.Duration `validate:"min=1s"` + Retries int `validate:"min=0,max=5"` +} + +// SecurityConfig contains JWT and security-related configuration. +type SecurityConfig struct { + JWTSecret string `validate:"required,min=32"` + AllowedOrigins []string + TrustedProxies []string +} + +// ObservabilityConfig contains monitoring and tracing configuration. +type ObservabilityConfig struct { + MetricsEnabled bool + TracingEnabled bool + OTELEndpoint string + ServiceName string `validate:"required"` +} + +// RateLimitConfig defines rate limiting parameters. +type RateLimitConfig struct { + Enabled bool + RequestsPerSecond int `validate:"min=1"` + BurstSize int `validate:"min=1"` +} + +// Load reads configuration from environment variables with sensible defaults. +// This follows Google's practice of environment-based configuration with validation. +func Load() (*Config, error) { + cfg := &Config{ + Server: ServerConfig{ + Port: getEnvOrDefault("GATEWAY_PORT", "8080"), + Host: getEnvOrDefault("GATEWAY_HOST", "0.0.0.0"), + ReadTimeout: getDurationEnvOrDefault("GATEWAY_READ_TIMEOUT", 30*time.Second), + WriteTimeout: getDurationEnvOrDefault("GATEWAY_WRITE_TIMEOUT", 30*time.Second), + IdleTimeout: getDurationEnvOrDefault("GATEWAY_IDLE_TIMEOUT", 60*time.Second), + ShutdownTimeout: getDurationEnvOrDefault("GATEWAY_SHUTDOWN_TIMEOUT", 15*time.Second), + }, + Services: ServicesConfig{ + AuthService: ServiceEndpoint{ + BaseURL: getEnvOrDefault("AUTH_SERVICE_URL", "http://hcaas_auth:8081"), + Timeout: getDurationEnvOrDefault("AUTH_SERVICE_TIMEOUT", 10*time.Second), + Retries: getIntEnvOrDefault("AUTH_SERVICE_RETRIES", 3), + }, + URLService: ServiceEndpoint{ + BaseURL: getEnvOrDefault("URL_SERVICE_URL", "http://hcaas_web:8080"), + Timeout: getDurationEnvOrDefault("URL_SERVICE_TIMEOUT", 10*time.Second), + Retries: getIntEnvOrDefault("URL_SERVICE_RETRIES", 3), + }, + NotificationService: ServiceEndpoint{ + BaseURL: getEnvOrDefault("NOTIFICATION_SERVICE_URL", "http://hcaas_notification:8082"), + Timeout: getDurationEnvOrDefault("NOTIFICATION_SERVICE_TIMEOUT", 10*time.Second), + Retries: getIntEnvOrDefault("NOTIFICATION_SERVICE_RETRIES", 3), + }, + }, + Security: SecurityConfig{ + JWTSecret: getEnvOrDefault("JWT_SECRET", "your-super-secret-jwt-key-change-in-production-min-32-chars"), + AllowedOrigins: getSliceEnvOrDefault("ALLOWED_ORIGINS", []string{"*"}), + TrustedProxies: getSliceEnvOrDefault("TRUSTED_PROXIES", []string{}), + }, + Observability: ObservabilityConfig{ + MetricsEnabled: getBoolEnvOrDefault("METRICS_ENABLED", true), + TracingEnabled: getBoolEnvOrDefault("TRACING_ENABLED", true), + OTELEndpoint: getEnvOrDefault("OTEL_EXPORTER_OTLP_ENDPOINT", "http://hcaas_jaeger_all_in_one:4317"), + ServiceName: getEnvOrDefault("OTEL_SERVICE_NAME", "hcaas_gateway_service"), + }, + RateLimit: RateLimitConfig{ + Enabled: getBoolEnvOrDefault("RATE_LIMIT_ENABLED", true), + RequestsPerSecond: getIntEnvOrDefault("RATE_LIMIT_RPS", 100), + BurstSize: getIntEnvOrDefault("RATE_LIMIT_BURST", 200), + }, + } + + // Validate configuration + if err := cfg.validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + return cfg, nil +} + +// validate performs basic validation on the configuration. +// In a production system, you might use a validation library like go-playground/validator. +func (c *Config) validate() error { + if c.Security.JWTSecret == "" { + return fmt.Errorf("JWT_SECRET is required") + } + if len(c.Security.JWTSecret) < 32 { + return fmt.Errorf("JWT_SECRET must be at least 32 characters long") + } + if c.Server.Port == "" { + return fmt.Errorf("GATEWAY_PORT is required") + } + if c.Services.AuthService.BaseURL == "" { + return fmt.Errorf("AUTH_SERVICE_URL is required") + } + if c.Services.URLService.BaseURL == "" { + return fmt.Errorf("URL_SERVICE_URL is required") + } + if c.Services.NotificationService.BaseURL == "" { + return fmt.Errorf("NOTIFICATION_SERVICE_URL is required") + } + return nil +} + +// Helper functions for environment variable parsing + +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getDurationEnvOrDefault(key string, defaultValue time.Duration) time.Duration { + if value := os.Getenv(key); value != "" { + if duration, err := time.ParseDuration(value); err == nil { + return duration + } + } + return defaultValue +} + +func getIntEnvOrDefault(key string, defaultValue int) int { + if value := os.Getenv(key); value != "" { + if intValue, err := strconv.Atoi(value); err == nil { + return intValue + } + } + return defaultValue +} + +func getBoolEnvOrDefault(key string, defaultValue bool) bool { + if value := os.Getenv(key); value != "" { + if boolValue, err := strconv.ParseBool(value); err == nil { + return boolValue + } + } + return defaultValue +} + +func getSliceEnvOrDefault(key string, defaultValue []string) []string { + // In a real implementation, you might parse comma-separated values + // For now, return the default + return defaultValue +} diff --git a/services/gateway/internal/handler/gateway.go b/services/gateway/internal/handler/gateway.go new file mode 100644 index 0000000..0604c08 --- /dev/null +++ b/services/gateway/internal/handler/gateway.go @@ -0,0 +1,101 @@ +// Package handler provides HTTP handlers for the API Gateway service. +// +// This package implements request handling, response processing, and +// integration with the service proxy following Google's API design patterns. +package handler + +import ( + "net/http" + "strings" + "time" + + "github.com/samims/hcaas/services/gateway/internal/logger" + "github.com/samims/hcaas/services/gateway/internal/metrics" + "github.com/samims/hcaas/services/gateway/internal/proxy" +) + +// GatewayHandler handles API Gateway requests and coordinates with the service proxy. +type GatewayHandler struct { + proxy *proxy.ServiceProxy + logger *logger.Logger + metrics *metrics.Metrics +} + +// NewGatewayHandler creates a new gateway handler. +func NewGatewayHandler(proxy *proxy.ServiceProxy, logger *logger.Logger, metrics *metrics.Metrics) *GatewayHandler { + return &GatewayHandler{ + proxy: proxy, + logger: logger.WithComponent("gateway-handler"), + metrics: metrics, + } +} + +// ProxyHandler handles all proxy requests to upstream services. +func (h *GatewayHandler) ProxyHandler(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // Proxy the request to the appropriate upstream service + err := h.proxy.ProxyRequest(r.Context(), w, r) + + duration := time.Since(startTime) + + if err != nil { + // Log the error + h.logger.LogError(r.Context(), err, "Failed to proxy request", + logger.FieldPath(r.URL.Path), + logger.FieldMethod(r.Method), + logger.FieldLatency(duration.Milliseconds()), + ) + + // Record error metrics + h.metrics.RecordError(metrics.ComponentProxy, metrics.ErrorTypeUpstream) + + // Return error response + h.handleProxyError(w, r, err) + return + } + + // Log successful proxy + h.logger.WithContext(r.Context()).Debug("Request proxied successfully", + logger.FieldPath(r.URL.Path), + logger.FieldMethod(r.Method), + logger.FieldLatency(duration.Milliseconds()), + ) +} + +// handleProxyError handles proxy errors and returns appropriate HTTP responses. +func (h *GatewayHandler) handleProxyError(w http.ResponseWriter, r *http.Request, err error) { + // Set CORS headers + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Content-Type", "application/json") + + // Determine status code based on error type + statusCode := http.StatusInternalServerError + message := "Internal server error" + code := "PROXY_ERROR" + + // You might want to implement more sophisticated error classification + errStr := err.Error() + switch { + case strings.Contains(errStr, "no route found"): + statusCode = http.StatusNotFound + message = "Endpoint not found" + code = "ROUTE_NOT_FOUND" + case strings.Contains(errStr, "timeout"): + statusCode = http.StatusGatewayTimeout + message = "Request timeout" + code = "TIMEOUT_ERROR" + case strings.Contains(errStr, "connection"): + statusCode = http.StatusBadGateway + message = "Upstream service unavailable" + code = "SERVICE_UNAVAILABLE" + } + + w.WriteHeader(statusCode) + + // Write JSON error response + errorResponse := `{"error": "` + message + `", "code": "` + code + `"}` + w.Write([]byte(errorResponse)) +} diff --git a/services/gateway/internal/handler/health.go b/services/gateway/internal/handler/health.go new file mode 100644 index 0000000..db05f34 --- /dev/null +++ b/services/gateway/internal/handler/health.go @@ -0,0 +1,144 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "time" + + "github.com/samims/hcaas/services/gateway/internal/config" + "github.com/samims/hcaas/services/gateway/internal/logger" +) + +// HealthHandler provides health check endpoints. +type HealthHandler struct { + logger *logger.Logger + config *config.ServicesConfig + client *http.Client +} + +// NewHealthHandler creates a new health handler. +func NewHealthHandler(logger *logger.Logger, cfg *config.ServicesConfig) *HealthHandler { + return &HealthHandler{ + logger: logger.WithComponent("health-handler"), + config: cfg, + client: &http.Client{ + Timeout: 5 * time.Second, // Short timeout for health checks + }, + } +} + +// HealthResponse represents the structure of health check responses. +type HealthResponse struct { + Status string `json:"status"` + Timestamp time.Time `json:"timestamp"` + Service string `json:"service"` + Version string `json:"version,omitempty"` + Checks map[string]string `json:"checks,omitempty"` +} + +// HealthCheck handles the /healthz endpoint. +func (h *HealthHandler) HealthCheck(w http.ResponseWriter, r *http.Request) { + version := os.Getenv("SERVICE_VERSION") + if version == "" { + version = "dev" + } + + response := HealthResponse{ + Status: "healthy", + Timestamp: time.Now().UTC(), + Service: "hcaas-gateway", + Version: version, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) +} + +// ReadinessCheck handles the /readyz endpoint with real upstream service validation. +func (h *HealthHandler) ReadinessCheck(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + checks := make(map[string]string) + overallStatus := "ready" + statusCode := http.StatusOK + + // Check Auth Service + if h.checkUpstreamService(ctx, "auth", h.config.AuthService.BaseURL+"/healthz") { + checks["auth_service"] = "healthy" + } else { + checks["auth_service"] = "unhealthy" + overallStatus = "not_ready" + statusCode = http.StatusServiceUnavailable + } + + // Check URL Service + if h.checkUpstreamService(ctx, "url", h.config.URLService.BaseURL+"/health") { + checks["url_service"] = "healthy" + } else { + checks["url_service"] = "unhealthy" + overallStatus = "not_ready" + statusCode = http.StatusServiceUnavailable + } + + // Check Notification Service + if h.checkUpstreamService(ctx, "notification", h.config.NotificationService.BaseURL+"/healthz") { + checks["notification_service"] = "healthy" + } else { + checks["notification_service"] = "unhealthy" + overallStatus = "not_ready" + statusCode = http.StatusServiceUnavailable + } + + // Internal checks + checks["gateway_proxy"] = "ready" + checks["middleware_stack"] = "ready" + + response := HealthResponse{ + Status: overallStatus, + Timestamp: time.Now().UTC(), + Service: "hcaas-gateway", + Checks: checks, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(response) +} + +// checkUpstreamService performs a health check on an upstream service. +func (h *HealthHandler) checkUpstreamService(ctx context.Context, serviceName, healthURL string) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil) + if err != nil { + h.logger.WithContext(ctx).Error("Failed to create health check request", + logger.FieldService(serviceName), + logger.FieldPath(healthURL), + ) + return false + } + + resp, err := h.client.Do(req) + if err != nil { + h.logger.WithContext(ctx).Error("Health check request failed", + logger.FieldService(serviceName), + logger.FieldPath(healthURL), + ) + return false + } + defer resp.Body.Close() + + // Consider 2xx status codes as healthy + healthy := resp.StatusCode >= 200 && resp.StatusCode < 300 + if !healthy { + h.logger.WithContext(ctx).Warn("Upstream service health check failed", + logger.FieldService(serviceName), + logger.FieldStatus(resp.StatusCode), + ) + } + + return healthy +} diff --git a/services/gateway/internal/logger/logger.go b/services/gateway/internal/logger/logger.go new file mode 100644 index 0000000..e9b89d6 --- /dev/null +++ b/services/gateway/internal/logger/logger.go @@ -0,0 +1,164 @@ +// Package logger provides structured logging functionality for the API Gateway service. +// +// This package follows Google's logging practices with structured JSON logging, +// context-aware logging, and proper log levels. +package logger + +import ( + "context" + "log/slog" + "os" + + "go.opentelemetry.io/otel/trace" +) + +// Logger wraps slog.Logger with additional functionality. +type Logger struct { + *slog.Logger +} + +// New creates a new structured logger instance. +// It uses JSON handler for production environments and text handler for development. +func New(serviceName string) *Logger { + // Determine log level from environment + level := getLogLevel() + + // Create handler options + opts := &slog.HandlerOptions{ + Level: level, + AddSource: level == slog.LevelDebug, // Only add source in debug mode + } + + // Use JSON handler for structured logging + handler := slog.NewJSONHandler(os.Stdout, opts) + + // Create logger with service name as a constant field + logger := slog.New(handler).With( + slog.String("service", serviceName), + slog.String("version", getVersion()), + ) + + return &Logger{Logger: logger} +} + +// WithContext returns a logger with context-derived fields. +// This is useful for adding request-scoped information like trace IDs. +func (l *Logger) WithContext(ctx context.Context) *Logger { + logger := l.Logger + + // Add trace ID if available in context + if traceID := getTraceIDFromContext(ctx); traceID != "" { + logger = logger.With(slog.String("trace_id", traceID)) + } + + // Add span ID if available in context + if spanID := getSpanIDFromContext(ctx); spanID != "" { + logger = logger.With(slog.String("span_id", spanID)) + } + + return &Logger{Logger: logger} +} + +// WithFields returns a logger with additional structured fields. +// This follows Google's practice of adding structured context to logs. +func (l *Logger) WithFields(fields ...slog.Attr) *Logger { + args := make([]any, 0, len(fields)) + for _, field := range fields { + args = append(args, field) + } + return &Logger{Logger: l.With(args...)} +} + +// WithComponent returns a logger with a component field. +// This is useful for distinguishing between different parts of the service. +func (l *Logger) WithComponent(component string) *Logger { + return &Logger{Logger: l.With(slog.String("component", component))} +} + +// LogRequest logs HTTP request information in a structured format. +func (l *Logger) LogRequest(ctx context.Context, method, path, userAgent string, statusCode, duration int) { + l.WithContext(ctx).Info("HTTP request", + slog.String("method", method), + slog.String("path", path), + slog.String("user_agent", userAgent), + slog.Int("status_code", statusCode), + slog.Int("duration_ms", duration), + ) +} + +// LogError logs error information with additional context. +func (l *Logger) LogError(ctx context.Context, err error, msg string, fields ...slog.Attr) { + logger := l.WithContext(ctx) + + args := []any{ + slog.String("error", err.Error()), + } + + for _, field := range fields { + args = append(args, field) + } + + logger.Error(msg, args...) +} + +// LogServiceCall logs outbound service calls for observability. +func (l *Logger) LogServiceCall(ctx context.Context, service, method, endpoint string, statusCode, duration int) { + l.WithContext(ctx).Info("Service call", + slog.String("target_service", service), + slog.String("method", method), + slog.String("endpoint", endpoint), + slog.Int("status_code", statusCode), + slog.Int("duration_ms", duration), + ) +} + +// Helper functions + +func getLogLevel() slog.Level { + switch os.Getenv("LOG_LEVEL") { + case "DEBUG", "debug": + return slog.LevelDebug + case "INFO", "info": + return slog.LevelInfo + case "WARN", "warn", "WARNING", "warning": + return slog.LevelWarn + case "ERROR", "error": + return slog.LevelError + default: + return slog.LevelInfo + } +} + +func getVersion() string { + // In a real application, this might come from build flags or version file + return os.Getenv("SERVICE_VERSION") +} + +// getTraceIDFromContext extracts trace ID from OpenTelemetry context. +func getTraceIDFromContext(ctx context.Context) string { + spanContext := trace.SpanContextFromContext(ctx) + if spanContext.IsValid() { + return spanContext.TraceID().String() + } + return "" +} + +// getSpanIDFromContext extracts span ID from OpenTelemetry context. +func getSpanIDFromContext(ctx context.Context) string { + spanContext := trace.SpanContextFromContext(ctx) + if spanContext.IsValid() { + return spanContext.SpanID().String() + } + return "" +} + +// Common log fields for consistent logging +var ( + FieldUserID = func(userID string) slog.Attr { return slog.String("user_id", userID) } + FieldRequestID = func(requestID string) slog.Attr { return slog.String("request_id", requestID) } + FieldLatency = func(ms int64) slog.Attr { return slog.Int64("latency_ms", ms) } + FieldStatus = func(status int) slog.Attr { return slog.Int("status", status) } + FieldMethod = func(method string) slog.Attr { return slog.String("method", method) } + FieldPath = func(path string) slog.Attr { return slog.String("path", path) } + FieldService = func(service string) slog.Attr { return slog.String("service", service) } +) diff --git a/services/gateway/internal/metrics/metrics.go b/services/gateway/internal/metrics/metrics.go new file mode 100644 index 0000000..72d44cb --- /dev/null +++ b/services/gateway/internal/metrics/metrics.go @@ -0,0 +1,259 @@ +// Package metrics provides Prometheus metrics for the API Gateway service. +// +// This package follows Google's SRE practices for monitoring and observability, +// implementing the four golden signals: latency, traffic, errors, and saturation. +package metrics + +import ( + "strconv" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +const ( + // Namespace for all API Gateway metrics + namespace = "hcaas_gateway" + + // Common label names + labelMethod = "method" + labelPath = "path" + labelStatusCode = "status_code" + labelService = "service" + labelEndpoint = "endpoint" +) + +// Metrics holds all Prometheus metrics for the API Gateway. +type Metrics struct { + // HTTP request metrics (Golden Signal: Traffic & Latency) + httpRequestsTotal *prometheus.CounterVec + httpRequestDuration *prometheus.HistogramVec + + // Service proxy metrics + upstreamRequestsTotal *prometheus.CounterVec + upstreamRequestDuration *prometheus.HistogramVec + upstreamRequestsInFlight *prometheus.GaugeVec + + // Rate limiting metrics + rateLimitHits *prometheus.CounterVec + rateLimitActive *prometheus.GaugeVec + + // Authentication metrics + authValidationsTotal *prometheus.CounterVec + authValidationDuration *prometheus.HistogramVec + + // Error metrics (Golden Signal: Errors) + errorTotal *prometheus.CounterVec + + // Resource utilization (Golden Signal: Saturation) + activeConnections *prometheus.GaugeVec + memoryUsage *prometheus.GaugeVec +} + +// New creates and registers all Prometheus metrics for the API Gateway. +// This follows Google's practice of registering metrics at startup. +func New() *Metrics { + return &Metrics{ + // HTTP request metrics + httpRequestsTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Name: "http_requests_total", + Help: "Total number of HTTP requests processed by the gateway.", + }, + []string{labelMethod, labelPath, labelStatusCode}, + ), + + httpRequestDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Name: "http_request_duration_seconds", + Help: "Duration of HTTP requests in seconds.", + // Buckets optimized for API response times + Buckets: []float64{.001, .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10}, + }, + []string{labelMethod, labelPath, labelStatusCode}, + ), + + // Upstream service metrics + upstreamRequestsTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Name: "upstream_requests_total", + Help: "Total number of requests sent to upstream services.", + }, + []string{labelService, labelMethod, labelEndpoint, labelStatusCode}, + ), + + upstreamRequestDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Name: "upstream_request_duration_seconds", + Help: "Duration of upstream service requests in seconds.", + Buckets: []float64{.001, .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10}, + }, + []string{labelService, labelMethod, labelEndpoint}, + ), + + upstreamRequestsInFlight: promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: namespace, + Name: "upstream_requests_in_flight", + Help: "Number of upstream requests currently being processed.", + }, + []string{labelService}, + ), + + // Rate limiting metrics + rateLimitHits: promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Name: "rate_limit_hits_total", + Help: "Total number of rate limit hits.", + }, + []string{"client_ip", "path"}, + ), + + rateLimitActive: promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: namespace, + Name: "rate_limit_active", + Help: "Number of active rate limit buckets.", + }, + []string{"path"}, + ), + + // Authentication metrics + authValidationsTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Name: "auth_validations_total", + Help: "Total number of JWT validations performed.", + }, + []string{"result"}, // success, failure, expired, invalid + ), + + authValidationDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Name: "auth_validation_duration_seconds", + Help: "Duration of JWT validation operations.", + Buckets: prometheus.DefBuckets, + }, + []string{"result"}, + ), + + // Error metrics + errorTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Name: "errors_total", + Help: "Total number of errors by type and component.", + }, + []string{"component", "error_type"}, + ), + + // Resource utilization + activeConnections: promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: namespace, + Name: "active_connections", + Help: "Number of active connections to the gateway.", + }, + []string{"type"}, // http, websocket, etc. + ), + + memoryUsage: promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: namespace, + Name: "memory_usage_bytes", + Help: "Memory usage in bytes by component.", + }, + []string{"component"}, + ), + } +} + +// RecordHTTPRequest records metrics for an HTTP request. +// This implements the Golden Signals for HTTP requests. +func (m *Metrics) RecordHTTPRequest(method, path string, statusCode int, duration time.Duration) { + statusStr := strconv.Itoa(statusCode) + + m.httpRequestsTotal.WithLabelValues(method, path, statusStr).Inc() + m.httpRequestDuration.WithLabelValues(method, path, statusStr).Observe(duration.Seconds()) +} + +// RecordUpstreamRequest records metrics for upstream service requests. +func (m *Metrics) RecordUpstreamRequest(service, method, endpoint string, statusCode int, duration time.Duration) { + statusStr := strconv.Itoa(statusCode) + + m.upstreamRequestsTotal.WithLabelValues(service, method, endpoint, statusStr).Inc() + m.upstreamRequestDuration.WithLabelValues(service, method, endpoint).Observe(duration.Seconds()) +} + +// RecordUpstreamRequestStart tracks the start of an upstream request. +func (m *Metrics) RecordUpstreamRequestStart(service string) { + m.upstreamRequestsInFlight.WithLabelValues(service).Inc() +} + +// RecordUpstreamRequestEnd tracks the end of an upstream request. +func (m *Metrics) RecordUpstreamRequestEnd(service string) { + m.upstreamRequestsInFlight.WithLabelValues(service).Dec() +} + +// RecordRateLimitHit records a rate limit hit. +func (m *Metrics) RecordRateLimitHit(clientIP, path string) { + m.rateLimitHits.WithLabelValues(clientIP, path).Inc() +} + +// SetRateLimitActive sets the number of active rate limit buckets for a path. +func (m *Metrics) SetRateLimitActive(path string, count float64) { + m.rateLimitActive.WithLabelValues(path).Set(count) +} + +// RecordAuthValidation records JWT validation metrics. +func (m *Metrics) RecordAuthValidation(result string, duration time.Duration) { + m.authValidationsTotal.WithLabelValues(result).Inc() + m.authValidationDuration.WithLabelValues(result).Observe(duration.Seconds()) +} + +// RecordError records an error by component and type. +func (m *Metrics) RecordError(component, errorType string) { + m.errorTotal.WithLabelValues(component, errorType).Inc() +} + +// SetActiveConnections sets the number of active connections. +func (m *Metrics) SetActiveConnections(connType string, count float64) { + m.activeConnections.WithLabelValues(connType).Set(count) +} + +// SetMemoryUsage sets memory usage for a component. +func (m *Metrics) SetMemoryUsage(component string, bytes float64) { + m.memoryUsage.WithLabelValues(component).Set(bytes) +} + +// AuthValidationResult constants for consistent labeling +const ( + AuthValidationSuccess = "success" + AuthValidationFailure = "failure" + AuthValidationExpired = "expired" + AuthValidationInvalid = "invalid" +) + +// ErrorType constants for consistent error categorization +const ( + ErrorTypeUpstream = "upstream" + ErrorTypeAuth = "auth" + ErrorTypeRateLimit = "rate_limit" + ErrorTypeValidation = "validation" + ErrorTypeInternal = "internal" +) + +// Component constants for consistent component identification +const ( + ComponentProxy = "proxy" + ComponentAuth = "auth" + ComponentRateLimit = "rate_limit" + ComponentMiddleware = "middleware" +) diff --git a/services/gateway/internal/middleware/auth.go b/services/gateway/internal/middleware/auth.go new file mode 100644 index 0000000..7b59729 --- /dev/null +++ b/services/gateway/internal/middleware/auth.go @@ -0,0 +1,201 @@ +// Package middleware provides HTTP middleware for the API Gateway. +// +// This package implements authentication, authorization, and request processing +// middleware following Google's security and observability best practices. +package middleware + +import ( + "net/http" + "strings" + "time" + + "github.com/samims/hcaas/services/gateway/internal/auth" + "github.com/samims/hcaas/services/gateway/internal/logger" + "github.com/samims/hcaas/services/gateway/internal/metrics" +) + +// AuthMiddleware provides JWT authentication middleware. +type AuthMiddleware struct { + validator *auth.JWTValidator + logger *logger.Logger + metrics *metrics.Metrics + + // Paths that don't require authentication + publicPaths []string +} + +// NewAuthMiddleware creates a new authentication middleware. +func NewAuthMiddleware(jwtSecret string, logger *logger.Logger, metrics *metrics.Metrics) *AuthMiddleware { + publicPaths := []string{ + "/auth/register", + "/auth/login", + "/auth/validate", + "/healthz", + "/health", + "/readyz", + "/metrics", + } + + return &AuthMiddleware{ + validator: auth.NewJWTValidator(jwtSecret), + logger: logger.WithComponent("auth-middleware"), + metrics: metrics, + publicPaths: publicPaths, + } +} + +// Middleware returns the authentication middleware handler. +func (m *AuthMiddleware) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // Skip authentication for public paths + if m.isPublicPath(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + // Extract token from Authorization header + authHeader := r.Header.Get("Authorization") + token, err := auth.ExtractTokenFromHeader(authHeader) + + if err != nil { + m.handleAuthError(w, r, err, auth.ValidationMissing) + m.recordAuthMetrics(auth.ValidationMissing, time.Since(startTime)) + return + } + + // Validate the token + userCtx, err := m.validator.ValidateToken(token) + if err != nil { + validationResult := auth.GetValidationResult(err) + m.handleAuthError(w, r, err, validationResult) + m.recordAuthMetrics(validationResult, time.Since(startTime)) + return + } + + // Add user context to request + ctx := auth.WithUserContext(r.Context(), userCtx) + r = r.WithContext(ctx) + + // Record successful authentication + m.recordAuthMetrics(auth.ValidationSuccess, time.Since(startTime)) + + // Log successful authentication + m.logger.WithContext(ctx).Info("Authentication successful", + logger.FieldUserID(userCtx.UserID), + logger.FieldPath(r.URL.Path), + logger.FieldMethod(r.Method), + ) + + next.ServeHTTP(w, r) + }) + } +} + +// isPublicPath checks if a path doesn't require authentication. +func (m *AuthMiddleware) isPublicPath(path string) bool { + for _, publicPath := range m.publicPaths { + if strings.HasPrefix(path, publicPath) { + return true + } + } + return false +} + +// handleAuthError handles authentication errors with appropriate HTTP responses. +func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, err error, result auth.ValidationResult) { + m.logger.WithContext(r.Context()).Info("Authentication failed", + logger.FieldPath(r.URL.Path), + logger.FieldMethod(r.Method), + logger.FieldStatus(http.StatusUnauthorized), + ) + + // Set appropriate status code based on error type + statusCode := http.StatusUnauthorized + message := "Authentication required" + + switch result { + case auth.ValidationExpired: + message = "Token has expired" + case auth.ValidationInvalid: + message = "Invalid token" + case auth.ValidationMissing: + message = "Authorization header required" + } + + // Set JSON content type and CORS headers + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + + w.WriteHeader(statusCode) + + // Write JSON error response + errorResponse := `{"error": "` + message + `", "code": "AUTHENTICATION_FAILED"}` + w.Write([]byte(errorResponse)) +} + +// recordAuthMetrics records authentication metrics. +func (m *AuthMiddleware) recordAuthMetrics(result auth.ValidationResult, duration time.Duration) { + m.metrics.RecordAuthValidation(string(result), duration) +} + +// OptionalAuthMiddleware provides optional authentication middleware. +// This allows authenticated and unauthenticated requests to proceed, but adds +// user context if authentication is present and valid. +type OptionalAuthMiddleware struct { + validator *auth.JWTValidator + logger *logger.Logger + metrics *metrics.Metrics +} + +// NewOptionalAuthMiddleware creates a new optional authentication middleware. +func NewOptionalAuthMiddleware(jwtSecret string, logger *logger.Logger, metrics *metrics.Metrics) *OptionalAuthMiddleware { + return &OptionalAuthMiddleware{ + validator: auth.NewJWTValidator(jwtSecret), + logger: logger.WithComponent("optional-auth"), + metrics: metrics, + } +} + +// Middleware returns the optional authentication middleware handler. +func (m *OptionalAuthMiddleware) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // Try to extract and validate token + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + if token, err := auth.ExtractTokenFromHeader(authHeader); err == nil { + if userCtx, err := m.validator.ValidateToken(token); err == nil { + // Add user context to request + ctx := auth.WithUserContext(r.Context(), userCtx) + r = r.WithContext(ctx) + + m.recordAuthMetrics(auth.ValidationSuccess, time.Since(startTime)) + } else { + // Log validation failure but don't block request + validationResult := auth.GetValidationResult(err) + m.recordAuthMetrics(validationResult, time.Since(startTime)) + + m.logger.WithContext(r.Context()).Debug("Optional auth validation failed", + logger.FieldPath(r.URL.Path), + logger.FieldMethod(r.Method), + ) + } + } + } + + next.ServeHTTP(w, r) + }) + } +} + +// recordAuthMetrics records authentication metrics for optional auth. +func (m *OptionalAuthMiddleware) recordAuthMetrics(result auth.ValidationResult, duration time.Duration) { + m.metrics.RecordAuthValidation(string(result), duration) +} diff --git a/services/gateway/internal/middleware/cors.go b/services/gateway/internal/middleware/cors.go new file mode 100644 index 0000000..33e3682 --- /dev/null +++ b/services/gateway/internal/middleware/cors.go @@ -0,0 +1,52 @@ +package middleware + +import ( + "net/http" +) + +// CORSMiddleware provides Cross-Origin Resource Sharing (CORS) functionality. +type CORSMiddleware struct { + allowedOrigins []string + allowedMethods []string + allowedHeaders []string +} + +// NewCORSMiddleware creates a new CORS middleware with default settings. +func NewCORSMiddleware() *CORSMiddleware { + return &CORSMiddleware{ + allowedOrigins: []string{"*"}, // In production, specify actual domains + allowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"}, + allowedHeaders: []string{ + "Accept", + "Content-Type", + "Content-Length", + "Accept-Encoding", + "Authorization", + "X-CSRF-Token", + "X-Requested-With", + }, + } +} + +// Middleware returns the CORS middleware handler. +func (m *CORSMiddleware) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers + w.Header().Set("Access-Control-Allow-Origin", "*") // Should be configurable + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, X-CSRF-Token, X-Requested-With") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Max-Age", "3600") + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + // Continue with the request + next.ServeHTTP(w, r) + }) + } +} diff --git a/services/gateway/internal/middleware/logging.go b/services/gateway/internal/middleware/logging.go new file mode 100644 index 0000000..77f38f9 --- /dev/null +++ b/services/gateway/internal/middleware/logging.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "net/http" + "time" + + "github.com/samims/hcaas/services/gateway/internal/logger" + "github.com/samims/hcaas/services/gateway/internal/metrics" +) + +// LoggingMiddleware provides request/response logging functionality. +type LoggingMiddleware struct { + logger *logger.Logger + metrics *metrics.Metrics +} + +// NewLoggingMiddleware creates a new logging middleware. +func NewLoggingMiddleware(logger *logger.Logger, metrics *metrics.Metrics) *LoggingMiddleware { + return &LoggingMiddleware{ + logger: logger.WithComponent("logging-middleware"), + metrics: metrics, + } +} + +// responseWriter wraps http.ResponseWriter to capture status code. +type responseWriter struct { + http.ResponseWriter + statusCode int + written bool +} + +func (rw *responseWriter) WriteHeader(code int) { + if !rw.written { + rw.statusCode = code + rw.written = true + rw.ResponseWriter.WriteHeader(code) + } +} + +func (rw *responseWriter) Write(data []byte) (int, error) { + if !rw.written { + rw.WriteHeader(http.StatusOK) + } + return rw.ResponseWriter.Write(data) +} + +// Middleware returns the logging middleware handler. +func (m *LoggingMiddleware) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // Wrap the response writer + wrappedWriter := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } + + // Process the request + next.ServeHTTP(wrappedWriter, r) + + // Calculate duration + duration := time.Since(startTime) + + // Log the request + m.logger.LogRequest( + r.Context(), + r.Method, + r.URL.Path, + r.UserAgent(), + wrappedWriter.statusCode, + int(duration.Milliseconds()), + ) + + // Record metrics + m.metrics.RecordHTTPRequest( + r.Method, + r.URL.Path, + wrappedWriter.statusCode, + duration, + ) + }) + } +} diff --git a/services/gateway/internal/middleware/ratelimit.go b/services/gateway/internal/middleware/ratelimit.go new file mode 100644 index 0000000..eef4b81 --- /dev/null +++ b/services/gateway/internal/middleware/ratelimit.go @@ -0,0 +1,221 @@ +package middleware + +import ( + "fmt" + "log/slog" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/time/rate" + + "github.com/samims/hcaas/services/gateway/internal/logger" + "github.com/samims/hcaas/services/gateway/internal/metrics" +) + +// RateLimitMiddleware provides rate limiting functionality using token bucket algorithm. +type RateLimitMiddleware struct { + limiters map[string]*rate.Limiter + mutex sync.RWMutex + + requestsPerSecond int + burstSize int + + logger *logger.Logger + metrics *metrics.Metrics + + // Last access time tracking for cleanup + lastAccess map[string]time.Time + + // Cleanup ticker for removing old limiters + cleanupTicker *time.Ticker + stopCleanup chan struct{} +} + +// NewRateLimitMiddleware creates a new rate limiting middleware. +func NewRateLimitMiddleware(requestsPerSecond, burstSize int, logger *logger.Logger, metrics *metrics.Metrics) *RateLimitMiddleware { + rl := &RateLimitMiddleware{ + limiters: make(map[string]*rate.Limiter), + lastAccess: make(map[string]time.Time), + requestsPerSecond: requestsPerSecond, + burstSize: burstSize, + logger: logger.WithComponent("rate-limit"), + metrics: metrics, + stopCleanup: make(chan struct{}), + } + + // Start cleanup goroutine to remove inactive limiters + rl.startCleanup() + + return rl +} + +// Middleware returns the rate limiting middleware handler. +func (rl *RateLimitMiddleware) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get client identifier (IP address) + clientIP := rl.getClientIP(r) + + // Get or create limiter for this client + limiter := rl.getLimiter(clientIP) + + // Check if request is allowed + if !limiter.Allow() { + rl.handleRateLimitExceeded(w, r, clientIP) + return + } + + // Request is allowed, proceed + next.ServeHTTP(w, r) + }) + } +} + +// getLimiter gets or creates a rate limiter for a client. +func (rl *RateLimitMiddleware) getLimiter(clientIP string) *rate.Limiter { + rl.mutex.Lock() + defer rl.mutex.Unlock() + + limiter, exists := rl.limiters[clientIP] + if !exists { + limiter = rate.NewLimiter(rate.Limit(rl.requestsPerSecond), rl.burstSize) + rl.limiters[clientIP] = limiter + } + + // Update last access time for cleanup + rl.lastAccess[clientIP] = time.Now() + + return limiter +} + +// getLimiterForIP gets an existing limiter for rate limit header calculation. +func (rl *RateLimitMiddleware) getLimiterForIP(clientIP string) *rate.Limiter { + rl.mutex.RLock() + defer rl.mutex.RUnlock() + + if limiter, exists := rl.limiters[clientIP]; exists { + return limiter + } + + // Return a default limiter if not found + return rate.NewLimiter(rate.Limit(rl.requestsPerSecond), rl.burstSize) +} + +// handleRateLimitExceeded handles rate limit exceeded scenarios. +func (rl *RateLimitMiddleware) handleRateLimitExceeded(w http.ResponseWriter, r *http.Request, clientIP string) { + // Record rate limit hit + rl.metrics.RecordRateLimitHit(clientIP, r.URL.Path) + + // Log the rate limit hit + rl.logger.WithContext(r.Context()).Warn("Rate limit exceeded", + logger.FieldPath(r.URL.Path), + logger.FieldMethod(r.Method), + logger.FieldStatus(http.StatusTooManyRequests), + logger.FieldUserID(clientIP), + ) + + // Calculate remaining tokens and reset time + limiter := rl.getLimiterForIP(clientIP) + remaining := int(limiter.Tokens()) + resetTime := 60 // Approximate reset time based on rate + + // Set rate limit headers with actual values + w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.requestsPerSecond)) + w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(remaining)) + w.Header().Set("X-RateLimit-Reset", strconv.Itoa(resetTime)) + w.Header().Set("Retry-After", strconv.Itoa(resetTime)) + + // Set CORS headers + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + + // Set content type and status + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + + // Write error response + errorResponse := `{"error": "Rate limit exceeded", "code": "RATE_LIMIT_EXCEEDED", "retry_after": 60}` + w.Write([]byte(errorResponse)) +} + +// getClientIP extracts client IP from request. +func (rl *RateLimitMiddleware) getClientIP(r *http.Request) string { + // Check X-Forwarded-For header first + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + return xff + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + return r.RemoteAddr +} + +// startCleanup starts a goroutine to periodically clean up inactive limiters. +func (rl *RateLimitMiddleware) startCleanup() { + rl.cleanupTicker = time.NewTicker(5 * time.Minute) // Clean up every 5 minutes + + go func() { + for { + select { + case <-rl.cleanupTicker.C: + rl.cleanup() + case <-rl.stopCleanup: + rl.cleanupTicker.Stop() + return + } + } + }() +} + +// cleanup removes limiters that haven't been used recently. +func (rl *RateLimitMiddleware) cleanup() { + rl.mutex.Lock() + defer rl.mutex.Unlock() + + now := time.Now() + cleanupThreshold := 10 * time.Minute + cleanupCount := 0 + + for clientIP, lastAccess := range rl.lastAccess { + // Remove limiters that haven't been accessed in the last 10 minutes + if now.Sub(lastAccess) > cleanupThreshold { + delete(rl.limiters, clientIP) + delete(rl.lastAccess, clientIP) + cleanupCount++ + } + } + + // Update metrics with active limiter count + rl.metrics.SetRateLimitActive("*", float64(len(rl.limiters))) + + rl.logger.Debug("Rate limiter cleanup completed", + logger.FieldService("rate-limiter"), + slog.Int("cleaned_up", cleanupCount), + slog.Int("active_limiters", len(rl.limiters)), + ) +} + +// Stop stops the rate limiter cleanup goroutine. +func (rl *RateLimitMiddleware) Stop() { + close(rl.stopCleanup) +} + +// GetStats returns current rate limiting statistics. +func (rl *RateLimitMiddleware) GetStats() map[string]interface{} { + rl.mutex.RLock() + defer rl.mutex.RUnlock() + + return map[string]interface{}{ + "active_limiters": len(rl.limiters), + "requests_per_second": rl.requestsPerSecond, + "burst_size": rl.burstSize, + } +} diff --git a/services/gateway/internal/middleware/tracing.go b/services/gateway/internal/middleware/tracing.go new file mode 100644 index 0000000..cbf63a6 --- /dev/null +++ b/services/gateway/internal/middleware/tracing.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "net/http" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + + "github.com/samims/hcaas/pkg/tracing" + "github.com/samims/hcaas/services/gateway/internal/logger" +) + +// TracingMiddleware provides OpenTelemetry tracing for HTTP requests. +type TracingMiddleware struct { + tracer *tracing.Tracer + logger *logger.Logger +} + +// NewTracingMiddleware creates a new tracing middleware. +func NewTracingMiddleware(logger *logger.Logger) *TracingMiddleware { + // Get the global tracer + otelTracer := otel.Tracer("hcaas-gateway") + tracer := tracing.NewTracer(otelTracer) + + return &TracingMiddleware{ + tracer: tracer, + logger: logger.WithComponent("tracing-middleware"), + } +} + +// Middleware returns the tracing middleware handler. +func (tm *TracingMiddleware) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Extract trace context from incoming request + ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + + // Start a new server span + spanName := r.Method + " " + r.URL.Path + ctx, span := tm.tracer.StartServerSpan(ctx, spanName) + defer span.End() + + // Add HTTP request attributes + tm.tracer.AddRequestAttributes(span, r.Method, r.URL.Path, r.UserAgent(), 0) + + // Add additional attributes + span.SetAttributes( + attribute.String("http.scheme", getScheme(r)), + attribute.String("http.host", r.Host), + attribute.String("http.target", r.RequestURI), + attribute.String("user_agent.original", r.UserAgent()), + ) + + // Create a response writer wrapper to capture status code + wrappedWriter := &tracingResponseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + span: span, + tracer: tm.tracer, + } + + // Update request context and process + r = r.WithContext(ctx) + next.ServeHTTP(wrappedWriter, r) + + // Update status code attribute after processing + tm.tracer.AddAttributes(span, + attribute.Int("http.status_code", wrappedWriter.statusCode), + ) + + // Set span status based on HTTP status code + if wrappedWriter.statusCode >= 400 { + if wrappedWriter.statusCode >= 500 { + span.SetStatus(codes.Error, "HTTP 5xx") + } else { + span.SetStatus(codes.Error, "HTTP 4xx") + } + } else { + span.SetStatus(codes.Ok, "") + } + }) + } +} + +// tracingResponseWriter wraps http.ResponseWriter to capture status code for tracing. +type tracingResponseWriter struct { + http.ResponseWriter + statusCode int + written bool + span trace.Span + tracer *tracing.Tracer +} + +func (trw *tracingResponseWriter) WriteHeader(code int) { + if !trw.written { + trw.statusCode = code + trw.written = true + trw.ResponseWriter.WriteHeader(code) + } +} + +func (trw *tracingResponseWriter) Write(data []byte) (int, error) { + if !trw.written { + trw.WriteHeader(http.StatusOK) + } + return trw.ResponseWriter.Write(data) +} + +// getScheme determines the scheme (http/https) of the request. +func getScheme(r *http.Request) string { + if r.TLS != nil { + return "https" + } + + // Check X-Forwarded-Proto header + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + return proto + } + + return "http" +} diff --git a/services/gateway/internal/proxy/proxy.go b/services/gateway/internal/proxy/proxy.go new file mode 100644 index 0000000..f0a618d --- /dev/null +++ b/services/gateway/internal/proxy/proxy.go @@ -0,0 +1,357 @@ +// Package proxy provides HTTP reverse proxy functionality for the API Gateway. +// +// This package implements intelligent request routing, load balancing, +// circuit breaking, and retry logic following Google's reliability patterns. +package proxy + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/samims/hcaas/services/gateway/internal/config" + "github.com/samims/hcaas/services/gateway/internal/logger" + "github.com/samims/hcaas/services/gateway/internal/metrics" +) + +// ServiceProxy handles proxying requests to upstream services. +type ServiceProxy struct { + client *http.Client + config *config.ServicesConfig + logger *logger.Logger + metrics *metrics.Metrics +} + +// NewServiceProxy creates a new service proxy with configured HTTP client. +func NewServiceProxy(cfg *config.ServicesConfig, logger *logger.Logger, metrics *metrics.Metrics) *ServiceProxy { + // Configure HTTP client with timeouts and connection pooling + client := &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + // Default timeout - will be overridden per service + Timeout: 30 * time.Second, + } + + return &ServiceProxy{ + client: client, + config: cfg, + logger: logger.WithComponent("proxy"), + metrics: metrics, + } +} + +// ProxyRequest proxies an HTTP request to the appropriate upstream service. +func (p *ServiceProxy) ProxyRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + startTime := time.Now() + + // Determine target service based on request path + serviceName, targetURL, err := p.determineTarget(r) + if err != nil { + p.metrics.RecordError(metrics.ComponentProxy, metrics.ErrorTypeValidation) + return fmt.Errorf("failed to determine target: %w", err) + } + + // Get service configuration + serviceConfig, err := p.getServiceConfig(serviceName) + if err != nil { + // rather than client errors, so it's categorized differently for alerting + p.metrics.RecordError(metrics.ComponentProxy, metrics.ErrorTypeInternal) + return fmt.Errorf("failed to get service config: %w", err) + } + + // OBSERVABILITY PHASE: Begin tracking this upstream request for monitoring + // This creates a "request in flight" metric + p.metrics.RecordUpstreamRequestStart(serviceName) + + // helps to identify the number of ongoing request + defer p.metrics.RecordUpstreamRequestEnd(serviceName) + + // EXECUTION PHASE: Proxy the request with intelligent retry logic + // This is where the actual network call happens, potentially with multiple attempts + resp, err := p.proxyWithRetries(ctx, r, targetURL, serviceConfig) + if err != nil { + // record problems with the target service (network issues, service unavailable, timeouts, etc.) + p.metrics.RecordError(metrics.ComponentProxy, metrics.ErrorTypeUpstream) + return fmt.Errorf("failed to proxy request: %w", err) + } + + // HTTP response bodies must be explicitly closed or they will leak memory and connections + defer resp.Body.Close() + + // METRICS PHASE: Calculate total request duration from start to finish + duration := time.Since(startTime) + p.metrics.RecordUpstreamRequest(serviceName, r.Method, r.URL.Path, resp.StatusCode, duration) + + // RESPONSE PHASE: Stream the upstream response back to the client + // This copies status code, headers, and body from the upstream service response + // to the client response, making the proxy transparent to the client + if err := p.copyResponse(w, resp); err != nil { + p.metrics.RecordError(metrics.ComponentProxy, metrics.ErrorTypeInternal) + return fmt.Errorf("failed to copy response: %w", err) + } + + // Log the service call + p.logger.LogServiceCall(ctx, serviceName, r.Method, r.URL.Path, resp.StatusCode, int(duration.Milliseconds())) + + return nil +} + +// determineTarget determines the target service and URL based on the request path. +func (p *ServiceProxy) determineTarget(r *http.Request) (string, string, error) { + path := r.URL.Path + + switch { + case strings.HasPrefix(path, "/auth/") || strings.HasPrefix(path, "/me"): + // Route to auth service + targetPath := strings.TrimPrefix(path, "/") + targetURL := fmt.Sprintf("%s/%s", p.config.AuthService.BaseURL, targetPath) + return "auth", targetURL, nil + + case strings.HasPrefix(path, "/urls"): + // Route to URL service + targetURL := fmt.Sprintf("%s%s", p.config.URLService.BaseURL, path) + return "url", targetURL, nil + + case strings.HasPrefix(path, "/notifications"): + // Route to notification service + targetURL := fmt.Sprintf("%s%s", p.config.NotificationService.BaseURL, path) + return "notification", targetURL, nil + + default: + return "", "", fmt.Errorf("no route found for path: %s", path) + } +} + +// getServiceConfig returns the configuration for a specific service. +func (p *ServiceProxy) getServiceConfig(serviceName string) (*config.ServiceEndpoint, error) { + switch serviceName { + case "auth": + return &p.config.AuthService, nil + case "url": + return &p.config.URLService, nil + case "notification": + return &p.config.NotificationService, nil + default: + return nil, fmt.Errorf("unknown service: %s", serviceName) + } +} + +// proxyWithRetries implements resilient communication patterns by automatically retrying failed +// requests according to the service configuration, using exponential backoff to avoid overwhelming +// failing services while maximizing the chance of eventual success. +func (p *ServiceProxy) proxyWithRetries(ctx context.Context, originalReq *http.Request, targetURL string, serviceConfig *config.ServiceEndpoint) (*http.Response, error) { + // Track the last error encountered during retry attempts + // This will be returned if all attempts fail + var lastErr error + + for attempt := 0; attempt <= serviceConfig.Retries; attempt++ { + // The createProxyRequest function handles body duplication and header copying + req, err := p.createProxyRequest(ctx, originalReq, targetURL) + + if err != nil { + // If we can't even create the request, fail immediately (no retry makes sense) + return nil, fmt.Errorf("failed to create proxy request: %w", err) + } + + // Apply per-request timeout to prevent individual requests from hanging indefinitely + // Each attempt gets its own timeout context, independent of the overall operation timeout + reqCtx, cancel := context.WithTimeout(ctx, serviceConfig.Timeout) + req = req.WithContext(reqCtx) + + // Make the request + resp, err := p.client.Do(req) + cancel() // Always cancel the context to prevent resource leaks + + if err == nil { + // Success - check if we should retry based on status code + if !p.shouldRetry(resp.StatusCode) { + return resp, nil + } + // Close the response body to free resources before retrying + resp.Body.Close() + lastErr = fmt.Errorf("upstream returned retryable status: %d", resp.StatusCode) + } else { + // Save this error in case all retries fail + lastErr = err + } + + // Implement exponential backoff between retry another attempts + // Don't sleep after the final attempt since we're about to give up anyway + + if attempt < serviceConfig.Retries { + // Calculate backoff duration: increases with each attempt to reduce load on failing services + // Formula: (attempt + 1) * 100ms, so 100ms, 200ms, 300ms, etc. + // The +1 ensures we don't have a 0ms delay on the first retry + backoff := time.Duration(attempt+1) * 100 * time.Millisecond + // Use select statement to respect context cancellation during backoff + select { + case <-time.After(backoff): + // Backoff period completed, proceed to next retry attempt + case <-ctx.Done(): + // Parent context was cancelled (timeout, client disconnect, etc.) + // Return immediately rather than continuing retry attempts + return nil, ctx.Err() + } + } + } + + return nil, fmt.Errorf("all retry attempts failed: %w", lastErr) +} + +// createProxyRequest creates a new HTTP request for proxying. +func (p *ServiceProxy) createProxyRequest(ctx context.Context, originalReq *http.Request, targetURL string) (*http.Request, error) { + // Parse target URL + target, err := url.Parse(targetURL) + if err != nil { + return nil, fmt.Errorf("invalid target URL: %w", err) + } + + // Handle request body preservation and duplication + // HTTP request bodies are streams that can only be read once, so we need to: + // 1. Read the entire body into memory + // 2. Create a new reader for the proxy request + // 3. Reset the original request body in case it needs to be read again (for retries) + var body io.Reader + if originalReq.Body != nil { + // Read all bytes from the original request body + bodyBytes, err := io.ReadAll(originalReq.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + // Create a new reader from the body bytes for the proxy request + body = bytes.NewReader(bodyBytes) + + // Reset the original request body so it can be read again if needed + // This is important for retry mechanisms or logging + originalReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + // Create the new HTTP request that will be sent to the target service + req, err := http.NewRequestWithContext(ctx, originalReq.Method, target.String(), body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Copy headers (excluding hop-by-hop headers) + for name, values := range originalReq.Header { + // Only copy headers that are not hop-by-hop headers + if !p.isHopByHopHeader(name) { + for _, value := range values { + req.Header.Add(name, value) + } + } + } + + // Add standard proxy forwarding headers to preserve client information + // These headers help the target service understand the original client context + + // X-Forwarded-For: Contains the original client's IP address + // This allows the target service to see the real client IP, not the proxy's IP + req.Header.Set("X-Forwarded-For", p.getClientIP(originalReq)) + + // X-Forwarded-Proto: Indicates the original protocol (http/https) + req.Header.Set("X-Forwarded-Proto", p.getScheme(originalReq)) + + // Helps target services understand what hostname the client originally requested + req.Header.Set("X-Forwarded-Host", originalReq.Host) + + // Add gateway identification + // Useful for troubleshooting, logging, and request tracing + req.Header.Set("X-Gateway", "hcaas-gateway") + + return req, nil +} + +// copyResponse copies the upstream response to the client. +func (p *ServiceProxy) copyResponse(w http.ResponseWriter, resp *http.Response) error { + // Copy headers (excluding hop-by-hop headers) + for name, values := range resp.Header { + if !p.isHopByHopHeader(name) { + for _, value := range values { + w.Header().Add(name, value) + } + } + } + + // Set status code + w.WriteHeader(resp.StatusCode) + + // Copy body + _, err := io.Copy(w, resp.Body) + return err +} + +// shouldRetry determines if a request should be retried based on status code. +func (p *ServiceProxy) shouldRetry(statusCode int) bool { + // Retry on server errors and service unavailable + return statusCode >= 500 && statusCode != 501 // Don't retry on "Not Implemented" +} + +// isHopByHopHeader checks if a header is hop-by-hop and should not be forwarded. +func (p *ServiceProxy) isHopByHopHeader(name string) bool { + hopByHopHeaders := []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailers", + "Transfer-Encoding", + "Upgrade", + } + + name = strings.ToLower(name) + for _, header := range hopByHopHeaders { + if strings.ToLower(header) == name { + return true + } + } + return false +} + +// getClientIP extracts the client IP address from the request. +func (p *ServiceProxy) getClientIP(r *http.Request) string { + // Check X-Forwarded-For header first + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the chain + if parts := strings.Split(xff, ","); len(parts) > 0 { + return strings.TrimSpace(parts[0]) + } + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + if host, _, found := strings.Cut(r.RemoteAddr, ":"); found { + return host + } + + return r.RemoteAddr +} + +// getScheme determines the scheme (http/https) of the original request. +func (p *ServiceProxy) getScheme(r *http.Request) string { + if r.TLS != nil { + return "https" + } + + // Check X-Forwarded-Proto header + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + return proto + } + + return "http" +}