diff --git a/internal/apirouter/destination_handlers.go b/internal/apirouter/destination_handlers.go index 9dc597cc..d1ba5592 100644 --- a/internal/apirouter/destination_handlers.go +++ b/internal/apirouter/destination_handlers.go @@ -12,22 +12,23 @@ import ( "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/telemetry" + "github.com/hookdeck/outpost/internal/tenantstore" "github.com/hookdeck/outpost/internal/util/maputil" ) type DestinationHandlers struct { logger *logging.Logger telemetry telemetry.Telemetry - entityStore models.EntityStore + tenantStore tenantstore.TenantStore topics []string registry destregistry.Registry } -func NewDestinationHandlers(logger *logging.Logger, telemetry telemetry.Telemetry, entityStore models.EntityStore, topics []string, registry destregistry.Registry) *DestinationHandlers { +func NewDestinationHandlers(logger *logging.Logger, telemetry telemetry.Telemetry, tenantStore tenantstore.TenantStore, topics []string, registry destregistry.Registry) *DestinationHandlers { return &DestinationHandlers{ logger: logger, telemetry: telemetry, - entityStore: entityStore, + tenantStore: tenantStore, topics: topics, registry: registry, } @@ -36,9 +37,9 @@ func NewDestinationHandlers(logger *logging.Logger, telemetry telemetry.Telemetr func (h *DestinationHandlers) List(c *gin.Context) { typeParams := c.QueryArray("type") topicsParams := c.QueryArray("topics") - var opts models.ListDestinationByTenantOpts + var opts tenantstore.ListDestinationByTenantOpts if len(typeParams) > 0 || len(topicsParams) > 0 { - opts = models.WithDestinationFilter(models.DestinationFilter{ + opts = tenantstore.WithDestinationFilter(tenantstore.DestinationFilter{ Type: typeParams, Topics: topicsParams, }) @@ -49,7 +50,7 @@ func (h *DestinationHandlers) List(c *gin.Context) { return } - destinations, err := h.entityStore.ListDestinationByTenant(c.Request.Context(), tenantID, opts) + destinations, err := h.tenantStore.ListDestinationByTenant(c.Request.Context(), tenantID, opts) if err != nil { AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err)) return @@ -96,7 +97,7 @@ func (h *DestinationHandlers) Create(c *gin.Context) { AbortWithValidationError(c, err) return } - if err := h.entityStore.CreateDestination(c.Request.Context(), destination); err != nil { + if err := h.tenantStore.CreateDestination(c.Request.Context(), destination); err != nil { h.handleUpsertDestinationError(c, err) return } @@ -196,7 +197,7 @@ func (h *DestinationHandlers) Update(c *gin.Context) { // Update destination. updatedDestination.UpdatedAt = time.Now() - if err := h.entityStore.UpsertDestination(c.Request.Context(), updatedDestination); err != nil { + if err := h.tenantStore.UpsertDestination(c.Request.Context(), updatedDestination); err != nil { h.handleUpsertDestinationError(c, err) return } @@ -218,7 +219,7 @@ func (h *DestinationHandlers) Delete(c *gin.Context) { if destination == nil { return } - if err := h.entityStore.DeleteDestination(c.Request.Context(), destination.TenantID, destination.ID); err != nil { + if err := h.tenantStore.DeleteDestination(c.Request.Context(), destination.TenantID, destination.ID); err != nil { AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err)) return } @@ -274,7 +275,7 @@ func (h *DestinationHandlers) setDisabilityHandler(c *gin.Context, disabled bool destination.DisabledAt = nil } if shouldUpdate { - if err := h.entityStore.UpsertDestination(c.Request.Context(), *destination); err != nil { + if err := h.tenantStore.UpsertDestination(c.Request.Context(), *destination); err != nil { h.handleUpsertDestinationError(c, err) return } @@ -289,9 +290,9 @@ func (h *DestinationHandlers) setDisabilityHandler(c *gin.Context, disabled bool } func (h *DestinationHandlers) mustRetrieveDestination(c *gin.Context, tenantID, destinationID string) *models.Destination { - destination, err := h.entityStore.RetrieveDestination(c.Request.Context(), tenantID, destinationID) + destination, err := h.tenantStore.RetrieveDestination(c.Request.Context(), tenantID, destinationID) if err != nil { - if errors.Is(err, models.ErrDestinationDeleted) { + if errors.Is(err, tenantstore.ErrDestinationDeleted) { c.Status(http.StatusNotFound) return nil } @@ -310,7 +311,7 @@ func (h *DestinationHandlers) handleUpsertDestinationError(c *gin.Context, err e AbortWithValidationError(c, err) return } - if errors.Is(err, models.ErrDuplicateDestination) { + if errors.Is(err, tenantstore.ErrDuplicateDestination) { AbortWithError(c, http.StatusBadRequest, NewErrBadRequest(err)) return } diff --git a/internal/apirouter/destination_handlers_test.go b/internal/apirouter/destination_handlers_test.go index 8f49abc0..da71d666 100644 --- a/internal/apirouter/destination_handlers_test.go +++ b/internal/apirouter/destination_handlers_test.go @@ -18,7 +18,7 @@ func TestDestinationCreateHandler(t *testing.T) { t.Parallel() router, _, redisClient := setupTestRouter(t, "", "") - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) t.Run("should set updated_at equal to created_at on creation", func(t *testing.T) { t.Parallel() @@ -30,7 +30,7 @@ func TestDestinationCreateHandler(t *testing.T) { CreatedAt: time.Now(), UpdatedAt: time.Now(), } - err := entityStore.UpsertTenant(context.Background(), tenant) + err := tenantStore.UpsertTenant(context.Background(), tenant) if err != nil { t.Fatal(err) } @@ -60,8 +60,8 @@ func TestDestinationCreateHandler(t *testing.T) { // Cleanup if destID, ok := response["id"].(string); ok { - entityStore.DeleteDestination(context.Background(), tenantID, destID) + tenantStore.DeleteDestination(context.Background(), tenantID, destID) } - entityStore.DeleteTenant(context.Background(), tenantID) + tenantStore.DeleteTenant(context.Background(), tenantID) }) } diff --git a/internal/apirouter/log_handlers.go b/internal/apirouter/log_handlers.go index 7f22339c..d3152000 100644 --- a/internal/apirouter/log_handlers.go +++ b/internal/apirouter/log_handlers.go @@ -88,13 +88,13 @@ func parseIncludeOptions(c *gin.Context) IncludeOptions { // APIAttempt is the API response for an attempt type APIAttempt struct { - ID string `json:"id"` - Status string `json:"status"` - DeliveredAt time.Time `json:"delivered_at"` - Code string `json:"code,omitempty"` - ResponseData map[string]interface{} `json:"response_data,omitempty"` - AttemptNumber int `json:"attempt_number"` - Manual bool `json:"manual"` + ID string `json:"id"` + Status string `json:"status"` + DeliveredAt time.Time `json:"delivered_at"` + Code string `json:"code,omitempty"` + ResponseData map[string]interface{} `json:"response_data,omitempty"` + AttemptNumber int `json:"attempt_number"` + Manual bool `json:"manual"` // Expandable fields - string (ID) or object depending on expand Event interface{} `json:"event"` @@ -146,8 +146,8 @@ type EventPaginatedResult struct { func toAPIAttempt(ar *logstore.AttemptRecord, opts IncludeOptions) APIAttempt { api := APIAttempt{ AttemptNumber: ar.Attempt.AttemptNumber, - Manual: ar.Attempt.Manual, - Destination: ar.Attempt.DestinationID, + Manual: ar.Attempt.Manual, + Destination: ar.Attempt.DestinationID, } if ar.Attempt != nil { @@ -186,7 +186,7 @@ func toAPIAttempt(ar *logstore.AttemptRecord, opts IncludeOptions) APIAttempt { } // TODO: Handle destination expansion - // This would require injecting EntityStore into LogHandlers and batch-fetching + // This would require injecting TenantStore into LogHandlers and batch-fetching // destinations by ID. Consider if this is needed - clients can fetch destination // details separately via GET /destinations/:id if needed. diff --git a/internal/apirouter/log_handlers_test.go b/internal/apirouter/log_handlers_test.go index 431e9fdb..a4c56517 100644 --- a/internal/apirouter/log_handlers_test.go +++ b/internal/apirouter/log_handlers_test.go @@ -23,11 +23,11 @@ func TestListAttempts(t *testing.T) { // Create a tenant tenantID := idgen.String() destinationID := idgen.Destination() - require.NoError(t, result.entityStore.UpsertTenant(context.Background(), models.Tenant{ + require.NoError(t, result.tenantStore.UpsertTenant(context.Background(), models.Tenant{ ID: tenantID, CreatedAt: time.Now(), })) - require.NoError(t, result.entityStore.UpsertDestination(context.Background(), models.Destination{ + require.NoError(t, result.tenantStore.UpsertDestination(context.Background(), models.Destination{ ID: destinationID, TenantID: tenantID, Type: "webhook", @@ -295,11 +295,11 @@ func TestRetrieveAttempt(t *testing.T) { // Create a tenant tenantID := idgen.String() destinationID := idgen.Destination() - require.NoError(t, result.entityStore.UpsertTenant(context.Background(), models.Tenant{ + require.NoError(t, result.tenantStore.UpsertTenant(context.Background(), models.Tenant{ ID: tenantID, CreatedAt: time.Now(), })) - require.NoError(t, result.entityStore.UpsertDestination(context.Background(), models.Destination{ + require.NoError(t, result.tenantStore.UpsertDestination(context.Background(), models.Destination{ ID: destinationID, TenantID: tenantID, Type: "webhook", @@ -404,11 +404,11 @@ func TestRetrieveEvent(t *testing.T) { // Create a tenant tenantID := idgen.String() destinationID := idgen.Destination() - require.NoError(t, result.entityStore.UpsertTenant(context.Background(), models.Tenant{ + require.NoError(t, result.tenantStore.UpsertTenant(context.Background(), models.Tenant{ ID: tenantID, CreatedAt: time.Now(), })) - require.NoError(t, result.entityStore.UpsertDestination(context.Background(), models.Destination{ + require.NoError(t, result.tenantStore.UpsertDestination(context.Background(), models.Destination{ ID: destinationID, TenantID: tenantID, Type: "webhook", @@ -489,11 +489,11 @@ func TestListEvents(t *testing.T) { // Create a tenant tenantID := idgen.String() destinationID := idgen.Destination() - require.NoError(t, result.entityStore.UpsertTenant(context.Background(), models.Tenant{ + require.NoError(t, result.tenantStore.UpsertTenant(context.Background(), models.Tenant{ ID: tenantID, CreatedAt: time.Now(), })) - require.NoError(t, result.entityStore.UpsertDestination(context.Background(), models.Destination{ + require.NoError(t, result.tenantStore.UpsertDestination(context.Background(), models.Destination{ ID: destinationID, TenantID: tenantID, Type: "webhook", diff --git a/internal/apirouter/requiretenant_middleware.go b/internal/apirouter/requiretenant_middleware.go index 022ea1a0..98dbe9d6 100644 --- a/internal/apirouter/requiretenant_middleware.go +++ b/internal/apirouter/requiretenant_middleware.go @@ -7,9 +7,10 @@ import ( "github.com/gin-gonic/gin" "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/tenantstore" ) -func RequireTenantMiddleware(entityStore models.EntityStore) gin.HandlerFunc { +func RequireTenantMiddleware(tenantStore tenantstore.TenantStore) gin.HandlerFunc { return func(c *gin.Context) { tenantID, exists := c.Get("tenantID") if !exists { @@ -17,9 +18,9 @@ func RequireTenantMiddleware(entityStore models.EntityStore) gin.HandlerFunc { return } - tenant, err := entityStore.RetrieveTenant(c.Request.Context(), tenantID.(string)) + tenant, err := tenantStore.RetrieveTenant(c.Request.Context(), tenantID.(string)) if err != nil { - if err == models.ErrTenantDeleted { + if err == tenantstore.ErrTenantDeleted { c.AbortWithStatus(http.StatusNotFound) return } diff --git a/internal/apirouter/requiretenant_middleware_test.go b/internal/apirouter/requiretenant_middleware_test.go index 2b09a300..6d639262 100644 --- a/internal/apirouter/requiretenant_middleware_test.go +++ b/internal/apirouter/requiretenant_middleware_test.go @@ -32,8 +32,8 @@ func TestRequireTenantMiddleware(t *testing.T) { tenant := models.Tenant{ ID: idgen.String(), } - entityStore := setupTestEntityStore(t, redisClient, nil) - err := entityStore.UpsertTenant(context.Background(), tenant) + tenantStore := setupTestTenantStore(t, redisClient) + err := tenantStore.UpsertTenant(context.Background(), tenant) require.Nil(t, err) w := httptest.NewRecorder() diff --git a/internal/apirouter/retry_handlers.go b/internal/apirouter/retry_handlers.go index f52f99d3..8f5fbed3 100644 --- a/internal/apirouter/retry_handlers.go +++ b/internal/apirouter/retry_handlers.go @@ -8,25 +8,26 @@ import ( "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/logstore" "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/tenantstore" "go.uber.org/zap" ) type RetryHandlers struct { logger *logging.Logger - entityStore models.EntityStore + tenantStore tenantstore.TenantStore logStore logstore.LogStore deliveryMQ *deliverymq.DeliveryMQ } func NewRetryHandlers( logger *logging.Logger, - entityStore models.EntityStore, + tenantStore tenantstore.TenantStore, logStore logstore.LogStore, deliveryMQ *deliverymq.DeliveryMQ, ) *RetryHandlers { return &RetryHandlers{ logger: logger, - entityStore: entityStore, + tenantStore: tenantStore, logStore: logStore, deliveryMQ: deliveryMQ, } @@ -58,7 +59,7 @@ func (h *RetryHandlers) RetryAttempt(c *gin.Context) { } // 2. Check destination exists and is enabled - destination, err := h.entityStore.RetrieveDestination(c.Request.Context(), tenant.ID, attemptRecord.Attempt.DestinationID) + destination, err := h.tenantStore.RetrieveDestination(c.Request.Context(), tenant.ID, attemptRecord.Attempt.DestinationID) if err != nil { AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err)) return diff --git a/internal/apirouter/retry_handlers_test.go b/internal/apirouter/retry_handlers_test.go index 74c7c65e..070873a8 100644 --- a/internal/apirouter/retry_handlers_test.go +++ b/internal/apirouter/retry_handlers_test.go @@ -23,11 +23,11 @@ func TestRetryAttempt(t *testing.T) { // Create a tenant and destination tenantID := idgen.String() destinationID := idgen.Destination() - require.NoError(t, result.entityStore.UpsertTenant(context.Background(), models.Tenant{ + require.NoError(t, result.tenantStore.UpsertTenant(context.Background(), models.Tenant{ ID: tenantID, CreatedAt: time.Now(), })) - require.NoError(t, result.entityStore.UpsertDestination(context.Background(), models.Destination{ + require.NoError(t, result.tenantStore.UpsertDestination(context.Background(), models.Destination{ ID: destinationID, TenantID: tenantID, Type: "webhook", @@ -117,7 +117,7 @@ func TestRetryAttempt(t *testing.T) { // Create a new destination that's disabled disabledDestinationID := idgen.Destination() disabledAt := time.Now() - require.NoError(t, result.entityStore.UpsertDestination(context.Background(), models.Destination{ + require.NoError(t, result.tenantStore.UpsertDestination(context.Background(), models.Destination{ ID: disabledDestinationID, TenantID: tenantID, Type: "webhook", diff --git a/internal/apirouter/router.go b/internal/apirouter/router.go index f5938d67..a6279073 100644 --- a/internal/apirouter/router.go +++ b/internal/apirouter/router.go @@ -13,11 +13,11 @@ import ( "github.com/hookdeck/outpost/internal/destregistry" "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/logstore" - "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/portal" "github.com/hookdeck/outpost/internal/publishmq" "github.com/hookdeck/outpost/internal/redis" "github.com/hookdeck/outpost/internal/telemetry" + "github.com/hookdeck/outpost/internal/tenantstore" "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" ) @@ -98,7 +98,7 @@ func NewRouter( logger *logging.Logger, redisClient redis.Cmdable, deliveryMQ *deliverymq.DeliveryMQ, - entityStore models.EntityStore, + tenantStore tenantstore.TenantStore, logStore logstore.LogStore, publishmqEventHandler publishmq.EventHandler, telemetry telemetry.Telemetry, @@ -139,11 +139,11 @@ func NewRouter( apiRouter := r.Group("/api/v1") apiRouter.Use(SetTenantIDMiddleware()) - tenantHandlers := NewTenantHandlers(logger, telemetry, cfg.JWTSecret, cfg.DeploymentID, entityStore) - destinationHandlers := NewDestinationHandlers(logger, telemetry, entityStore, cfg.Topics, cfg.Registry) + tenantHandlers := NewTenantHandlers(logger, telemetry, cfg.JWTSecret, cfg.DeploymentID, tenantStore) + destinationHandlers := NewDestinationHandlers(logger, telemetry, tenantStore, cfg.Topics, cfg.Registry) publishHandlers := NewPublishHandlers(logger, publishmqEventHandler) logHandlers := NewLogHandlers(logger, logStore) - retryHandlers := NewRetryHandlers(logger, entityStore, logStore, deliveryMQ) + retryHandlers := NewRetryHandlers(logger, tenantStore, logStore, deliveryMQ) topicHandlers := NewTopicHandlers(logger, cfg.Topics) // Non-tenant routes (no :tenantID in path) @@ -196,7 +196,7 @@ func NewRouter( AuthScope: AuthScopeAdmin, Mode: RouteModePortal, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -206,7 +206,7 @@ func NewRouter( AuthScope: AuthScopeAdmin, Mode: RouteModePortal, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, } @@ -246,7 +246,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -256,7 +256,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, @@ -268,7 +268,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -278,7 +278,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -288,7 +288,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -298,7 +298,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -308,7 +308,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -318,7 +318,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -328,7 +328,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, @@ -340,7 +340,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -350,7 +350,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -360,7 +360,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, @@ -372,7 +372,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -382,7 +382,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, @@ -394,7 +394,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -404,7 +404,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, { @@ -414,7 +414,7 @@ func NewRouter( AuthScope: AuthScopeAdminOrTenant, Mode: RouteModeAlways, Middlewares: []gin.HandlerFunc{ - RequireTenantMiddleware(entityStore), + RequireTenantMiddleware(tenantStore), }, }, } diff --git a/internal/apirouter/router_test.go b/internal/apirouter/router_test.go index d26840c0..307c96d3 100644 --- a/internal/apirouter/router_test.go +++ b/internal/apirouter/router_test.go @@ -15,10 +15,10 @@ import ( "github.com/hookdeck/outpost/internal/idgen" "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/logstore" - "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/publishmq" "github.com/hookdeck/outpost/internal/redis" "github.com/hookdeck/outpost/internal/telemetry" + "github.com/hookdeck/outpost/internal/tenantstore" "github.com/hookdeck/outpost/internal/apirouter" "github.com/hookdeck/outpost/internal/util/testutil" @@ -32,7 +32,7 @@ type testRouterResult struct { router http.Handler logger *logging.Logger redisClient redis.Client - entityStore models.EntityStore + tenantStore tenantstore.TenantStore logStore logstore.LogStore deliveryMQ *deliverymq.DeliveryMQ } @@ -49,9 +49,9 @@ func setupTestRouterFull(t *testing.T, apiKey, jwtSecret string, funcs ...func(t deliveryMQ := deliverymq.New() deliveryMQ.Init(context.Background()) eventTracer := eventtracer.NewNoopEventTracer() - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) logStore := setupTestLogStore(t, funcs...) - eventHandler := publishmq.NewEventHandler(logger, deliveryMQ, entityStore, eventTracer, testutil.TestTopics, idempotence.New(redisClient, idempotence.WithSuccessfulTTL(24*time.Hour))) + eventHandler := publishmq.NewEventHandler(logger, deliveryMQ, tenantStore, eventTracer, testutil.TestTopics, idempotence.New(redisClient, idempotence.WithSuccessfulTTL(24*time.Hour))) router := apirouter.NewRouter( apirouter.RouterConfig{ ServiceName: "", @@ -63,7 +63,7 @@ func setupTestRouterFull(t *testing.T, apiKey, jwtSecret string, funcs ...func(t logger, redisClient, deliveryMQ, - entityStore, + tenantStore, logStore, eventHandler, &telemetry.NoopTelemetry{}, @@ -72,7 +72,7 @@ func setupTestRouterFull(t *testing.T, apiKey, jwtSecret string, funcs ...func(t router: router, logger: logger, redisClient: redisClient, - entityStore: entityStore, + tenantStore: tenantStore, logStore: logStore, deliveryMQ: deliveryMQ, } @@ -93,14 +93,12 @@ func setupTestLogStore(t *testing.T, funcs ...func(t *testing.T) clickhouse.DB) return logStore } -func setupTestEntityStore(_ *testing.T, redisClient redis.Client, cipher models.Cipher) models.EntityStore { - if cipher == nil { - cipher = models.NewAESCipher("secret") - } - return models.NewEntityStore(redisClient, - models.WithCipher(cipher), - models.WithAvailableTopics(testutil.TestTopics), - ) +func setupTestTenantStore(_ *testing.T, redisClient redis.Client) tenantstore.TenantStore { + return tenantstore.New(tenantstore.Config{ + RedisClient: redisClient, + Secret: "secret", + AvailableTopics: testutil.TestTopics, + }) } func TestRouterWithAPIKey(t *testing.T) { diff --git a/internal/apirouter/tenant_handlers.go b/internal/apirouter/tenant_handlers.go index 53772394..cbe00806 100644 --- a/internal/apirouter/tenant_handlers.go +++ b/internal/apirouter/tenant_handlers.go @@ -10,6 +10,7 @@ import ( "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/telemetry" + "github.com/hookdeck/outpost/internal/tenantstore" ) type TenantHandlers struct { @@ -17,7 +18,7 @@ type TenantHandlers struct { telemetry telemetry.Telemetry jwtSecret string deploymentID string - entityStore models.EntityStore + tenantStore tenantstore.TenantStore } func NewTenantHandlers( @@ -25,14 +26,14 @@ func NewTenantHandlers( telemetry telemetry.Telemetry, jwtSecret string, deploymentID string, - entityStore models.EntityStore, + tenantStore tenantstore.TenantStore, ) *TenantHandlers { return &TenantHandlers{ logger: logger, telemetry: telemetry, jwtSecret: jwtSecret, deploymentID: deploymentID, - entityStore: entityStore, + tenantStore: tenantStore, } } @@ -55,8 +56,8 @@ func (h *TenantHandlers) Upsert(c *gin.Context) { } // Check existing tenant. - existingTenant, err := h.entityStore.RetrieveTenant(c.Request.Context(), tenantID) - if err != nil && err != models.ErrTenantDeleted { + existingTenant, err := h.tenantStore.RetrieveTenant(c.Request.Context(), tenantID) + if err != nil && err != tenantstore.ErrTenantDeleted { AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err)) return } @@ -65,7 +66,7 @@ func (h *TenantHandlers) Upsert(c *gin.Context) { if existingTenant != nil { existingTenant.Metadata = input.Metadata existingTenant.UpdatedAt = time.Now() - if err := h.entityStore.UpsertTenant(c.Request.Context(), *existingTenant); err != nil { + if err := h.tenantStore.UpsertTenant(c.Request.Context(), *existingTenant); err != nil { AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err)) return } @@ -82,7 +83,7 @@ func (h *TenantHandlers) Upsert(c *gin.Context) { CreatedAt: now, UpdatedAt: now, } - if err := h.entityStore.UpsertTenant(c.Request.Context(), *tenant); err != nil { + if err := h.tenantStore.UpsertTenant(c.Request.Context(), *tenant); err != nil { AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err)) return } @@ -113,7 +114,7 @@ func (h *TenantHandlers) List(c *gin.Context) { return } - req := models.ListTenantRequest{ + req := tenantstore.ListTenantRequest{ Next: cursors.Next, Prev: cursors.Prev, Dir: dir, @@ -130,10 +131,10 @@ func (h *TenantHandlers) List(c *gin.Context) { } // Call entity store - resp, err := h.entityStore.ListTenant(c.Request.Context(), req) + resp, err := h.tenantStore.ListTenant(c.Request.Context(), req) if err != nil { // Map errors to HTTP status codes - if errors.Is(err, models.ErrListTenantNotSupported) { + if errors.Is(err, tenantstore.ErrListTenantNotSupported) { AbortWithError(c, http.StatusNotImplemented, ErrorResponse{ Err: err, Code: http.StatusNotImplemented, @@ -141,15 +142,15 @@ func (h *TenantHandlers) List(c *gin.Context) { }) return } - if errors.Is(err, models.ErrConflictingCursors) { + if errors.Is(err, tenantstore.ErrConflictingCursors) { AbortWithError(c, http.StatusBadRequest, NewErrBadRequest(err)) return } - if errors.Is(err, models.ErrInvalidCursor) { + if errors.Is(err, tenantstore.ErrInvalidCursor) { AbortWithError(c, http.StatusBadRequest, NewErrBadRequest(err)) return } - if errors.Is(err, models.ErrInvalidOrder) { + if errors.Is(err, tenantstore.ErrInvalidOrder) { AbortWithError(c, http.StatusBadRequest, NewErrBadRequest(err)) return } @@ -166,9 +167,9 @@ func (h *TenantHandlers) Delete(c *gin.Context) { return } - err := h.entityStore.DeleteTenant(c.Request.Context(), tenantID) + err := h.tenantStore.DeleteTenant(c.Request.Context(), tenantID) if err != nil { - if err == models.ErrTenantNotFound { + if err == tenantstore.ErrTenantNotFound { c.Status(http.StatusNotFound) return } diff --git a/internal/apirouter/tenant_handlers_test.go b/internal/apirouter/tenant_handlers_test.go index 5291068e..09288f60 100644 --- a/internal/apirouter/tenant_handlers_test.go +++ b/internal/apirouter/tenant_handlers_test.go @@ -17,7 +17,7 @@ func TestDestinationUpsertHandler(t *testing.T) { t.Parallel() router, _, redisClient := setupTestRouter(t, "", "") - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) t.Run("should create when there's no existing tenant", func(t *testing.T) { t.Parallel() @@ -46,7 +46,7 @@ func TestDestinationUpsertHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) // Request w := httptest.NewRecorder() @@ -66,7 +66,7 @@ func TestDestinationUpsertHandler(t *testing.T) { assert.Equal(t, existingResource.CreatedAt.Unix(), createdAt.Unix()) // Cleanup - entityStore.DeleteTenant(context.Background(), existingResource.ID) + tenantStore.DeleteTenant(context.Background(), existingResource.ID) }) } @@ -74,7 +74,7 @@ func TestTenantRetrieveHandler(t *testing.T) { t.Parallel() router, _, redisClient := setupTestRouter(t, "", "") - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) t.Run("should return 404 when there's no tenant", func(t *testing.T) { t.Parallel() @@ -94,7 +94,7 @@ func TestTenantRetrieveHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) // Request w := httptest.NewRecorder() @@ -114,7 +114,7 @@ func TestTenantRetrieveHandler(t *testing.T) { assert.Equal(t, existingResource.CreatedAt.Unix(), createdAt.Unix()) // Cleanup - entityStore.DeleteTenant(context.Background(), existingResource.ID) + tenantStore.DeleteTenant(context.Background(), existingResource.ID) }) } @@ -122,7 +122,7 @@ func TestTenantDeleteHandler(t *testing.T) { t.Parallel() router, _, redisClient := setupTestRouter(t, "", "") - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) t.Run("should return 404 when there's no tenant", func(t *testing.T) { t.Parallel() @@ -142,7 +142,7 @@ func TestTenantDeleteHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) // Request w := httptest.NewRecorder() @@ -164,7 +164,7 @@ func TestTenantDeleteHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) inputDestination := models.Destination{ Type: "webhook", Topics: []string{"user.created", "user.updated"}, @@ -176,7 +176,7 @@ func TestTenantDeleteHandler(t *testing.T) { ids[i] = idgen.String() inputDestination.ID = ids[i] inputDestination.CreatedAt = time.Now() - entityStore.UpsertDestination(context.Background(), inputDestination) + tenantStore.UpsertDestination(context.Background(), inputDestination) } // Request @@ -190,7 +190,7 @@ func TestTenantDeleteHandler(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, true, response["success"]) - destinations, err := entityStore.ListDestinationByTenant(context.Background(), existingResource.ID) + destinations, err := tenantStore.ListDestinationByTenant(context.Background(), existingResource.ID) assert.Nil(t, err) assert.Equal(t, 0, len(destinations)) }) @@ -202,7 +202,7 @@ func TestTenantRetrieveTokenHandler(t *testing.T) { apiKey := "api_key" jwtSecret := "jwt_secret" router, _, redisClient := setupTestRouter(t, apiKey, jwtSecret) - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) t.Run("should return token and tenant_id", func(t *testing.T) { t.Parallel() @@ -212,7 +212,7 @@ func TestTenantRetrieveTokenHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) // Request w := httptest.NewRecorder() @@ -228,7 +228,7 @@ func TestTenantRetrieveTokenHandler(t *testing.T) { assert.Equal(t, existingResource.ID, response["tenant_id"]) // Cleanup - entityStore.DeleteTenant(context.Background(), existingResource.ID) + tenantStore.DeleteTenant(context.Background(), existingResource.ID) }) } @@ -238,7 +238,7 @@ func TestTenantRetrievePortalHandler(t *testing.T) { apiKey := "api_key" jwtSecret := "jwt_secret" router, _, redisClient := setupTestRouter(t, apiKey, jwtSecret) - entityStore := setupTestEntityStore(t, redisClient, nil) + tenantStore := setupTestTenantStore(t, redisClient) t.Run("should return redirect_url with token and tenant_id in body", func(t *testing.T) { t.Parallel() @@ -248,7 +248,7 @@ func TestTenantRetrievePortalHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) // Request w := httptest.NewRecorder() @@ -265,7 +265,7 @@ func TestTenantRetrievePortalHandler(t *testing.T) { assert.Equal(t, existingResource.ID, response["tenant_id"]) // Cleanup - entityStore.DeleteTenant(context.Background(), existingResource.ID) + tenantStore.DeleteTenant(context.Background(), existingResource.ID) }) t.Run("should include theme in redirect_url when provided", func(t *testing.T) { @@ -276,7 +276,7 @@ func TestTenantRetrievePortalHandler(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(context.Background(), existingResource) + tenantStore.UpsertTenant(context.Background(), existingResource) // Request w := httptest.NewRecorder() @@ -293,7 +293,7 @@ func TestTenantRetrievePortalHandler(t *testing.T) { assert.Equal(t, existingResource.ID, response["tenant_id"]) // Cleanup - entityStore.DeleteTenant(context.Background(), existingResource.ID) + tenantStore.DeleteTenant(context.Background(), existingResource.ID) }) } @@ -301,7 +301,7 @@ func TestTenantListHandler(t *testing.T) { t.Parallel() router, _, redisClient := setupTestRouter(t, "", "") - _ = setupTestEntityStore(t, redisClient, nil) + _ = setupTestTenantStore(t, redisClient) // Note: These tests use miniredis which doesn't support RediSearch. // The ListTenant feature requires RediSearch, so we expect 501 Not Implemented. diff --git a/internal/config/config.go b/internal/config/config.go index 35fed1d5..60d6c470 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -453,7 +453,7 @@ func (c *TelemetryConfig) ToTelemetryConfig() telemetry.TelemetryConfig { func (c *Config) ToTelemetryApplicationInfo() telemetry.ApplicationInfo { portalEnabled := c.APIKey != "" && c.APIJWTSecret != "" - entityStore := "redis" + tenantStore := "redis" logStore := "" if c.ClickHouse.Addr != "" { logStore = "clickhouse" @@ -466,7 +466,7 @@ func (c *Config) ToTelemetryApplicationInfo() telemetry.ApplicationInfo { Version: version.Version(), MQ: c.MQs.GetInfraType(), PortalEnabled: portalEnabled, - EntityStore: entityStore, + TenantStore: tenantStore, LogStore: logStore, } } diff --git a/internal/deliverymq/messagehandler.go b/internal/deliverymq/messagehandler.go index cf91c06d..6f4fa3d6 100644 --- a/internal/deliverymq/messagehandler.go +++ b/internal/deliverymq/messagehandler.go @@ -15,6 +15,7 @@ import ( "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/mqs" "github.com/hookdeck/outpost/internal/scheduler" + "github.com/hookdeck/outpost/internal/tenantstore" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" ) @@ -68,7 +69,7 @@ type messageHandler struct { eventTracer DeliveryTracer logger *logging.Logger logMQ LogPublisher - entityStore DestinationGetter + tenantStore DestinationGetter retryScheduler RetryScheduler retryBackoff backoff.Backoff retryMaxLimit int @@ -105,7 +106,7 @@ type AlertMonitor interface { func NewMessageHandler( logger *logging.Logger, logMQ LogPublisher, - entityStore DestinationGetter, + tenantStore DestinationGetter, publisher Publisher, eventTracer DeliveryTracer, retryScheduler RetryScheduler, @@ -118,7 +119,7 @@ func NewMessageHandler( eventTracer: eventTracer, logger: logger, logMQ: logMQ, - entityStore: entityStore, + tenantStore: tenantStore, publisher: publisher, retryScheduler: retryScheduler, retryBackoff: retryBackoff, @@ -163,7 +164,7 @@ func (h *messageHandler) handleError(msg *mqs.Message, err error) error { // Don't return error for expected cases var preErr *PreDeliveryError if errors.As(err, &preErr) { - if errors.Is(preErr.err, models.ErrDestinationDeleted) || errors.Is(preErr.err, errDestinationDisabled) { + if errors.Is(preErr.err, tenantstore.ErrDestinationDeleted) || errors.Is(preErr.err, errDestinationDisabled) { return nil } } @@ -359,7 +360,7 @@ func (h *messageHandler) shouldNackError(err error) bool { var preErr *PreDeliveryError if errors.As(err, &preErr) { // Don't nack if it's a permanent error - if errors.Is(preErr.err, models.ErrDestinationDeleted) || errors.Is(preErr.err, errDestinationDisabled) { + if errors.Is(preErr.err, tenantstore.ErrDestinationDeleted) || errors.Is(preErr.err, errDestinationDisabled) { return false } return true // Nack other pre-delivery errors @@ -428,7 +429,7 @@ func (h *messageHandler) scheduleRetry(ctx context.Context, task models.Delivery // Returns an error if the destination is not found, deleted, disabled, or any other state that // would prevent publishing. func (h *messageHandler) ensurePublishableDestination(ctx context.Context, task models.DeliveryTask) (*models.Destination, error) { - destination, err := h.entityStore.RetrieveDestination(ctx, task.Event.TenantID, task.DestinationID) + destination, err := h.tenantStore.RetrieveDestination(ctx, task.Event.TenantID, task.DestinationID) if err != nil { logger := h.logger.Ctx(ctx) fields := []zap.Field{ @@ -438,7 +439,7 @@ func (h *messageHandler) ensurePublishableDestination(ctx context.Context, task zap.String("destination_id", task.DestinationID), } - if errors.Is(err, models.ErrDestinationDeleted) { + if errors.Is(err, tenantstore.ErrDestinationDeleted) { logger.Info("destination deleted", fields...) } else { // Unexpected errors like DB connection issues @@ -451,7 +452,7 @@ func (h *messageHandler) ensurePublishableDestination(ctx context.Context, task zap.String("event_id", task.Event.ID), zap.String("tenant_id", task.Event.TenantID), zap.String("destination_id", task.DestinationID)) - return nil, models.ErrDestinationNotFound + return nil, tenantstore.ErrDestinationNotFound } if destination.DisabledAt != nil { h.logger.Ctx(ctx).Info("skipping disabled destination", diff --git a/internal/deliverymq/messagehandler_test.go b/internal/deliverymq/messagehandler_test.go index 0a3a5d75..2b3944e8 100644 --- a/internal/deliverymq/messagehandler_test.go +++ b/internal/deliverymq/messagehandler_test.go @@ -13,6 +13,7 @@ import ( "github.com/hookdeck/outpost/internal/idempotence" "github.com/hookdeck/outpost/internal/idgen" "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/tenantstore" "github.com/hookdeck/outpost/internal/util/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -153,7 +154,7 @@ func TestMessageHandler_DestinationDeleted(t *testing.T) { ) // Setup mocks - destGetter := &mockDestinationGetter{err: models.ErrDestinationDeleted} + destGetter := &mockDestinationGetter{err: tenantstore.ErrDestinationDeleted} retryScheduler := newMockRetryScheduler() logPublisher := newMockLogPublisher(nil) alertMonitor := newMockAlertMonitor() diff --git a/internal/destinationmockserver/mocksdk/mocksdk.go b/internal/destinationmockserver/mocksdk/mocksdk.go index 4db056d4..590bfb60 100644 --- a/internal/destinationmockserver/mocksdk/mocksdk.go +++ b/internal/destinationmockserver/mocksdk/mocksdk.go @@ -11,7 +11,7 @@ import ( "github.com/hookdeck/outpost/internal/models" ) -func New(baseURL string) destinationmockserver.EntityStore { +func New(baseURL string) destinationmockserver.MockStore { parsedURL, err := url.Parse(baseURL) if err != nil { panic(err) diff --git a/internal/destinationmockserver/model.go b/internal/destinationmockserver/model.go index dc638874..81674f81 100644 --- a/internal/destinationmockserver/model.go +++ b/internal/destinationmockserver/model.go @@ -21,7 +21,7 @@ type Event struct { Payload map[string]interface{} `json:"payload"` } -type EntityStore interface { +type MockStore interface { ListDestination(ctx context.Context) ([]models.Destination, error) RetrieveDestination(ctx context.Context, id string) (*models.Destination, error) UpsertDestination(ctx context.Context, destination models.Destination) error @@ -32,20 +32,20 @@ type EntityStore interface { ClearEvents(ctx context.Context, destinationID string) error } -type entityStore struct { +type mockStore struct { mu sync.RWMutex destinations map[string]models.Destination events map[string][]Event } -func NewEntityStore() EntityStore { - return &entityStore{ +func NewMockStore() MockStore { + return &mockStore{ destinations: make(map[string]models.Destination), events: make(map[string][]Event), } } -func (s *entityStore) ListDestination(ctx context.Context) ([]models.Destination, error) { +func (s *mockStore) ListDestination(ctx context.Context) ([]models.Destination, error) { s.mu.RLock() defer s.mu.RUnlock() destinationList := make([]models.Destination, len(s.destinations)) @@ -57,7 +57,7 @@ func (s *entityStore) ListDestination(ctx context.Context) ([]models.Destination return destinationList, nil } -func (s *entityStore) RetrieveDestination(ctx context.Context, id string) (*models.Destination, error) { +func (s *mockStore) RetrieveDestination(ctx context.Context, id string) (*models.Destination, error) { s.mu.RLock() defer s.mu.RUnlock() destination, ok := s.destinations[id] @@ -67,14 +67,14 @@ func (s *entityStore) RetrieveDestination(ctx context.Context, id string) (*mode return &destination, nil } -func (s *entityStore) UpsertDestination(ctx context.Context, destination models.Destination) error { +func (s *mockStore) UpsertDestination(ctx context.Context, destination models.Destination) error { s.mu.Lock() defer s.mu.Unlock() s.destinations[destination.ID] = destination return nil } -func (s *entityStore) DeleteDestination(ctx context.Context, id string) error { +func (s *mockStore) DeleteDestination(ctx context.Context, id string) error { s.mu.Lock() defer s.mu.Unlock() if _, ok := s.destinations[id]; !ok { @@ -85,7 +85,7 @@ func (s *entityStore) DeleteDestination(ctx context.Context, id string) error { return nil } -func (s *entityStore) ReceiveEvent(ctx context.Context, destinationID string, payload map[string]interface{}, metadata map[string]string) (*Event, error) { +func (s *mockStore) ReceiveEvent(ctx context.Context, destinationID string, payload map[string]interface{}, metadata map[string]string) (*Event, error) { s.mu.Lock() defer s.mu.Unlock() destination, ok := s.destinations[destinationID] @@ -220,7 +220,7 @@ func verifySignature(secret string, payload []byte, signature string, algorithm return false } -func (s *entityStore) ListEvent(ctx context.Context, destinationID string) ([]Event, error) { +func (s *mockStore) ListEvent(ctx context.Context, destinationID string) ([]Event, error) { s.mu.RLock() defer s.mu.RUnlock() events, ok := s.events[destinationID] @@ -230,7 +230,7 @@ func (s *entityStore) ListEvent(ctx context.Context, destinationID string) ([]Ev return events, nil } -func (s *entityStore) ClearEvents(ctx context.Context, destinationID string) error { +func (s *mockStore) ClearEvents(ctx context.Context, destinationID string) error { s.mu.Lock() defer s.mu.Unlock() if _, ok := s.destinations[destinationID]; !ok { diff --git a/internal/destinationmockserver/router.go b/internal/destinationmockserver/router.go index 6acebef7..7278f83c 100644 --- a/internal/destinationmockserver/router.go +++ b/internal/destinationmockserver/router.go @@ -9,11 +9,11 @@ import ( "github.com/hookdeck/outpost/internal/models" ) -func NewRouter(entityStore EntityStore) http.Handler { +func NewRouter(store MockStore) http.Handler { r := gin.Default() handlers := Handlers{ - entityStore: entityStore, + store: store, } r.GET("/healthz", func(c *gin.Context) { @@ -33,11 +33,11 @@ func NewRouter(entityStore EntityStore) http.Handler { } type Handlers struct { - entityStore EntityStore + store MockStore } func (h *Handlers) ListDestination(c *gin.Context) { - if destinations, err := h.entityStore.ListDestination(c.Request.Context()); err != nil { + if destinations, err := h.store.ListDestination(c.Request.Context()); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return } else { @@ -51,7 +51,7 @@ func (h *Handlers) UpsertDestination(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } - if err := h.entityStore.UpsertDestination(c.Request.Context(), input); err != nil { + if err := h.store.UpsertDestination(c.Request.Context(), input); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return } @@ -59,7 +59,7 @@ func (h *Handlers) UpsertDestination(c *gin.Context) { } func (h *Handlers) DeleteDestination(c *gin.Context) { - if err := h.entityStore.DeleteDestination(c.Request.Context(), c.Param("destinationID")); err != nil { + if err := h.store.DeleteDestination(c.Request.Context(), c.Param("destinationID")); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } @@ -68,7 +68,7 @@ func (h *Handlers) DeleteDestination(c *gin.Context) { func (h *Handlers) ReceiveWebhookEvent(c *gin.Context) { destinationID := c.Param("destinationID") - destination, err := h.entityStore.RetrieveDestination(c.Request.Context(), destinationID) + destination, err := h.store.RetrieveDestination(c.Request.Context(), destinationID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return @@ -94,7 +94,7 @@ func (h *Handlers) ReceiveWebhookEvent(c *gin.Context) { } log.Println("metadata", metadata) - if event, err := h.entityStore.ReceiveEvent(c.Request.Context(), destinationID, input, metadata); err != nil { + if event, err := h.store.ReceiveEvent(c.Request.Context(), destinationID, input, metadata); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return } else { @@ -108,7 +108,7 @@ func (h *Handlers) ReceiveWebhookEvent(c *gin.Context) { func (h *Handlers) ListEvent(c *gin.Context) { destinationID := c.Param("destinationID") - destination, err := h.entityStore.RetrieveDestination(c.Request.Context(), destinationID) + destination, err := h.store.RetrieveDestination(c.Request.Context(), destinationID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return @@ -117,7 +117,7 @@ func (h *Handlers) ListEvent(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"message": "destination not found"}) return } - if events, err := h.entityStore.ListEvent(c.Request.Context(), destinationID); err != nil { + if events, err := h.store.ListEvent(c.Request.Context(), destinationID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return } else { @@ -127,7 +127,7 @@ func (h *Handlers) ListEvent(c *gin.Context) { func (h *Handlers) ClearEvents(c *gin.Context) { destinationID := c.Param("destinationID") - destination, err := h.entityStore.RetrieveDestination(c.Request.Context(), destinationID) + destination, err := h.store.RetrieveDestination(c.Request.Context(), destinationID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return @@ -136,7 +136,7 @@ func (h *Handlers) ClearEvents(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"message": "destination not found"}) return } - if err := h.entityStore.ClearEvents(c.Request.Context(), destinationID); err != nil { + if err := h.store.ClearEvents(c.Request.Context(), destinationID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return } diff --git a/internal/destinationmockserver/server.go b/internal/destinationmockserver/server.go index 9821a83a..d6953778 100644 --- a/internal/destinationmockserver/server.go +++ b/internal/destinationmockserver/server.go @@ -55,8 +55,8 @@ func (s *DestinationMockServer) Run(ctx context.Context) error { func New(config DestinationMockServerConfig) DestinationMockServer { logger, _ := zap.NewDevelopment() - entityStore := NewEntityStore() - router := NewRouter(entityStore) + store := NewMockStore() + router := NewRouter(store) return DestinationMockServer{ logger: logger, diff --git a/internal/models/destination.go b/internal/models/destination.go deleted file mode 100644 index 5bd41136..00000000 --- a/internal/models/destination.go +++ /dev/null @@ -1,351 +0,0 @@ -package models - -import ( - "encoding" - "encoding/json" - "errors" - "fmt" - "log" - "slices" - "strings" - "time" - - "github.com/hookdeck/outpost/internal/redis" - "github.com/hookdeck/outpost/internal/simplejsonmatch" -) - -var ( - ErrInvalidTopics = errors.New("validation failed: invalid topics") - ErrInvalidTopicsFormat = errors.New("validation failed: invalid topics format") -) - -type Destination struct { - ID string `json:"id" redis:"id"` - TenantID string `json:"tenant_id" redis:"-"` - Type string `json:"type" redis:"type"` - Topics Topics `json:"topics" redis:"-"` - Filter Filter `json:"filter,omitempty" redis:"-"` - Config Config `json:"config" redis:"-"` - Credentials Credentials `json:"credentials" redis:"-"` - DeliveryMetadata DeliveryMetadata `json:"delivery_metadata,omitempty" redis:"-"` - Metadata Metadata `json:"metadata,omitempty" redis:"-"` - CreatedAt time.Time `json:"created_at" redis:"created_at"` - UpdatedAt time.Time `json:"updated_at" redis:"updated_at"` - DisabledAt *time.Time `json:"disabled_at" redis:"disabled_at"` -} - -func (d *Destination) parseRedisHash(cmd *redis.MapStringStringCmd, cipher Cipher) error { - hash, err := cmd.Result() - if err != nil { - return err - } - if len(hash) == 0 { - return redis.Nil - } - // Check for deleted resource before scanning - if _, exists := hash["deleted_at"]; exists { - return ErrDestinationDeleted - } - - // Parse basic fields manually (Scan doesn't handle numeric timestamps) - d.ID = hash["id"] - d.Type = hash["type"] - - // Parse created_at - supports both numeric (Unix) and RFC3339 formats - d.CreatedAt, err = parseTimestamp(hash["created_at"]) - if err != nil { - return fmt.Errorf("invalid created_at: %w", err) - } - - // Parse updated_at - same lazy migration support - if hash["updated_at"] != "" { - d.UpdatedAt, err = parseTimestamp(hash["updated_at"]) - if err != nil { - d.UpdatedAt = d.CreatedAt - } - } else { - d.UpdatedAt = d.CreatedAt - } - - // Parse disabled_at if present - if hash["disabled_at"] != "" { - disabledAt, err := parseTimestamp(hash["disabled_at"]) - if err == nil { - d.DisabledAt = &disabledAt - } - } - err = d.Topics.UnmarshalBinary([]byte(hash["topics"])) - if err != nil { - return fmt.Errorf("invalid topics: %w", err) - } - err = d.Config.UnmarshalBinary([]byte(hash["config"])) - if err != nil { - return fmt.Errorf("invalid config: %w", err) - } - credentialsBytes, err := cipher.Decrypt([]byte(hash["credentials"])) - if err != nil { - return fmt.Errorf("invalid credentials: %w", err) - } - err = d.Credentials.UnmarshalBinary(credentialsBytes) - if err != nil { - return fmt.Errorf("invalid credentials: %w", err) - } - // Decrypt and deserialize delivery_metadata if present - if deliveryMetadataStr, exists := hash["delivery_metadata"]; exists && deliveryMetadataStr != "" { - deliveryMetadataBytes, err := cipher.Decrypt([]byte(deliveryMetadataStr)) - if err != nil { - return fmt.Errorf("invalid delivery_metadata: %w", err) - } - err = d.DeliveryMetadata.UnmarshalBinary(deliveryMetadataBytes) - if err != nil { - return fmt.Errorf("invalid delivery_metadata: %w", err) - } - } - // Deserialize metadata if present - if metadataStr, exists := hash["metadata"]; exists && metadataStr != "" { - err = d.Metadata.UnmarshalBinary([]byte(metadataStr)) - if err != nil { - return fmt.Errorf("invalid metadata: %w", err) - } - } - // Deserialize filter if present - if filterStr, exists := hash["filter"]; exists && filterStr != "" { - err = d.Filter.UnmarshalBinary([]byte(filterStr)) - if err != nil { - return fmt.Errorf("invalid filter: %w", err) - } - } - return nil -} - -func (d *Destination) Validate(topics []string) error { - if err := d.Topics.Validate(topics); err != nil { - return err - } - return nil -} - -type DestinationSummary struct { - ID string `json:"id"` - Type string `json:"type"` - Topics Topics `json:"topics"` - Filter Filter `json:"filter,omitempty"` - Disabled bool `json:"disabled"` -} - -var _ encoding.BinaryMarshaler = &DestinationSummary{} -var _ encoding.BinaryUnmarshaler = &DestinationSummary{} - -func (ds *DestinationSummary) MarshalBinary() ([]byte, error) { - return json.Marshal(ds) -} - -func (ds *DestinationSummary) UnmarshalBinary(data []byte) error { - return json.Unmarshal(data, ds) -} - -func (d *Destination) ToSummary() *DestinationSummary { - return &DestinationSummary{ - ID: d.ID, - Type: d.Type, - Topics: d.Topics, - Filter: d.Filter, - Disabled: d.DisabledAt != nil, - } -} - -// MatchEvent checks if the destination matches the given event. -// Returns true if the destination is enabled, topic matches, and filter matches. -func (d *Destination) MatchEvent(event Event) bool { - if d.DisabledAt != nil { - return false - } - if !d.Topics.MatchTopic(event.Topic) { - return false - } - return matchFilter(d.Filter, event) -} - -// MatchFilter checks if the given event matches the destination's filter. -// Returns true if no filter is set (nil or empty) or if the event matches the filter. -func (ds *DestinationSummary) MatchFilter(event Event) bool { - return matchFilter(ds.Filter, event) -} - -// matchFilter is the shared implementation for filter matching. -// Returns true if no filter is set (nil or empty) or if the event matches the filter. -func matchFilter(filter Filter, event Event) bool { - if len(filter) == 0 { - return true - } - // Build the filter input from the event - filterInput := map[string]any{ - "id": event.ID, - "topic": event.Topic, - "time": event.Time.Format("2006-01-02T15:04:05Z07:00"), - "metadata": map[string]any{}, - "data": map[string]any{}, - } - // Convert metadata to map[string]any - if event.Metadata != nil { - metadata := make(map[string]any) - for k, v := range event.Metadata { - metadata[k] = v - } - filterInput["metadata"] = metadata - } - // Copy data - if event.Data != nil { - filterInput["data"] = map[string]any(event.Data) - } - return simplejsonmatch.Match(filterInput, map[string]any(filter)) -} - -// ============================== Types ============================== - -type Topics []string - -var _ encoding.BinaryMarshaler = &Topics{} -var _ encoding.BinaryUnmarshaler = &Topics{} -var _ json.Marshaler = &Topics{} -var _ json.Unmarshaler = &Topics{} - -func (t *Topics) MatchesAll() bool { - return len(*t) == 1 && (*t)[0] == "*" -} - -func (t *Topics) MatchTopic(eventTopic string) bool { - return eventTopic == "" || eventTopic == "*" || t.MatchesAll() || slices.Contains(*t, eventTopic) -} - -func (t *Topics) Validate(availableTopics []string) error { - if len(*t) == 0 { - return ErrInvalidTopics - } - if t.MatchesAll() { - return nil - } - // If no available topics are configured, allow any topics - if len(availableTopics) == 0 { - return nil - } - for _, topic := range *t { - if topic == "*" { - return ErrInvalidTopics - } - if !slices.Contains(availableTopics, topic) { - return ErrInvalidTopics - } - } - return nil -} - -func (t *Topics) MarshalBinary() ([]byte, error) { - str := strings.Join(*t, ",") - return []byte(str), nil -} - -func (t *Topics) UnmarshalBinary(data []byte) error { - *t = TopicsFromString(string(data)) - return nil -} - -func (t *Topics) MarshalJSON() ([]byte, error) { - return json.Marshal(*t) -} - -func (t *Topics) UnmarshalJSON(data []byte) error { - if string(data) == `"*"` { - *t = TopicsFromString("*") - return nil - } - var arr []string - if err := json.Unmarshal(data, &arr); err != nil { - log.Println(err) - return ErrInvalidTopicsFormat - } - *t = arr - return nil -} - -func TopicsFromString(s string) Topics { - return Topics(strings.Split(s, ",")) -} - -type Config = MapStringString -type Credentials = MapStringString -type DeliveryMetadata = MapStringString -type MapStringString map[string]string - -var _ encoding.BinaryMarshaler = &MapStringString{} -var _ encoding.BinaryUnmarshaler = &MapStringString{} -var _ json.Unmarshaler = &MapStringString{} - -func (m *MapStringString) MarshalBinary() ([]byte, error) { - return json.Marshal(m) -} - -func (m *MapStringString) UnmarshalBinary(data []byte) error { - return json.Unmarshal(data, m) -} - -func (m *MapStringString) UnmarshalJSON(data []byte) error { - // First try to unmarshal as map[string]string - var stringMap map[string]string - if err := json.Unmarshal(data, &stringMap); err == nil { - *m = stringMap - return nil - } - - // If that fails, try map[string]interface{} to handle mixed types - var mixedMap map[string]interface{} - if err := json.Unmarshal(data, &mixedMap); err != nil { - return err - } - - // Convert all values to strings - result := make(map[string]string) - for k, v := range mixedMap { - switch val := v.(type) { - case string: - result[k] = val - case bool: - result[k] = fmt.Sprintf("%v", val) - case float64: - result[k] = fmt.Sprintf("%v", val) - case nil: - result[k] = "" - default: - // For other types, try to convert to string using JSON marshaling - if b, err := json.Marshal(val); err == nil { - result[k] = string(b) - } else { - result[k] = fmt.Sprintf("%v", val) - } - } - } - - *m = result - return nil -} - -// Filter represents a JSON schema filter for event matching. -// It uses the simplejsonmatch schema syntax for filtering events. -type Filter map[string]any - -var _ encoding.BinaryMarshaler = &Filter{} -var _ encoding.BinaryUnmarshaler = &Filter{} - -func (f *Filter) MarshalBinary() ([]byte, error) { - if f == nil || len(*f) == 0 { - return nil, nil - } - return json.Marshal(f) -} - -func (f *Filter) UnmarshalBinary(data []byte) error { - if len(data) == 0 { - return nil - } - return json.Unmarshal(data, f) -} diff --git a/internal/models/encryption_test.go b/internal/models/encryption_test.go deleted file mode 100644 index 72d1cd2e..00000000 --- a/internal/models/encryption_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package models_test - -import ( - "testing" - - "github.com/hookdeck/outpost/internal/models" - "github.com/stretchr/testify/assert" -) - -func TestCipher(t *testing.T) { - cipher := models.NewAESCipher("secret") - - const value = "hello world" - - var err error - var encrypted []byte - - t.Run("should encrypt", func(t *testing.T) { - encrypted, err = cipher.Encrypt([]byte(value)) - assert.Nil(t, err) - assert.NotNil(t, encrypted) - }) - - t.Run("should decrypt", func(t *testing.T) { - decrypted, err := cipher.Decrypt(encrypted) - assert.Nil(t, err) - assert.Equal(t, value, string(decrypted)) - }) -} diff --git a/internal/models/entities.go b/internal/models/entities.go new file mode 100644 index 00000000..68bca1ed --- /dev/null +++ b/internal/models/entities.go @@ -0,0 +1,158 @@ +package models + +import ( + "errors" + "slices" + "strings" + "time" + + "github.com/hookdeck/outpost/internal/simplejsonmatch" +) + +var ( + ErrInvalidTopics = errors.New("validation failed: invalid topics") + ErrInvalidTopicsFormat = errors.New("validation failed: invalid topics format") +) + +type Tenant struct { + ID string `json:"id" redis:"id"` + DestinationsCount int `json:"destinations_count" redis:"-"` + Topics []string `json:"topics" redis:"-"` + Metadata Metadata `json:"metadata,omitempty" redis:"-"` + CreatedAt time.Time `json:"created_at" redis:"created_at"` + UpdatedAt time.Time `json:"updated_at" redis:"updated_at"` +} + +type Destination struct { + ID string `json:"id" redis:"id"` + TenantID string `json:"tenant_id" redis:"-"` + Type string `json:"type" redis:"type"` + Topics Topics `json:"topics" redis:"-"` + Filter Filter `json:"filter,omitempty" redis:"-"` + Config Config `json:"config" redis:"-"` + Credentials Credentials `json:"credentials" redis:"-"` + DeliveryMetadata DeliveryMetadata `json:"delivery_metadata,omitempty" redis:"-"` + Metadata Metadata `json:"metadata,omitempty" redis:"-"` + CreatedAt time.Time `json:"created_at" redis:"created_at"` + UpdatedAt time.Time `json:"updated_at" redis:"updated_at"` + DisabledAt *time.Time `json:"disabled_at" redis:"disabled_at"` +} + +func (d *Destination) Validate(topics []string) error { + if err := d.Topics.Validate(topics); err != nil { + return err + } + return nil +} + +// MatchEvent checks if the destination matches the given event. +// Returns true if the destination is enabled, topic matches, and filter matches. +func (d *Destination) MatchEvent(event Event) bool { + if d.DisabledAt != nil { + return false + } + if !d.Topics.MatchTopic(event.Topic) { + return false + } + return MatchFilter(d.Filter, event) +} + +// MatchFilter checks if the given event matches the filter. +// Returns true if no filter is set (nil or empty) or if the event matches the filter. +func MatchFilter(filter Filter, event Event) bool { + if len(filter) == 0 { + return true + } + // Build the filter input from the event + filterInput := map[string]any{ + "id": event.ID, + "topic": event.Topic, + "time": event.Time.Format("2006-01-02T15:04:05Z07:00"), + "metadata": map[string]any{}, + "data": map[string]any{}, + } + // Convert metadata to map[string]any + if event.Metadata != nil { + metadata := make(map[string]any) + for k, v := range event.Metadata { + metadata[k] = v + } + filterInput["metadata"] = metadata + } + // Copy data + if event.Data != nil { + filterInput["data"] = map[string]any(event.Data) + } + return simplejsonmatch.Match(filterInput, map[string]any(filter)) +} + +type Event struct { + ID string `json:"id"` + TenantID string `json:"tenant_id"` + DestinationID string `json:"destination_id"` + Topic string `json:"topic"` + EligibleForRetry bool `json:"eligible_for_retry"` + Time time.Time `json:"time"` + Metadata Metadata `json:"metadata"` + Data Data `json:"data"` + Status string `json:"status,omitempty"` + + // Telemetry data, must exist to properly trace events between publish receiver & delivery handler + Telemetry *EventTelemetry `json:"telemetry,omitempty"` +} + +const ( + AttemptStatusSuccess = "success" + AttemptStatusFailed = "failed" +) + +type Attempt struct { + ID string `json:"id"` + TenantID string `json:"tenant_id"` + EventID string `json:"event_id"` + DestinationID string `json:"destination_id"` + AttemptNumber int `json:"attempt_number"` + Manual bool `json:"manual"` + Status string `json:"status"` + Time time.Time `json:"time"` + Code string `json:"code"` + ResponseData map[string]interface{} `json:"response_data"` +} + +// ============================== Types ============================== + +type Topics []string + +func (t *Topics) MatchesAll() bool { + return len(*t) == 1 && (*t)[0] == "*" +} + +func (t *Topics) MatchTopic(eventTopic string) bool { + return eventTopic == "" || eventTopic == "*" || t.MatchesAll() || slices.Contains(*t, eventTopic) +} + +func (t *Topics) Validate(availableTopics []string) error { + if len(*t) == 0 { + return ErrInvalidTopics + } + if t.MatchesAll() { + return nil + } + // If no available topics are configured, allow any topics + if len(availableTopics) == 0 { + return nil + } + for _, topic := range *t { + if topic == "*" { + return ErrInvalidTopics + } + if !slices.Contains(availableTopics, topic) { + return ErrInvalidTopics + } + } + return nil +} + +func TopicsFromString(s string) Topics { + return Topics(strings.Split(s, ",")) +} diff --git a/internal/models/destination_test.go b/internal/models/entities_test.go similarity index 85% rename from internal/models/destination_test.go rename to internal/models/entities_test.go index 753971ef..c48e6186 100644 --- a/internal/models/destination_test.go +++ b/internal/models/entities_test.go @@ -314,7 +314,7 @@ func TestFilter_UnmarshalBinary(t *testing.T) { }) } -func TestDestinationSummary_MatchFilter(t *testing.T) { +func TestMatchFilter(t *testing.T) { t.Parallel() baseEvent := testutil.EventFactory.Any( @@ -600,13 +600,7 @@ func TestDestinationSummary_MatchFilter(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ds := models.DestinationSummary{ - ID: "dest_1", - Type: "webhook", - Topics: []string{"*"}, - Filter: tc.filter, - } - assert.Equal(t, tc.expected, ds.MatchFilter(tc.event)) + assert.Equal(t, tc.expected, models.MatchFilter(tc.filter, tc.event)) }) } } @@ -639,24 +633,96 @@ func TestDestination_JSONMarshalWithFilter(t *testing.T) { assert.Equal(t, "order.created", unmarshaled.Filter["data"].(map[string]any)["type"]) } -func TestDestinationSummary_ToSummaryIncludesFilter(t *testing.T) { +// ============================== Tenant tests ============================== + +func TestTenant_JSONMarshalWithMetadata(t *testing.T) { t.Parallel() - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID("dest_123"), - testutil.DestinationFactory.WithType("webhook"), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{ - "type": "order.created", - }, + tenant := testutil.TenantFactory.Any( + testutil.TenantFactory.WithID("tenant_123"), + testutil.TenantFactory.WithMetadata(map[string]string{ + "environment": "production", + "team": "platform", + "region": "us-east-1", }), ) - summary := destination.ToSummary() + // Marshal to JSON + jsonBytes, err := json.Marshal(tenant) + assert.NoError(t, err) + + // Unmarshal back + var unmarshaled models.Tenant + err = json.Unmarshal(jsonBytes, &unmarshaled) + assert.NoError(t, err) + + // Verify metadata is preserved + assert.Equal(t, tenant.Metadata, unmarshaled.Metadata) + assert.Equal(t, "production", unmarshaled.Metadata["environment"]) + assert.Equal(t, "platform", unmarshaled.Metadata["team"]) + assert.Equal(t, "us-east-1", unmarshaled.Metadata["region"]) + + // Verify other fields still work + assert.Equal(t, tenant.ID, unmarshaled.ID) +} + +func TestTenant_JSONMarshalWithoutMetadata(t *testing.T) { + t.Parallel() + + tenant := testutil.TenantFactory.Any( + testutil.TenantFactory.WithID("tenant_123"), + ) + + // Marshal to JSON + jsonBytes, err := json.Marshal(tenant) + assert.NoError(t, err) + + // Unmarshal back + var unmarshaled models.Tenant + err = json.Unmarshal(jsonBytes, &unmarshaled) + assert.NoError(t, err) + + // Verify metadata is nil when not provided + assert.Nil(t, unmarshaled.Metadata) +} + +func TestTenant_JSONUnmarshalEmptyMetadata(t *testing.T) { + t.Parallel() + + jsonData := `{ + "id": "tenant_123", + "destinations_count": 0, + "topics": [], + "metadata": {}, + "created_at": "2024-01-01T00:00:00Z" + }` + + var tenant models.Tenant + err := json.Unmarshal([]byte(jsonData), &tenant) + assert.NoError(t, err) + + // Empty maps should be preserved as empty, not nil + assert.NotNil(t, tenant.Metadata) + assert.Empty(t, tenant.Metadata) +} + +func TestTenant_JSONMarshalWithUpdatedAt(t *testing.T) { + t.Parallel() + + tenant := testutil.TenantFactory.Any( + testutil.TenantFactory.WithID("tenant_123"), + ) + + // Marshal to JSON + jsonBytes, err := json.Marshal(tenant) + assert.NoError(t, err) + + // Unmarshal back + var unmarshaled models.Tenant + err = json.Unmarshal(jsonBytes, &unmarshaled) + assert.NoError(t, err) - assert.Equal(t, destination.ID, summary.ID) - assert.Equal(t, destination.Type, summary.Type) - assert.Equal(t, destination.Topics, summary.Topics) - assert.Equal(t, destination.Filter, summary.Filter) + // Verify updated_at is preserved + assert.Equal(t, tenant.UpdatedAt.Unix(), unmarshaled.UpdatedAt.Unix()) + assert.Equal(t, tenant.CreatedAt.Unix(), unmarshaled.CreatedAt.Unix()) } diff --git a/internal/models/entity.go b/internal/models/entity.go deleted file mode 100644 index 0c969249..00000000 --- a/internal/models/entity.go +++ /dev/null @@ -1,942 +0,0 @@ -package models - -import ( - "context" - "errors" - "fmt" - "slices" - "sort" - "strconv" - "time" - - "github.com/hookdeck/outpost/internal/cursor" - "github.com/hookdeck/outpost/internal/pagination" - "github.com/hookdeck/outpost/internal/redis" -) - -const defaultMaxDestinationsPerTenant = 20 - -type EntityStore interface { - Init(ctx context.Context) error - RetrieveTenant(ctx context.Context, tenantID string) (*Tenant, error) - UpsertTenant(ctx context.Context, tenant Tenant) error - DeleteTenant(ctx context.Context, tenantID string) error - ListTenant(ctx context.Context, req ListTenantRequest) (*TenantPaginatedResult, error) - ListDestinationByTenant(ctx context.Context, tenantID string, options ...ListDestinationByTenantOpts) ([]Destination, error) - RetrieveDestination(ctx context.Context, tenantID, destinationID string) (*Destination, error) - CreateDestination(ctx context.Context, destination Destination) error - UpsertDestination(ctx context.Context, destination Destination) error - DeleteDestination(ctx context.Context, tenantID, destinationID string) error - MatchEvent(ctx context.Context, event Event) ([]DestinationSummary, error) -} - -var ( - ErrTenantNotFound = errors.New("tenant does not exist") - ErrTenantDeleted = errors.New("tenant has been deleted") - ErrDuplicateDestination = errors.New("destination already exists") - ErrDestinationNotFound = errors.New("destination does not exist") - ErrDestinationDeleted = errors.New("destination has been deleted") - ErrMaxDestinationsPerTenantReached = errors.New("maximum number of destinations per tenant reached") - ErrListTenantNotSupported = errors.New("list tenant feature is not enabled") - ErrInvalidCursor = errors.New("invalid cursor") - ErrInvalidOrder = errors.New("invalid order: must be 'asc' or 'desc'") - ErrConflictingCursors = errors.New("cannot specify both next and prev cursors") -) - -// ListTenantRequest contains parameters for listing tenants. -type ListTenantRequest struct { - Limit int // Number of results per page (default: 20) - Next string // Cursor for next page - Prev string // Cursor for previous page - Dir string // Sort direction: "asc" or "desc" (default: "desc") -} - -// SeekPagination represents cursor-based pagination metadata for list responses. -type SeekPagination struct { - OrderBy string `json:"order_by"` - Dir string `json:"dir"` - Limit int `json:"limit"` - Next *string `json:"next"` - Prev *string `json:"prev"` -} - -// TenantPaginatedResult contains the paginated list of tenants. -type TenantPaginatedResult struct { - Models []Tenant `json:"models"` - Pagination SeekPagination `json:"pagination"` - Count int `json:"count"` -} - -type entityStoreImpl struct { - redisClient redis.Cmdable - cipher Cipher - availableTopics []string - maxDestinationsPerTenant int - deploymentID string - listTenantSupported bool -} - -// doCmd executes an arbitrary Redis command using the Do method. -// Returns an error if the client doesn't support Do (e.g., mock clients). -func (s *entityStoreImpl) doCmd(ctx context.Context, args ...interface{}) *redis.Cmd { - if dc, ok := s.redisClient.(redis.DoContext); ok { - return dc.Do(ctx, args...) - } - // Return an error cmd if Do is not supported - cmd := &redis.Cmd{} - cmd.SetErr(errors.New("redis client does not support Do command")) - return cmd -} - -// deploymentPrefix returns the deployment prefix for Redis keys -func (s *entityStoreImpl) deploymentPrefix() string { - if s.deploymentID == "" { - return "" - } - return fmt.Sprintf("%s:", s.deploymentID) -} - -// New cluster-compatible key formats with hash tags -func (s *entityStoreImpl) redisTenantID(tenantID string) string { - return fmt.Sprintf("%stenant:{%s}:tenant", s.deploymentPrefix(), tenantID) -} - -func (s *entityStoreImpl) redisTenantDestinationSummaryKey(tenantID string) string { - return fmt.Sprintf("%stenant:{%s}:destinations", s.deploymentPrefix(), tenantID) -} - -func (s *entityStoreImpl) redisDestinationID(destinationID, tenantID string) string { - return fmt.Sprintf("%stenant:{%s}:destination:%s", s.deploymentPrefix(), tenantID, destinationID) -} - -var _ EntityStore = (*entityStoreImpl)(nil) - -type EntityStoreOption func(*entityStoreImpl) - -func WithCipher(cipher Cipher) EntityStoreOption { - return func(s *entityStoreImpl) { - s.cipher = cipher - } -} - -func WithAvailableTopics(topics []string) EntityStoreOption { - return func(s *entityStoreImpl) { - s.availableTopics = topics - } -} - -func WithMaxDestinationsPerTenant(maxDestinationsPerTenant int) EntityStoreOption { - return func(s *entityStoreImpl) { - s.maxDestinationsPerTenant = maxDestinationsPerTenant - } -} - -func WithDeploymentID(deploymentID string) EntityStoreOption { - return func(s *entityStoreImpl) { - s.deploymentID = deploymentID - } -} - -func NewEntityStore(redisClient redis.Cmdable, opts ...EntityStoreOption) EntityStore { - store := &entityStoreImpl{ - redisClient: redisClient, - cipher: NewAESCipher(""), - availableTopics: []string{}, - maxDestinationsPerTenant: defaultMaxDestinationsPerTenant, - } - - for _, opt := range opts { - opt(store) - } - - return store -} - -// tenantIndexName returns the RediSearch index name for tenants. -func (s *entityStoreImpl) tenantIndexName() string { - return s.deploymentPrefix() + "tenant_idx" -} - -// tenantKeyPrefix returns the key prefix for tenant hashes (for RediSearch). -func (s *entityStoreImpl) tenantKeyPrefix() string { - return s.deploymentPrefix() + "tenant:" -} - -// Init initializes the entity store, probing for RediSearch support. -// If RediSearch is available, it creates the tenant index. -// If RediSearch is not available, ListTenant will return ErrListTenantNotSupported. -func (s *entityStoreImpl) Init(ctx context.Context) error { - // Probe for RediSearch support using FT._LIST - _, err := s.doCmd(ctx, "FT._LIST").Result() - if err != nil { - // RediSearch not available - this is not an error, just disable the feature - s.listTenantSupported = false - return nil - } - - // Try to create the tenant index, gracefully disable if it fails - // TODO: consider logging this error, but we don't have a logger in this context - if err := s.ensureTenantIndex(ctx); err != nil { - s.listTenantSupported = false - return nil - } - - s.listTenantSupported = true - return nil -} - -// ensureTenantIndex creates the RediSearch index for tenants if it doesn't exist. -func (s *entityStoreImpl) ensureTenantIndex(ctx context.Context) error { - indexName := s.tenantIndexName() - - // Check if index already exists using FT.INFO - _, err := s.doCmd(ctx, "FT.INFO", indexName).Result() - if err == nil { - return nil - } - - // Create the index - // FT.CREATE index ON HASH PREFIX 1 prefix FILTER '@entity == "tenant"' SCHEMA ... - // Note: created_at and deleted_at are stored as Unix timestamps - // deleted_at is indexed so we can filter out deleted tenants in FT.SEARCH queries - // FILTER ensures only tenant entities are indexed (not destinations which share prefix) - prefix := s.tenantKeyPrefix() - _, err = s.doCmd(ctx, "FT.CREATE", indexName, - "ON", "HASH", - "PREFIX", "1", prefix, - "FILTER", `@entity == "tenant"`, - "SCHEMA", - "id", "TAG", - "entity", "TAG", - "created_at", "NUMERIC", "SORTABLE", - "deleted_at", "NUMERIC", - ).Result() - - if err != nil { - return fmt.Errorf("failed to create tenant index: %w", err) - } - - return nil -} - -func (s *entityStoreImpl) RetrieveTenant(ctx context.Context, tenantID string) (*Tenant, error) { - pipe := s.redisClient.Pipeline() - tenantCmd := pipe.HGetAll(ctx, s.redisTenantID(tenantID)) - destinationListCmd := pipe.HGetAll(ctx, s.redisTenantDestinationSummaryKey(tenantID)) - - if _, err := pipe.Exec(ctx); err != nil { - return nil, err - } - - tenantHash, err := tenantCmd.Result() - if err != nil { - return nil, err - } - if len(tenantHash) == 0 { - return nil, nil - } - tenant := &Tenant{} - if err := tenant.parseRedisHash(tenantHash); err != nil { - return nil, err - } - - destinationSummaryList, err := s.parseListDestinationSummaryByTenantCmd(destinationListCmd, ListDestinationByTenantOpts{}) - if err != nil { - return nil, err - } - tenant.DestinationsCount = len(destinationSummaryList) - tenant.Topics = s.parseTenantTopics(destinationSummaryList) - - return tenant, err -} - -func (s *entityStoreImpl) UpsertTenant(ctx context.Context, tenant Tenant) error { - key := s.redisTenantID(tenant.ID) - - // For cluster compatibility, execute commands individually instead of in a transaction - // Support overriding deleted resources - if err := s.redisClient.Persist(ctx, key).Err(); err != nil && err != redis.Nil { - return err - } - - if err := s.redisClient.HDel(ctx, key, "deleted_at").Err(); err != nil && err != redis.Nil { - return err - } - - // Auto-generate timestamps if not provided - now := time.Now() - if tenant.CreatedAt.IsZero() { - tenant.CreatedAt = now - } - if tenant.UpdatedAt.IsZero() { - tenant.UpdatedAt = now - } - - // Set tenant data - store timestamps as Unix milliseconds for timezone-agnostic sorting - // entity field is used by RediSearch FILTER to distinguish tenants from destinations - if err := s.redisClient.HSet(ctx, key, - "id", tenant.ID, - "entity", "tenant", - "created_at", tenant.CreatedAt.UnixMilli(), - "updated_at", tenant.UpdatedAt.UnixMilli(), - ).Err(); err != nil { - return err - } - - // Store metadata if present, otherwise delete field - if tenant.Metadata != nil { - if err := s.redisClient.HSet(ctx, key, "metadata", &tenant.Metadata).Err(); err != nil { - return err - } - } else { - if err := s.redisClient.HDel(ctx, key, "metadata").Err(); err != nil && err != redis.Nil { - return err - } - } - - return nil -} - -func (s *entityStoreImpl) DeleteTenant(ctx context.Context, tenantID string) error { - if exists, err := s.redisClient.Exists(ctx, s.redisTenantID(tenantID)).Result(); err != nil { - return err - } else if exists == 0 { - return ErrTenantNotFound - } - - // Get destination IDs before transaction - destinationIDs, err := s.redisClient.HKeys(ctx, s.redisTenantDestinationSummaryKey(tenantID)).Result() - if err != nil { - return err - } - - // All operations on same tenant - cluster compatible transaction - _, err = s.redisClient.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - nowUnixMilli := time.Now().UnixMilli() - - // Delete all destinations atomically - for _, destinationID := range destinationIDs { - destKey := s.redisDestinationID(destinationID, tenantID) - pipe.HSet(ctx, destKey, "deleted_at", nowUnixMilli) - pipe.Expire(ctx, destKey, 7*24*time.Hour) - } - - // Delete summary and mark tenant as deleted - // Store deleted_at as Unix milliseconds so it can be filtered in FT.SEARCH - pipe.Del(ctx, s.redisTenantDestinationSummaryKey(tenantID)) - pipe.HSet(ctx, s.redisTenantID(tenantID), "deleted_at", nowUnixMilli) - pipe.Expire(ctx, s.redisTenantID(tenantID), 7*24*time.Hour) - - return nil - }) - - return err -} - -const ( - defaultListTenantLimit = 20 - maxListTenantLimit = 100 -) - -// ListTenant returns a paginated list of tenants using RediSearch. -func (s *entityStoreImpl) ListTenant(ctx context.Context, req ListTenantRequest) (*TenantPaginatedResult, error) { - if !s.listTenantSupported { - return nil, ErrListTenantNotSupported - } - - // Validate: cannot specify both Next and Prev - if req.Next != "" && req.Prev != "" { - return nil, ErrConflictingCursors - } - - // Apply defaults and validate limit - limit := req.Limit - if limit <= 0 { - limit = defaultListTenantLimit - } - if limit > maxListTenantLimit { - limit = maxListTenantLimit - } - - // Validate and apply dir (sort direction) - dir := req.Dir - if dir == "" { - dir = "desc" - } - if dir != "asc" && dir != "desc" { - return nil, ErrInvalidOrder - } - - // Base filter for tenant search - // Filter: @entity:{tenant} ensures only tenant records (not destinations) - // Filter: -@deleted_at:[1 +inf] excludes deleted tenants - baseFilter := "@entity:{tenant} -@deleted_at:[1 +inf]" - - // Use pagination package for cursor-based pagination with n+1 pattern - result, err := pagination.Run(ctx, pagination.Config[Tenant]{ - Limit: limit, - Order: dir, - Next: req.Next, - Prev: req.Prev, - Cursor: pagination.Cursor[Tenant]{ - Encode: func(t Tenant) string { - return cursor.Encode("tnt", 1, strconv.FormatInt(t.CreatedAt.UnixMilli(), 10)) - }, - Decode: func(c string) (string, error) { - data, err := cursor.Decode(c, "tnt", 1) - if err != nil { - return "", fmt.Errorf("%w: %v", ErrInvalidCursor, err) - } - return data, nil - }, - }, - Fetch: func(ctx context.Context, q pagination.QueryInput) ([]Tenant, error) { - return s.fetchTenants(ctx, baseFilter, q) - }, - }) - if err != nil { - return nil, err - } - - tenants := result.Items - - // Batch fetch destination summaries for all tenants in a single Redis round-trip - if len(tenants) > 0 { - pipe := s.redisClient.Pipeline() - cmds := make([]*redis.MapStringStringCmd, len(tenants)) - for i, t := range tenants { - cmds[i] = pipe.HGetAll(ctx, s.redisTenantDestinationSummaryKey(t.ID)) - } - if _, err := pipe.Exec(ctx); err != nil { - return nil, fmt.Errorf("failed to fetch destination summaries: %w", err) - } - - // Compute destinations_count and topics for each tenant - for i := range tenants { - destinationSummaryList, err := s.parseListDestinationSummaryByTenantCmd(cmds[i], ListDestinationByTenantOpts{}) - if err != nil { - return nil, err - } - tenants[i].DestinationsCount = len(destinationSummaryList) - tenants[i].Topics = s.parseTenantTopics(destinationSummaryList) - } - } - - // Get total count of all tenants (excluding deleted) - cheap query with LIMIT 0 0 - var totalCount int - countResult, err := s.doCmd(ctx, "FT.SEARCH", s.tenantIndexName(), - baseFilter, - "LIMIT", 0, 0, - ).Result() - if err == nil { - _, totalCount, _ = s.parseSearchResult(ctx, countResult) - } - - // Convert empty cursors to nil pointers (Hookdeck returns null for empty cursors) - var nextCursor, prevCursor *string - if result.Next != "" { - nextCursor = &result.Next - } - if result.Prev != "" { - prevCursor = &result.Prev - } - - return &TenantPaginatedResult{ - Models: tenants, - Pagination: SeekPagination{ - OrderBy: "created_at", - Dir: dir, - Limit: limit, - Next: nextCursor, - Prev: prevCursor, - }, - Count: totalCount, - }, nil -} - -// fetchTenants builds and executes the FT.SEARCH query for tenant pagination. -func (s *entityStoreImpl) fetchTenants(ctx context.Context, baseFilter string, q pagination.QueryInput) ([]Tenant, error) { - // Build FT.SEARCH query with timestamp filter (keyset pagination) - var query string - sortDir := "DESC" - if q.SortDir == "asc" { - sortDir = "ASC" - } - - if q.CursorPos == "" { - // First page - only filter out deleted - query = baseFilter - } else { - // Parse cursor timestamp - cursorTimestamp, err := strconv.ParseInt(q.CursorPos, 10, 64) - if err != nil { - return nil, fmt.Errorf("%w: invalid timestamp", ErrInvalidCursor) - } - - // Build cursor condition based on compare operator - // Use inclusive ranges with cursor-1/cursor+1 for better compatibility - if q.Compare == "<" { - // Get records with created_at < cursor - query = fmt.Sprintf("(@created_at:[0 %d]) %s", cursorTimestamp-1, baseFilter) - } else { - // Get records with created_at > cursor - query = fmt.Sprintf("(@created_at:[%d +inf]) %s", cursorTimestamp+1, baseFilter) - } - } - - // Execute FT.SEARCH query - result, err := s.doCmd(ctx, "FT.SEARCH", s.tenantIndexName(), - query, - "SORTBY", "created_at", sortDir, - "LIMIT", 0, q.Limit, - ).Result() - if err != nil { - return nil, fmt.Errorf("failed to search tenants: %w", err) - } - - // Parse FT.SEARCH result - tenants, _, err := s.parseSearchResult(ctx, result) - if err != nil { - return nil, err - } - - return tenants, nil -} - -// parseSearchResult parses the FT.SEARCH result into a list of tenants. -// Supports both RESP2 (array) and RESP3 (map) formats. -func (s *entityStoreImpl) parseSearchResult(_ context.Context, result interface{}) ([]Tenant, int, error) { - // RESP3 format (go-redis v9): map with "total_results", "results", etc. - if resultMap, ok := result.(map[interface{}]interface{}); ok { - return s.parseResp3SearchResult(resultMap) - } - - // RESP2 format: [total_count, doc1_key, doc1_fields, doc2_key, doc2_fields, ...] - arr, ok := result.([]interface{}) - if !ok || len(arr) == 0 { - return []Tenant{}, 0, nil - } - - totalCount, ok := arr[0].(int64) - if !ok { - return nil, 0, fmt.Errorf("invalid search result: expected total count") - } - - tenants := make([]Tenant, 0, (len(arr)-1)/2) - - // Iterate through results (skip first element which is count) - for i := 1; i < len(arr); i += 2 { - if i+1 >= len(arr) { - break - } - - // arr[i] is the document key, arr[i+1] is the fields - // Fields can be either: - // - []interface{} array (Redis Stack RESP2): [field1, val1, field2, val2, ...] - // - map[interface{}]interface{} (Dragonfly): {field1: val1, field2: val2, ...} - hash := make(map[string]string) - - switch fields := arr[i+1].(type) { - case []interface{}: - // Redis Stack RESP2 format: array of alternating key/value - for j := 0; j < len(fields)-1; j += 2 { - key, keyOk := fields[j].(string) - val, valOk := fields[j+1].(string) - if keyOk && valOk { - hash[key] = val - } - } - case map[interface{}]interface{}: - // Dragonfly format: map of key/value pairs - for k, v := range fields { - key, keyOk := k.(string) - if !keyOk { - continue - } - switch val := v.(type) { - case string: - hash[key] = val - case float64: - hash[key] = fmt.Sprintf("%.0f", val) - case int64: - hash[key] = fmt.Sprintf("%d", val) - } - } - default: - continue - } - - // Skip deleted tenants - if _, deleted := hash["deleted_at"]; deleted { - continue - } - - tenant := &Tenant{} - if err := tenant.parseRedisHash(hash); err != nil { - continue // Skip invalid entries - } - - tenants = append(tenants, *tenant) - } - - return tenants, int(totalCount), nil -} - -// parseResp3SearchResult parses the RESP3 map format from FT.SEARCH. -func (s *entityStoreImpl) parseResp3SearchResult(resultMap map[interface{}]interface{}) ([]Tenant, int, error) { - totalCount := 0 - if tc, ok := resultMap["total_results"].(int64); ok { - totalCount = int(tc) - } - - results, ok := resultMap["results"].([]interface{}) - if !ok { - return []Tenant{}, totalCount, nil - } - - tenants := make([]Tenant, 0, len(results)) - - for _, r := range results { - docMap, ok := r.(map[interface{}]interface{}) - if !ok { - continue - } - - // Get extra_attributes which contains the hash fields - extraAttrs, ok := docMap["extra_attributes"].(map[interface{}]interface{}) - if !ok { - continue - } - - // Convert to string map - hash := make(map[string]string) - for k, v := range extraAttrs { - if keyStr, ok := k.(string); ok { - if valStr, ok := v.(string); ok { - hash[keyStr] = valStr - } - } - } - - // Skip deleted tenants - if _, deleted := hash["deleted_at"]; deleted { - continue - } - - tenant := &Tenant{} - if err := tenant.parseRedisHash(hash); err != nil { - continue // Skip invalid entries - } - - tenants = append(tenants, *tenant) - } - - return tenants, totalCount, nil -} - -func (s *entityStoreImpl) listDestinationSummaryByTenant(ctx context.Context, tenantID string, opts ListDestinationByTenantOpts) ([]DestinationSummary, error) { - return s.parseListDestinationSummaryByTenantCmd(s.redisClient.HGetAll(ctx, s.redisTenantDestinationSummaryKey(tenantID)), opts) -} - -func (s *entityStoreImpl) parseListDestinationSummaryByTenantCmd(cmd *redis.MapStringStringCmd, opts ListDestinationByTenantOpts) ([]DestinationSummary, error) { - destinationSummaryListHash, err := cmd.Result() - if err != nil { - if err == redis.Nil { - return []DestinationSummary{}, nil - } - return nil, err - } - destinationSummaryList := make([]DestinationSummary, 0, len(destinationSummaryListHash)) - for _, destinationSummaryStr := range destinationSummaryListHash { - destinationSummary := DestinationSummary{} - if err := destinationSummary.UnmarshalBinary([]byte(destinationSummaryStr)); err != nil { - return nil, err - } - included := true - if opts.Filter != nil { - included = opts.Filter.match(destinationSummary) - } - if included { - destinationSummaryList = append(destinationSummaryList, destinationSummary) - } - } - return destinationSummaryList, nil -} - -func (s *entityStoreImpl) ListDestinationByTenant(ctx context.Context, tenantID string, options ...ListDestinationByTenantOpts) ([]Destination, error) { - var opts ListDestinationByTenantOpts - if len(options) > 0 { - opts = options[0] - } else { - opts = ListDestinationByTenantOpts{} - } - - destinationSummaryList, err := s.listDestinationSummaryByTenant(ctx, tenantID, opts) - if err != nil { - return nil, err - } - - pipe := s.redisClient.Pipeline() - cmds := make([]*redis.MapStringStringCmd, len(destinationSummaryList)) - for i, destinationSummary := range destinationSummaryList { - cmds[i] = pipe.HGetAll(ctx, s.redisDestinationID(destinationSummary.ID, tenantID)) - } - _, err = pipe.Exec(ctx) - if err != nil { - return nil, err - } - - destinations := make([]Destination, len(destinationSummaryList)) - for i, cmd := range cmds { - destination := &Destination{TenantID: tenantID} - err = destination.parseRedisHash(cmd, s.cipher) - if err != nil { - return []Destination{}, err - } - destinations[i] = *destination - } - - sort.Slice(destinations, func(i, j int) bool { - return destinations[i].CreatedAt.Before(destinations[j].CreatedAt) - }) - - return destinations, nil -} - -func (s *entityStoreImpl) RetrieveDestination(ctx context.Context, tenantID, destinationID string) (*Destination, error) { - cmd := s.redisClient.HGetAll(ctx, s.redisDestinationID(destinationID, tenantID)) - destination := &Destination{TenantID: tenantID} - if err := destination.parseRedisHash(cmd, s.cipher); err != nil { - if err == redis.Nil { - return nil, nil - } - return nil, err - } - return destination, nil -} - -func (s *entityStoreImpl) CreateDestination(ctx context.Context, destination Destination) error { - key := s.redisDestinationID(destination.ID, destination.TenantID) - // Check if destination exists - if fields, err := s.redisClient.HGetAll(ctx, key).Result(); err != nil { - return err - } else if len(fields) > 0 { - if _, isDeleted := fields["deleted_at"]; !isDeleted { - return ErrDuplicateDestination - } - } - - // Check if tenant has reached max destinations by counting entries in the summary hash - count, err := s.redisClient.HLen(ctx, s.redisTenantDestinationSummaryKey(destination.TenantID)).Result() - if err != nil { - return err - } - if count >= int64(s.maxDestinationsPerTenant) { - return ErrMaxDestinationsPerTenantReached - } - - return s.UpsertDestination(ctx, destination) -} - -func (s *entityStoreImpl) UpsertDestination(ctx context.Context, destination Destination) error { - key := s.redisDestinationID(destination.ID, destination.TenantID) - - // Pre-marshal and encrypt credentials and delivery_metadata BEFORE starting Redis transaction - // This isolates marshaling failures from Redis transaction failures - credentialsBytes, err := destination.Credentials.MarshalBinary() - if err != nil { - return fmt.Errorf("invalid destination credentials: %w", err) - } - encryptedCredentials, err := s.cipher.Encrypt(credentialsBytes) - if err != nil { - return fmt.Errorf("failed to encrypt destination credentials: %w", err) - } - - // Encrypt delivery_metadata if present (contains sensitive data like auth tokens) - var encryptedDeliveryMetadata []byte - if destination.DeliveryMetadata != nil { - deliveryMetadataBytes, err := destination.DeliveryMetadata.MarshalBinary() - if err != nil { - return fmt.Errorf("invalid destination delivery_metadata: %w", err) - } - encryptedDeliveryMetadata, err = s.cipher.Encrypt(deliveryMetadataBytes) - if err != nil { - return fmt.Errorf("failed to encrypt destination delivery_metadata: %w", err) - } - } - - // Auto-generate timestamps if not provided - now := time.Now() - if destination.CreatedAt.IsZero() { - destination.CreatedAt = now - } - if destination.UpdatedAt.IsZero() { - destination.UpdatedAt = now - } - - // All keys use same tenant prefix - cluster compatible transaction - summaryKey := s.redisTenantDestinationSummaryKey(destination.TenantID) - - _, err = s.redisClient.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - // Clear deletion markers - pipe.Persist(ctx, key) - pipe.HDel(ctx, key, "deleted_at") - - // Set all destination fields atomically - // Store timestamps as Unix milliseconds for timezone-agnostic handling - // entity field is used for consistency with tenants (both tagged for RediSearch filtering) - pipe.HSet(ctx, key, "id", destination.ID) - pipe.HSet(ctx, key, "entity", "destination") - pipe.HSet(ctx, key, "type", destination.Type) - pipe.HSet(ctx, key, "topics", &destination.Topics) - pipe.HSet(ctx, key, "config", &destination.Config) - pipe.HSet(ctx, key, "credentials", encryptedCredentials) - pipe.HSet(ctx, key, "created_at", destination.CreatedAt.UnixMilli()) - pipe.HSet(ctx, key, "updated_at", destination.UpdatedAt.UnixMilli()) - - if destination.DisabledAt != nil { - pipe.HSet(ctx, key, "disabled_at", destination.DisabledAt.UnixMilli()) - } else { - pipe.HDel(ctx, key, "disabled_at") - } - - // Store encrypted delivery_metadata if present - if destination.DeliveryMetadata != nil { - pipe.HSet(ctx, key, "delivery_metadata", encryptedDeliveryMetadata) - } else { - pipe.HDel(ctx, key, "delivery_metadata") - } - - // Store metadata if present - if destination.Metadata != nil { - pipe.HSet(ctx, key, "metadata", &destination.Metadata) - } else { - pipe.HDel(ctx, key, "metadata") - } - - // Store filter if present - if len(destination.Filter) > 0 { - pipe.HSet(ctx, key, "filter", &destination.Filter) - } else { - pipe.HDel(ctx, key, "filter") - } - - // Update summary atomically - pipe.HSet(ctx, summaryKey, destination.ID, destination.ToSummary()) - return nil - }) - - return err -} - -func (s *entityStoreImpl) DeleteDestination(ctx context.Context, tenantID, destinationID string) error { - key := s.redisDestinationID(destinationID, tenantID) - summaryKey := s.redisTenantDestinationSummaryKey(tenantID) - - // Check if destination exists - if exists, err := s.redisClient.Exists(ctx, key).Result(); err != nil { - return err - } else if exists == 0 { - return ErrDestinationNotFound - } - - // Atomic deletion with same-tenant keys - _, err := s.redisClient.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - nowUnixMilli := time.Now().UnixMilli() - - // Remove from summary and mark as deleted atomically - pipe.HDel(ctx, summaryKey, destinationID) - pipe.HSet(ctx, key, "deleted_at", nowUnixMilli) - pipe.Expire(ctx, key, 7*24*time.Hour) - - return nil - }) - - return err -} - -func (s *entityStoreImpl) MatchEvent(ctx context.Context, event Event) ([]DestinationSummary, error) { - destinationSummaryList, err := s.listDestinationSummaryByTenant(ctx, event.TenantID, ListDestinationByTenantOpts{}) - if err != nil { - return nil, err - } - - matchedDestinationSummaryList := []DestinationSummary{} - - for _, destinationSummary := range destinationSummaryList { - if destinationSummary.Disabled { - continue - } - // Match by topic first (if topic is provided) - if event.Topic != "" && !destinationSummary.Topics.MatchTopic(event.Topic) { - continue - } - // Then apply filter (if filter is set) - if !destinationSummary.MatchFilter(event) { - continue - } - matchedDestinationSummaryList = append(matchedDestinationSummaryList, destinationSummary) - } - - return matchedDestinationSummaryList, nil -} - -func (s *entityStoreImpl) parseTenantTopics(destinationSummaryList []DestinationSummary) []string { - all := false - topicsSet := make(map[string]struct{}) - for _, destination := range destinationSummaryList { - for _, topic := range destination.Topics { - if topic == "*" { - all = true - break - } - topicsSet[topic] = struct{}{} - } - } - - if all { - return []string{"*"} - } - - topics := make([]string, 0, len(topicsSet)) - for topic := range topicsSet { - topics = append(topics, topic) - } - - sort.Strings(topics) - return topics -} - -type ListDestinationByTenantOpts struct { - Filter *DestinationFilter -} - -type DestinationFilter struct { - Type []string - Topics []string -} - -func WithDestinationFilter(filter DestinationFilter) ListDestinationByTenantOpts { - return ListDestinationByTenantOpts{Filter: &filter} -} - -// match returns true if the destinationSummary matches the options -func (filter DestinationFilter) match(destinationSummary DestinationSummary) bool { - if len(filter.Type) > 0 && !slices.Contains(filter.Type, destinationSummary.Type) { - return false - } - if len(filter.Topics) > 0 { - filterMatchesAll := len(filter.Topics) == 1 && filter.Topics[0] == "*" - if !destinationSummary.Topics.MatchesAll() { - if filterMatchesAll { - return false - } - for _, topic := range filter.Topics { - if !slices.Contains(destinationSummary.Topics, topic) { - return false - } - } - } - } - return true -} diff --git a/internal/models/entity_test.go b/internal/models/entity_test.go deleted file mode 100644 index d6c8f3ab..00000000 --- a/internal/models/entity_test.go +++ /dev/null @@ -1,480 +0,0 @@ -package models_test - -import ( - "context" - "encoding/json" - "fmt" - "testing" - "time" - - "github.com/hookdeck/outpost/internal/idgen" - "github.com/hookdeck/outpost/internal/models" - "github.com/hookdeck/outpost/internal/pagination/paginationtest" - "github.com/hookdeck/outpost/internal/redis" - "github.com/hookdeck/outpost/internal/util/testinfra" - "github.com/hookdeck/outpost/internal/util/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -// miniredisClientFactory creates a miniredis client (in-memory, no RediSearch) -func miniredisClientFactory(t *testing.T) redis.Cmdable { - return testutil.CreateTestRedisClient(t) -} - -// redisStackClientFactory creates a Redis Stack client on DB 0 (RediSearch works) -// Tests using this are serialized since RediSearch only works on DB 0. -func redisStackClientFactory(t *testing.T) redis.Cmdable { - testinfra.Start(t) - redisCfg := testinfra.NewRedisStackConfig(t) - client, err := redis.New(context.Background(), redisCfg) - if err != nil { - t.Fatalf("failed to create redis client: %v", err) - } - t.Cleanup(func() { client.Close() }) - return client -} - -// dragonflyClientFactory creates a Dragonfly client (DB 1-15, no RediSearch). -// Tests can run in parallel since each gets its own DB. -func dragonflyClientFactory(t *testing.T) redis.Cmdable { - testinfra.Start(t) - redisCfg := testinfra.NewDragonflyConfig(t) - client, err := redis.New(context.Background(), redisCfg) - if err != nil { - t.Fatalf("failed to create dragonfly client: %v", err) - } - t.Cleanup(func() { client.Close() }) - return client -} - -// dragonflyStackClientFactory creates a Dragonfly client on DB 0 (RediSearch works). -// Tests using this are serialized since RediSearch only works on DB 0. -func dragonflyStackClientFactory(t *testing.T) redis.Cmdable { - testinfra.Start(t) - redisCfg := testinfra.NewDragonflyStackConfig(t) - client, err := redis.New(context.Background(), redisCfg) - if err != nil { - t.Fatalf("failed to create dragonfly stack client: %v", err) - } - t.Cleanup(func() { client.Close() }) - return client -} - -// ============================================================================= -// EntityTestSuite with miniredis (in-memory, no RediSearch) -// ============================================================================= - -func TestEntityStore_Miniredis_WithoutDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &EntityTestSuite{ - RedisClientFactory: miniredisClientFactory, - deploymentID: "", - }) -} - -func TestEntityStore_Miniredis_WithDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &EntityTestSuite{ - RedisClientFactory: miniredisClientFactory, - deploymentID: "dp_test_001", - }) -} - -// ============================================================================= -// EntityTestSuite with Redis Stack (real Redis with RediSearch) -// ============================================================================= - -func TestEntityStore_RedisStack_WithoutDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &EntityTestSuite{ - RedisClientFactory: redisStackClientFactory, - deploymentID: "", - }) -} - -func TestEntityStore_RedisStack_WithDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &EntityTestSuite{ - RedisClientFactory: redisStackClientFactory, - deploymentID: "dp_test_001", - }) -} - -// ============================================================================= -// EntityTestSuite with Dragonfly (DB 1-15, no RediSearch, faster parallel tests) -// ============================================================================= - -func TestEntityStore_Dragonfly_WithoutDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &EntityTestSuite{ - RedisClientFactory: dragonflyClientFactory, - deploymentID: "", - }) -} - -func TestEntityStore_Dragonfly_WithDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &EntityTestSuite{ - RedisClientFactory: dragonflyClientFactory, - deploymentID: "dp_test_001", - }) -} - -// ============================================================================= -// ListTenantTestSuite - only runs with Redis Stack (requires RediSearch) -// ============================================================================= - -// Parallel OK - different deployment IDs use different indexes -func TestListTenant_RedisStack_WithoutDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &ListTenantTestSuite{ - RedisClientFactory: redisStackClientFactory, - deploymentID: "", - }) -} - -// Parallel OK - different deployment IDs use different indexes -func TestListTenant_RedisStack_WithDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &ListTenantTestSuite{ - RedisClientFactory: redisStackClientFactory, - deploymentID: "dp_test_001", - }) -} - -// ============================================================================= -// ListTenantTestSuite with Dragonfly Stack (DB 0 for RediSearch) -// ============================================================================= - -// Parallel OK - different deployment IDs use different indexes -func TestListTenant_Dragonfly_WithoutDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &ListTenantTestSuite{ - RedisClientFactory: dragonflyStackClientFactory, - deploymentID: "", - }) -} - -// Parallel OK - different deployment IDs use different indexes -func TestListTenant_Dragonfly_WithDeploymentID(t *testing.T) { - t.Parallel() - suite.Run(t, &ListTenantTestSuite{ - RedisClientFactory: dragonflyStackClientFactory, - deploymentID: "dp_test_001", - }) -} - -// ============================================================================= -// ListTenant Pagination Suite - Tests using paginationtest.Suite -// These tests verify correct cursor-based pagination behavior including: -// - Forward/backward traversal -// - n+1 pattern (hasMore detection) -// - First page has no Prev, last page has no Next -// ============================================================================= - -func TestListTenantPagination(t *testing.T) { - t.Parallel() - runListTenantPaginationSuite(t, dragonflyStackClientFactory, "") -} - -func TestListTenantPagination_WithDeploymentID(t *testing.T) { - t.Parallel() - runListTenantPaginationSuite(t, dragonflyStackClientFactory, "dp_pagination_test") -} - -func TestListTenantPagination_Compat_RedisStack(t *testing.T) { - t.Parallel() - runListTenantPaginationSuite(t, redisStackClientFactory, "") -} - -func TestListTenantPagination_Compat_RedisStack_WithDeploymentID(t *testing.T) { - t.Parallel() - runListTenantPaginationSuite(t, redisStackClientFactory, "dp_pagination_test") -} - -func runListTenantPaginationSuite(t *testing.T, factory RedisClientFactory, deploymentID string) { - ctx := context.Background() - redisClient := factory(t) - - // Add unique suffix to deployment ID to isolate test data between parallel runs - if deploymentID != "" { - deploymentID = fmt.Sprintf("%s_%d", deploymentID, time.Now().UnixNano()) - } else { - deploymentID = fmt.Sprintf("pagination_test_%d", time.Now().UnixNano()) - } - - entityStore := models.NewEntityStore(redisClient, - models.WithCipher(models.NewAESCipher("secret")), - models.WithAvailableTopics(testutil.TestTopics), - models.WithDeploymentID(deploymentID), - ) - - // Initialize entity store (probes for RediSearch) - err := entityStore.Init(ctx) - require.NoError(t, err) - - // Track created tenant IDs for cleanup between subtests - var createdTenantIDs []string - baseTime := time.Now() - - paginationSuite := paginationtest.Suite[models.Tenant]{ - Name: "entitystore_ListTenant", - - NewItem: func(index int) models.Tenant { - return models.Tenant{ - ID: fmt.Sprintf("tenant_pagination_%d_%d", time.Now().UnixNano(), index), - CreatedAt: baseTime.Add(time.Duration(index) * time.Second), - UpdatedAt: baseTime.Add(time.Duration(index) * time.Second), - } - }, - - InsertMany: func(ctx context.Context, items []models.Tenant) error { - for _, item := range items { - if err := entityStore.UpsertTenant(ctx, item); err != nil { - return err - } - createdTenantIDs = append(createdTenantIDs, item.ID) - } - return nil - }, - - List: func(ctx context.Context, opts paginationtest.ListOpts) (paginationtest.ListResult[models.Tenant], error) { - resp, err := entityStore.ListTenant(ctx, models.ListTenantRequest{ - Limit: opts.Limit, - Dir: opts.Order, - Next: opts.Next, - Prev: opts.Prev, - }) - if err != nil { - return paginationtest.ListResult[models.Tenant]{}, err - } - // Convert *string cursors to string for test framework - var next, prev string - if resp.Pagination.Next != nil { - next = *resp.Pagination.Next - } - if resp.Pagination.Prev != nil { - prev = *resp.Pagination.Prev - } - return paginationtest.ListResult[models.Tenant]{ - Items: resp.Models, - Next: next, - Prev: prev, - }, nil - }, - - GetID: func(t models.Tenant) string { - return t.ID - }, - - Cleanup: func(ctx context.Context) error { - for _, id := range createdTenantIDs { - _ = entityStore.DeleteTenant(ctx, id) - } - createdTenantIDs = nil - return nil - }, - } - - paginationSuite.Run(t) -} - -// TestDestinationCredentialsEncryption verifies that credentials and delivery_metadata -// are properly encrypted when stored in Redis. -// -// NOTE: This test accesses Redis implementation details directly to verify encryption. -// While this couples the test to the storage implementation, it's necessary to confirm -// that sensitive fields are actually encrypted at rest. -func TestDestinationCredentialsEncryption(t *testing.T) { - t.Parallel() - - ctx := context.Background() - redisClient := testutil.CreateTestRedisClient(t) - cipher := models.NewAESCipher("secret") - - entityStore := models.NewEntityStore(redisClient, - models.WithCipher(cipher), - models.WithAvailableTopics(testutil.TestTopics), - ) - - input := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithType("rabbitmq"), - testutil.DestinationFactory.WithTopics([]string{"user.created", "user.updated"}), - testutil.DestinationFactory.WithConfig(map[string]string{ - "server_url": "localhost:5672", - "exchange": "events", - }), - testutil.DestinationFactory.WithCredentials(map[string]string{ - "username": "guest", - "password": "guest", - }), - testutil.DestinationFactory.WithDeliveryMetadata(map[string]string{ - "Authorization": "Bearer secret-token", - "X-API-Key": "sensitive-key", - }), - ) - - err := entityStore.UpsertDestination(ctx, input) - require.NoError(t, err) - - // Access Redis directly to verify encryption (implementation detail) - keyFormat := "tenant:{%s}:destination:%s" - actual, err := redisClient.HGetAll(ctx, fmt.Sprintf(keyFormat, input.TenantID, input.ID)).Result() - require.NoError(t, err) - - // Verify credentials are encrypted (not plaintext) - assert.NotEqual(t, input.Credentials, actual["credentials"]) - - // Verify we can decrypt credentials back to original - decryptedCredentials, err := cipher.Decrypt([]byte(actual["credentials"])) - require.NoError(t, err) - jsonCredentials, _ := json.Marshal(input.Credentials) - assert.Equal(t, string(jsonCredentials), string(decryptedCredentials)) - - // Verify delivery_metadata is encrypted (not plaintext) - assert.NotEqual(t, input.DeliveryMetadata, actual["delivery_metadata"]) - - // Verify we can decrypt delivery_metadata back to original - decryptedDeliveryMetadata, err := cipher.Decrypt([]byte(actual["delivery_metadata"])) - require.NoError(t, err) - jsonDeliveryMetadata, _ := json.Marshal(input.DeliveryMetadata) - assert.Equal(t, string(jsonDeliveryMetadata), string(decryptedDeliveryMetadata)) -} - -// TestMaxDestinationsPerTenant verifies that the entity store properly enforces -// the maximum destinations per tenant limit. -func TestMaxDestinationsPerTenant(t *testing.T) { - t.Parallel() - - ctx := context.Background() - redisClient := testutil.CreateTestRedisClient(t) - maxDestinations := 2 - - limitedStore := models.NewEntityStore(redisClient, - models.WithCipher(models.NewAESCipher("secret")), - models.WithAvailableTopics(testutil.TestTopics), - models.WithMaxDestinationsPerTenant(maxDestinations), - ) - - tenant := models.Tenant{ - ID: idgen.String(), - CreatedAt: time.Now(), - } - require.NoError(t, limitedStore.UpsertTenant(ctx, tenant)) - - // Should be able to create up to maxDestinations - for i := 0; i < maxDestinations; i++ { - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - ) - err := limitedStore.CreateDestination(ctx, destination) - require.NoError(t, err, "Should be able to create destination %d", i+1) - } - - // Should fail when trying to create one more - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - ) - err := limitedStore.CreateDestination(ctx, destination) - require.Error(t, err) - require.ErrorIs(t, err, models.ErrMaxDestinationsPerTenantReached) - - // Should be able to create after deleting one - destinations, err := limitedStore.ListDestinationByTenant(ctx, tenant.ID) - require.NoError(t, err) - require.NoError(t, limitedStore.DeleteDestination(ctx, tenant.ID, destinations[0].ID)) - - destination = testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - ) - err = limitedStore.CreateDestination(ctx, destination) - require.NoError(t, err, "Should be able to create destination after deleting one") -} - -// TestDeploymentIsolation verifies that entity stores with different deployment IDs -// are completely isolated from each other, even when sharing the same Redis instance. -func TestDeploymentIsolation(t *testing.T) { - t.Parallel() - - ctx := context.Background() - redisClient := testutil.CreateTestRedisClient(t) - - // Create two entity stores with different deployment IDs - store1 := models.NewEntityStore(redisClient, - models.WithCipher(models.NewAESCipher("secret")), - models.WithAvailableTopics(testutil.TestTopics), - models.WithDeploymentID("dp_001"), - ) - - store2 := models.NewEntityStore(redisClient, - models.WithCipher(models.NewAESCipher("secret")), - models.WithAvailableTopics(testutil.TestTopics), - models.WithDeploymentID("dp_002"), - ) - - // Use the SAME tenant ID and destination ID for both deployments - tenantID := idgen.String() - destinationID := idgen.Destination() - - // Create tenant in both deployments - tenant := models.Tenant{ - ID: tenantID, - CreatedAt: time.Now(), - } - require.NoError(t, store1.UpsertTenant(ctx, tenant)) - require.NoError(t, store2.UpsertTenant(ctx, tenant)) - - // Create destination with same ID but different config in each deployment - destination1 := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID(destinationID), - testutil.DestinationFactory.WithTenantID(tenantID), - testutil.DestinationFactory.WithConfig(map[string]string{ - "deployment": "dp_001", - }), - ) - destination2 := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID(destinationID), - testutil.DestinationFactory.WithTenantID(tenantID), - testutil.DestinationFactory.WithConfig(map[string]string{ - "deployment": "dp_002", - }), - ) - - require.NoError(t, store1.CreateDestination(ctx, destination1)) - require.NoError(t, store2.CreateDestination(ctx, destination2)) - - // Verify store1 only sees its own data - retrieved1, err := store1.RetrieveDestination(ctx, tenantID, destinationID) - require.NoError(t, err) - assert.Equal(t, "dp_001", retrieved1.Config["deployment"], "Store 1 should see its own data") - - // Verify store2 only sees its own data - retrieved2, err := store2.RetrieveDestination(ctx, tenantID, destinationID) - require.NoError(t, err) - assert.Equal(t, "dp_002", retrieved2.Config["deployment"], "Store 2 should see its own data") - - // Verify list operations are also isolated - list1, err := store1.ListDestinationByTenant(ctx, tenantID) - require.NoError(t, err) - require.Len(t, list1, 1, "Store 1 should only see 1 destination") - assert.Equal(t, "dp_001", list1[0].Config["deployment"]) - - list2, err := store2.ListDestinationByTenant(ctx, tenantID) - require.NoError(t, err) - require.Len(t, list2, 1, "Store 2 should only see 1 destination") - assert.Equal(t, "dp_002", list2[0].Config["deployment"]) - - // Verify deleting from one deployment doesn't affect the other - require.NoError(t, store1.DeleteDestination(ctx, tenantID, destinationID)) - - // Store1 should not find the destination - _, err = store1.RetrieveDestination(ctx, tenantID, destinationID) - require.ErrorIs(t, err, models.ErrDestinationDeleted) - - // Store2 should still have its destination - retrieved2Again, err := store2.RetrieveDestination(ctx, tenantID, destinationID) - require.NoError(t, err) - assert.Equal(t, "dp_002", retrieved2Again.Config["deployment"], "Store 2 data should be unaffected") -} diff --git a/internal/models/entitysuite_test.go b/internal/models/entitysuite_test.go deleted file mode 100644 index ff025a3d..00000000 --- a/internal/models/entitysuite_test.go +++ /dev/null @@ -1,1519 +0,0 @@ -package models_test - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/hookdeck/outpost/internal/idgen" - "github.com/hookdeck/outpost/internal/models" - "github.com/hookdeck/outpost/internal/redis" - "github.com/hookdeck/outpost/internal/util/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -// Helper function used by test suite -func assertEqualDestination(t *testing.T, expected, actual models.Destination) { - assert.Equal(t, expected.ID, actual.ID) - assert.Equal(t, expected.Type, actual.Type) - assert.Equal(t, expected.Topics, actual.Topics) - assert.Equal(t, expected.Filter, actual.Filter) - assert.Equal(t, expected.Config, actual.Config) - assert.Equal(t, expected.Credentials, actual.Credentials) - assert.Equal(t, expected.DeliveryMetadata, actual.DeliveryMetadata) - assert.Equal(t, expected.Metadata, actual.Metadata) - // Use time.Time.Equal() to compare instants (ignores timezone/nanoseconds) - // Timestamps are stored as Unix milliseconds, so sub-millisecond precision is lost and times return as UTC - assertEqualTime(t, expected.CreatedAt, actual.CreatedAt, "CreatedAt") - assertEqualTime(t, expected.UpdatedAt, actual.UpdatedAt, "UpdatedAt") - assertEqualTimePtr(t, expected.DisabledAt, actual.DisabledAt, "DisabledAt") -} - -// assertEqualTime compares two times by truncating to millisecond precision -// since timestamps are stored as Unix milliseconds. -func assertEqualTime(t *testing.T, expected, actual time.Time, field string) { - t.Helper() - // Truncate to milliseconds since Unix timestamps lose sub-millisecond precision - expectedTrunc := expected.Truncate(time.Millisecond) - actualTrunc := actual.Truncate(time.Millisecond) - assert.True(t, expectedTrunc.Equal(actualTrunc), - "expected %s %v, got %v", field, expectedTrunc, actualTrunc) -} - -// assertEqualTimePtr compares two optional times by truncating to millisecond precision. -func assertEqualTimePtr(t *testing.T, expected, actual *time.Time, field string) { - t.Helper() - if expected == nil { - assert.Nil(t, actual, "%s should be nil", field) - return - } - require.NotNil(t, actual, "%s should not be nil", field) - assertEqualTime(t, *expected, *actual, field) -} - -// RedisClientFactory creates a Redis client for testing. -// Required - each test suite must explicitly provide one. -type RedisClientFactory func(t *testing.T) redis.Cmdable - -// EntityTestSuite contains all entity store tests. -// Requires a RedisClientFactory to be set before running. -type EntityTestSuite struct { - suite.Suite - ctx context.Context - redisClient redis.Cmdable - entityStore models.EntityStore - deploymentID string - RedisClientFactory RedisClientFactory // Required - must be set -} - -func (s *EntityTestSuite) SetupSuite() { - s.ctx = context.Background() - - require.NotNil(s.T(), s.RedisClientFactory, "RedisClientFactory must be set") - s.redisClient = s.RedisClientFactory(s.T()) - - opts := []models.EntityStoreOption{ - models.WithCipher(models.NewAESCipher("secret")), - models.WithAvailableTopics(testutil.TestTopics), - } - if s.deploymentID != "" { - opts = append(opts, models.WithDeploymentID(s.deploymentID)) - } - s.entityStore = models.NewEntityStore(s.redisClient, opts...) - - // Initialize entity store (probes for RediSearch) - err := s.entityStore.Init(s.ctx) - require.NoError(s.T(), err) -} - -func (s *EntityTestSuite) TestInitIdempotency() { - // Calling Init multiple times should not fail (index already exists is handled gracefully) - for i := 0; i < 3; i++ { - err := s.entityStore.Init(s.ctx) - require.NoError(s.T(), err, "Init call %d should not fail", i+1) - } -} - -func (s *EntityTestSuite) TestListTenantNotSupported() { - // This test verifies behavior when RediSearch is NOT available (miniredis case) - // When running with Redis Stack, this test will pass but ListTenant will work - // When running with miniredis, ListTenant should return ErrListTenantNotSupported - - _, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{}) - - // Check if we're on miniredis (no RediSearch support) - // We can detect this by checking if the error is ErrListTenantNotSupported - if err != nil { - assert.ErrorIs(s.T(), err, models.ErrListTenantNotSupported, - "ListTenant should return ErrListTenantNotSupported when RediSearch is not available") - } - // If err is nil, we're on Redis Stack and ListTenant works - that's fine too -} - -func (s *EntityTestSuite) TestTenantCRUD() { - t := s.T() - now := time.Now() - input := models.Tenant{ - ID: idgen.String(), - CreatedAt: now, - UpdatedAt: now, - } - - t.Run("gets empty", func(t *testing.T) { - actual, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - assert.Nil(s.T(), actual) - assert.NoError(s.T(), err) - }) - - t.Run("sets", func(t *testing.T) { - err := s.entityStore.UpsertTenant(s.ctx, input) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), input.ID, retrieved.ID) - assertEqualTime(t, input.CreatedAt, retrieved.CreatedAt, "CreatedAt") - }) - - t.Run("gets", func(t *testing.T) { - actual, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), input.ID, actual.ID) - assertEqualTime(t, input.CreatedAt, actual.CreatedAt, "CreatedAt") - }) - - t.Run("overrides", func(t *testing.T) { - input.CreatedAt = time.Now() - - err := s.entityStore.UpsertTenant(s.ctx, input) - require.NoError(s.T(), err) - - actual, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), input.ID, actual.ID) - assertEqualTime(t, input.CreatedAt, actual.CreatedAt, "CreatedAt") - }) - - t.Run("clears", func(t *testing.T) { - require.NoError(s.T(), s.entityStore.DeleteTenant(s.ctx, input.ID)) - - actual, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - assert.ErrorIs(s.T(), err, models.ErrTenantDeleted) - assert.Nil(s.T(), actual) - }) - - t.Run("deletes again", func(t *testing.T) { - assert.NoError(s.T(), s.entityStore.DeleteTenant(s.ctx, input.ID)) - }) - - t.Run("deletes non-existent", func(t *testing.T) { - assert.ErrorIs(s.T(), s.entityStore.DeleteTenant(s.ctx, "non-existent-tenant"), models.ErrTenantNotFound) - }) - - t.Run("creates & overrides deleted resource", func(t *testing.T) { - require.NoError(s.T(), s.entityStore.UpsertTenant(s.ctx, input)) - - actual, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), input.ID, actual.ID) - assertEqualTime(t, input.CreatedAt, actual.CreatedAt, "CreatedAt") - }) - - t.Run("upserts with metadata", func(t *testing.T) { - input.Metadata = map[string]string{ - "environment": "production", - "team": "platform", - } - - err := s.entityStore.UpsertTenant(s.ctx, input) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), input.ID, retrieved.ID) - assert.Equal(s.T(), input.Metadata, retrieved.Metadata) - }) - - t.Run("updates metadata", func(t *testing.T) { - input.Metadata = map[string]string{ - "environment": "staging", - "team": "engineering", - "region": "us-west-2", - } - - err := s.entityStore.UpsertTenant(s.ctx, input) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), input.Metadata, retrieved.Metadata) - }) - - t.Run("handles nil metadata", func(t *testing.T) { - input.Metadata = nil - - err := s.entityStore.UpsertTenant(s.ctx, input) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, input.ID) - require.NoError(s.T(), err) - assert.Nil(s.T(), retrieved.Metadata) - }) - - // UpdatedAt tests - t.Run("sets updated_at on create", func(t *testing.T) { - newTenant := testutil.TenantFactory.Any() - - err := s.entityStore.UpsertTenant(s.ctx, newTenant) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, newTenant.ID) - require.NoError(s.T(), err) - assert.True(s.T(), newTenant.UpdatedAt.Unix() == retrieved.UpdatedAt.Unix()) - }) - - t.Run("updates updated_at on upsert", func(t *testing.T) { - // Use explicit timestamps 1 second apart (Unix timestamps have second precision) - originalTime := time.Now().Add(-2 * time.Second).Truncate(time.Second) - updatedTime := originalTime.Add(1 * time.Second) - - original := testutil.TenantFactory.Any() - original.UpdatedAt = originalTime - - err := s.entityStore.UpsertTenant(s.ctx, original) - require.NoError(s.T(), err) - - // Update the tenant with a later timestamp - updated := original - updated.UpdatedAt = updatedTime - - err = s.entityStore.UpsertTenant(s.ctx, updated) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, updated.ID) - require.NoError(s.T(), err) - - // updated_at should be newer than original (comparing truncated times) - assert.True(s.T(), retrieved.UpdatedAt.After(originalTime)) - assert.True(s.T(), updated.UpdatedAt.Unix() == retrieved.UpdatedAt.Unix()) - }) - - t.Run("fallback updated_at to created_at for existing records", func(t *testing.T) { - // Create a tenant normally first - oldTenant := testutil.TenantFactory.Any() - err := s.entityStore.UpsertTenant(s.ctx, oldTenant) - require.NoError(s.T(), err) - - // Now manually remove the updated_at field from Redis to simulate old record - key := "tenant:" + oldTenant.ID - err = s.redisClient.HDel(s.ctx, key, "updated_at").Err() - require.NoError(s.T(), err) - - // Retrieve should fallback updated_at to created_at - retrieved, err := s.entityStore.RetrieveTenant(s.ctx, oldTenant.ID) - require.NoError(s.T(), err) - assert.True(s.T(), retrieved.UpdatedAt.Equal(retrieved.CreatedAt)) - }) -} - -func (s *EntityTestSuite) TestDestinationCRUD() { - t := s.T() - now := time.Now() - input := models.Destination{ - ID: idgen.Destination(), - Type: "rabbitmq", - Topics: []string{"user.created", "user.updated"}, - Config: map[string]string{ - "server_url": "localhost:5672", - "exchange": "events", - }, - Credentials: map[string]string{ - "username": "guest", - "password": "guest", - }, - DeliveryMetadata: map[string]string{ - "app-id": "test-app", - "source": "outpost", - }, - Metadata: map[string]string{ - "environment": "test", - "team": "platform", - }, - CreatedAt: now, - UpdatedAt: now, - DisabledAt: nil, - TenantID: idgen.String(), - } - - t.Run("gets empty", func(t *testing.T) { - actual, err := s.entityStore.RetrieveDestination(s.ctx, input.TenantID, input.ID) - require.NoError(s.T(), err) - assert.Nil(s.T(), actual) - }) - - t.Run("sets", func(t *testing.T) { - err := s.entityStore.CreateDestination(s.ctx, input) - require.NoError(s.T(), err) - }) - - t.Run("gets", func(t *testing.T) { - actual, err := s.entityStore.RetrieveDestination(s.ctx, input.TenantID, input.ID) - require.NoError(s.T(), err) - assertEqualDestination(t, input, *actual) - }) - - t.Run("updates", func(t *testing.T) { - input.Topics = []string{"*"} - input.DeliveryMetadata = map[string]string{ - "app-id": "updated-app", - "version": "2.0", - } - input.Metadata = map[string]string{ - "environment": "staging", - } - - err := s.entityStore.UpsertDestination(s.ctx, input) - require.NoError(s.T(), err) - - actual, err := s.entityStore.RetrieveDestination(s.ctx, input.TenantID, input.ID) - require.NoError(s.T(), err) - assertEqualDestination(t, input, *actual) - }) - - t.Run("clears", func(t *testing.T) { - err := s.entityStore.DeleteDestination(s.ctx, input.TenantID, input.ID) - require.NoError(s.T(), err) - - actual, err := s.entityStore.RetrieveDestination(s.ctx, input.TenantID, input.ID) - assert.ErrorIs(s.T(), err, models.ErrDestinationDeleted) - assert.Nil(s.T(), actual) - }) - - t.Run("creates & overrides deleted resource", func(t *testing.T) { - err := s.entityStore.CreateDestination(s.ctx, input) - require.NoError(s.T(), err) - - actual, err := s.entityStore.RetrieveDestination(s.ctx, input.TenantID, input.ID) - require.NoError(s.T(), err) - assertEqualDestination(t, input, *actual) - }) - - t.Run("err when creates duplicate", func(t *testing.T) { - assert.ErrorIs(s.T(), s.entityStore.CreateDestination(s.ctx, input), models.ErrDuplicateDestination) - - // cleanup - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, input.TenantID, input.ID)) - }) - - t.Run("handles nil delivery_metadata and metadata", func(t *testing.T) { - // Factory defaults to nil for DeliveryMetadata and Metadata - inputWithNilFields := testutil.DestinationFactory.Any() - - err := s.entityStore.CreateDestination(s.ctx, inputWithNilFields) - require.NoError(s.T(), err) - - actual, err := s.entityStore.RetrieveDestination(s.ctx, inputWithNilFields.TenantID, inputWithNilFields.ID) - require.NoError(s.T(), err) - assert.Nil(t, actual.DeliveryMetadata) - assert.Nil(t, actual.Metadata) - - // cleanup - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, inputWithNilFields.TenantID, inputWithNilFields.ID)) - }) - - // UpdatedAt tests - t.Run("sets updated_at on create", func(t *testing.T) { - newDest := testutil.DestinationFactory.Any() - - err := s.entityStore.CreateDestination(s.ctx, newDest) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, newDest.TenantID, newDest.ID) - require.NoError(s.T(), err) - assert.True(s.T(), newDest.UpdatedAt.Unix() == retrieved.UpdatedAt.Unix()) - - // cleanup - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, newDest.TenantID, newDest.ID)) - }) - - t.Run("updates updated_at on upsert", func(t *testing.T) { - // Use explicit timestamps 1 second apart (Unix timestamps have second precision) - originalTime := time.Now().Add(-2 * time.Second).Truncate(time.Second) - updatedTime := originalTime.Add(1 * time.Second) - - original := testutil.DestinationFactory.Any() - original.UpdatedAt = originalTime - - err := s.entityStore.CreateDestination(s.ctx, original) - require.NoError(s.T(), err) - - // Update the destination with a later timestamp - updated := original - updated.UpdatedAt = updatedTime - updated.Topics = []string{"updated.topic"} - - err = s.entityStore.UpsertDestination(s.ctx, updated) - require.NoError(s.T(), err) - - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, updated.TenantID, updated.ID) - require.NoError(s.T(), err) - - // updated_at should be newer than original (comparing truncated times) - assert.True(s.T(), retrieved.UpdatedAt.After(originalTime)) - assert.True(s.T(), updated.UpdatedAt.Unix() == retrieved.UpdatedAt.Unix()) - - // cleanup - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, updated.TenantID, updated.ID)) - }) - - t.Run("fallback updated_at to created_at for existing records", func(t *testing.T) { - // Create a destination normally first - oldDest := testutil.DestinationFactory.Any() - err := s.entityStore.CreateDestination(s.ctx, oldDest) - require.NoError(s.T(), err) - - // Now manually remove the updated_at field from Redis to simulate old record - key := "destination:" + oldDest.TenantID + ":" + oldDest.ID - err = s.redisClient.HDel(s.ctx, key, "updated_at").Err() - require.NoError(s.T(), err) - - // Retrieve should fallback updated_at to created_at - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, oldDest.TenantID, oldDest.ID) - require.NoError(s.T(), err) - assert.True(s.T(), retrieved.UpdatedAt.Equal(retrieved.CreatedAt)) - - // cleanup - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, oldDest.TenantID, oldDest.ID)) - }) -} - -func (s *EntityTestSuite) TestListDestinationEmpty() { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, idgen.String()) - require.NoError(s.T(), err) - assert.Empty(s.T(), destinations) -} - -func (s *EntityTestSuite) TestDeleteTenantAndAssociatedDestinations() { - tenant := models.Tenant{ - ID: idgen.String(), - CreatedAt: time.Now(), - } - // Arrange - require.NoError(s.T(), s.entityStore.UpsertTenant(s.ctx, tenant)) - destinationIDs := []string{idgen.Destination(), idgen.Destination(), idgen.Destination()} - for _, id := range destinationIDs { - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID(id), - testutil.DestinationFactory.WithTenantID(tenant.ID), - ))) - } - // Act - require.NoError(s.T(), s.entityStore.DeleteTenant(s.ctx, tenant.ID)) - // Assert - _, err := s.entityStore.RetrieveTenant(s.ctx, tenant.ID) - assert.ErrorIs(s.T(), err, models.ErrTenantDeleted) - for _, id := range destinationIDs { - _, err := s.entityStore.RetrieveDestination(s.ctx, tenant.ID, id) - assert.ErrorIs(s.T(), err, models.ErrDestinationDeleted) - } -} - -// Helper struct for multi-destination tests -type multiDestinationData struct { - tenant models.Tenant - destinations []models.Destination -} - -func (s *EntityTestSuite) setupMultiDestination() multiDestinationData { - data := multiDestinationData{ - tenant: models.Tenant{ - ID: idgen.String(), - CreatedAt: time.Now(), - }, - destinations: make([]models.Destination, 5), - } - require.NoError(s.T(), s.entityStore.UpsertTenant(s.ctx, data.tenant)) - - destinationTopicList := [][]string{ - {"*"}, - {"user.created"}, - {"user.updated"}, - {"user.deleted"}, - {"user.created", "user.updated"}, - } - // Use explicit timestamps 1 second apart to ensure deterministic sort order - // (Unix timestamps have second precision) - baseTime := time.Now().Add(-10 * time.Second).Truncate(time.Second) - for i := 0; i < 5; i++ { - id := idgen.Destination() - data.destinations[i] = testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID(id), - testutil.DestinationFactory.WithTenantID(data.tenant.ID), - testutil.DestinationFactory.WithTopics(destinationTopicList[i]), - testutil.DestinationFactory.WithCreatedAt(baseTime.Add(time.Duration(i)*time.Second)), - ) - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, data.destinations[i])) - } - - // Insert & Delete destination to ensure it's cleaned up properly - toBeDeletedID := idgen.Destination() - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, - testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID(toBeDeletedID), - testutil.DestinationFactory.WithTenantID(data.tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - ))) - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, toBeDeletedID)) - - return data -} - -func (s *EntityTestSuite) TestMultiDestinationRetrieveTenantDestinationsCount() { - data := s.setupMultiDestination() - - tenant, err := s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), 5, tenant.DestinationsCount) -} - -func (s *EntityTestSuite) TestMultiDestinationRetrieveTenantTopics() { - data := s.setupMultiDestination() - - // destinations[0] has topics ["*"], so tenant.Topics should be ["*"] - tenant, err := s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), []string{"*"}, tenant.Topics) - - // After deleting the wildcard destination, tenant.Topics should aggregate remaining topics - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, data.destinations[0].ID)) - tenant, err = s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), []string{"user.created", "user.deleted", "user.updated"}, tenant.Topics) - - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, data.destinations[1].ID)) - tenant, err = s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), []string{"user.created", "user.deleted", "user.updated"}, tenant.Topics) - - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, data.destinations[2].ID)) - tenant, err = s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), []string{"user.created", "user.deleted", "user.updated"}, tenant.Topics) - - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, data.destinations[3].ID)) - tenant, err = s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), []string{"user.created", "user.updated"}, tenant.Topics) - - require.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, data.destinations[4].ID)) - tenant, err = s.entityStore.RetrieveTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Equal(s.T(), []string{}, tenant.Topics) -} - -func (s *EntityTestSuite) TestMultiDestinationListDestinationByTenant() { - data := s.setupMultiDestination() - - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 5) - for index, destination := range destinations { - require.Equal(s.T(), data.destinations[index].ID, destination.ID) - } -} - -func (s *EntityTestSuite) TestMultiDestinationListDestinationWithOpts() { - t := s.T() - data := s.setupMultiDestination() - - t.Run("filter by type: webhook", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Type: []string{"webhook"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 5) - }) - - t.Run("filter by type: rabbitmq", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Type: []string{"rabbitmq"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 0) - }) - - t.Run("filter by type: webhook,rabbitmq", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Type: []string{"webhook", "rabbitmq"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 5) - }) - - t.Run("filter by topic: user.created", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Topics: []string{"user.created"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 3) - }) - - t.Run("filter by topic: user.created,user.updated", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Topics: []string{"user.created", "user.updated"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 2) - }) - - t.Run("filter by type: rabbitmq, topic: user.created,user.updated", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Type: []string{"rabbitmq"}, - Topics: []string{"user.created", "user.updated"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 0) - }) - - t.Run("filter by topic: *", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID, models.WithDestinationFilter(models.DestinationFilter{ - Topics: []string{"*"}, - })) - require.NoError(s.T(), err) - require.Len(s.T(), destinations, 1) - }) -} - -func (s *EntityTestSuite) TestMultiDestinationMatchEvent() { - t := s.T() - data := s.setupMultiDestination() - - t.Run("match by topic", func(t *testing.T) { - event := models.Event{ - ID: idgen.Event(), - Topic: "user.created", - Time: time.Now(), - TenantID: data.tenant.ID, - Metadata: map[string]string{}, - Data: map[string]interface{}{}, - } - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - require.Len(s.T(), matchedDestinationSummaryList, 3) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[0].ID, data.destinations[1].ID, data.destinations[4].ID}, summary.ID) - } - }) - - // MatchEvent IGNORES destination_id and only matches by topic. - // These tests verify that destination_id in the event is intentionally ignored. - // Specific destination matching is handled at a higher level (publishmq package). - t.Run("ignores destination_id and matches by topic only", func(t *testing.T) { - event := models.Event{ - ID: idgen.Event(), - Topic: "user.created", - Time: time.Now(), - TenantID: data.tenant.ID, - DestinationID: data.destinations[1].ID, // This should be IGNORED - Metadata: map[string]string{}, - Data: map[string]interface{}{}, - } - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - // Should match all destinations with "user.created" topic, not just the specified destination_id - require.Len(s.T(), matchedDestinationSummaryList, 3) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[0].ID, data.destinations[1].ID, data.destinations[4].ID}, summary.ID) - } - }) - - t.Run("ignores non-existent destination_id", func(t *testing.T) { - event := models.Event{ - ID: idgen.Event(), - Topic: "user.created", - Time: time.Now(), - TenantID: data.tenant.ID, - DestinationID: "not-found", // This should be IGNORED - Metadata: map[string]string{}, - Data: map[string]interface{}{}, - } - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - // Should still match all destinations with "user.created" topic - require.Len(s.T(), matchedDestinationSummaryList, 3) - }) - - t.Run("ignores destination_id with mismatched topic", func(t *testing.T) { - event := models.Event{ - ID: idgen.Event(), - Topic: "user.created", - Time: time.Now(), - TenantID: data.tenant.ID, - DestinationID: data.destinations[3].ID, // "user.deleted" destination - should be IGNORED - Metadata: map[string]string{}, - Data: map[string]interface{}{}, - } - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - // Should match all destinations with "user.created" topic, not the specified "user.deleted" destination - require.Len(s.T(), matchedDestinationSummaryList, 3) - }) - - t.Run("match after destination is updated", func(t *testing.T) { - updatedIndex := 2 - updatedTopics := []string{"user.created"} - updatedDestination := data.destinations[updatedIndex] - updatedDestination.Topics = updatedTopics - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, updatedDestination)) - - actual, err := s.entityStore.RetrieveDestination(s.ctx, updatedDestination.TenantID, updatedDestination.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), updatedDestination.Topics, actual.Topics) - - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, data.tenant.ID) - require.NoError(s.T(), err) - assert.Len(s.T(), destinations, 5) - - // Match user.created - event := models.Event{ - ID: idgen.Event(), - Topic: "user.created", - Time: time.Now(), - TenantID: data.tenant.ID, - Metadata: map[string]string{}, - Data: map[string]interface{}{}, - } - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - require.Len(s.T(), matchedDestinationSummaryList, 4) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[0].ID, data.destinations[1].ID, data.destinations[2].ID, data.destinations[4].ID}, summary.ID) - } - - // Match user.updated - event = models.Event{ - ID: idgen.Event(), - Topic: "user.updated", - Time: time.Now(), - TenantID: data.tenant.ID, - Metadata: map[string]string{}, - Data: map[string]interface{}{}, - } - matchedDestinationSummaryList, err = s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - require.Len(s.T(), matchedDestinationSummaryList, 2) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[0].ID, data.destinations[4].ID}, summary.ID) - } - }) -} - -func (s *EntityTestSuite) TestDestinationEnableDisable() { - t := s.T() - input := testutil.DestinationFactory.Any() - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, input)) - - assertDestination := func(t *testing.T, expected models.Destination) { - actual, err := s.entityStore.RetrieveDestination(s.ctx, input.TenantID, input.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), expected.ID, actual.ID) - assertEqualTimePtr(t, expected.DisabledAt, actual.DisabledAt, "DisabledAt") - } - - t.Run("should disable", func(t *testing.T) { - now := time.Now() - input.DisabledAt = &now - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, input)) - assertDestination(t, input) - }) - - t.Run("should enable", func(t *testing.T) { - input.DisabledAt = nil - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, input)) - assertDestination(t, input) - }) -} - -func (s *EntityTestSuite) TestMultiSuiteDisableAndMatch() { - t := s.T() - data := s.setupMultiDestination() - - t.Run("initial match user.deleted", func(t *testing.T) { - event := testutil.EventFactory.Any( - testutil.EventFactory.WithTenantID(data.tenant.ID), - testutil.EventFactory.WithTopic("user.deleted"), - ) - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - require.Len(s.T(), matchedDestinationSummaryList, 2) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[0].ID, data.destinations[3].ID}, summary.ID) - } - }) - - t.Run("should not match disabled destination", func(t *testing.T) { - destination := data.destinations[0] - now := time.Now() - destination.DisabledAt = &now - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, destination)) - - event := testutil.EventFactory.Any( - testutil.EventFactory.WithTenantID(data.tenant.ID), - testutil.EventFactory.WithTopic("user.deleted"), - ) - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - require.Len(s.T(), matchedDestinationSummaryList, 1) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[3].ID}, summary.ID) - } - }) - - t.Run("should match after re-enabled destination", func(t *testing.T) { - destination := data.destinations[0] - destination.DisabledAt = nil - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, destination)) - - event := testutil.EventFactory.Any( - testutil.EventFactory.WithTenantID(data.tenant.ID), - testutil.EventFactory.WithTopic("user.deleted"), - ) - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - require.Len(s.T(), matchedDestinationSummaryList, 2) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[0].ID, data.destinations[3].ID}, summary.ID) - } - }) - -} - -func (s *EntityTestSuite) TestDeleteDestination() { - t := s.T() - destination := testutil.DestinationFactory.Any() - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destination)) - - t.Run("should not return error when deleting existing destination", func(t *testing.T) { - assert.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, destination.TenantID, destination.ID)) - }) - - t.Run("should not return error when deleting already-deleted destination", func(t *testing.T) { - assert.NoError(s.T(), s.entityStore.DeleteDestination(s.ctx, destination.TenantID, destination.ID)) - }) - - t.Run("should return error when deleting non-existent destination", func(t *testing.T) { - err := s.entityStore.DeleteDestination(s.ctx, destination.TenantID, idgen.Destination()) - assert.ErrorIs(s.T(), err, models.ErrDestinationNotFound) - }) - - t.Run("should return ErrDestinationDeleted when retrieving deleted destination", func(t *testing.T) { - dest, err := s.entityStore.RetrieveDestination(s.ctx, destination.TenantID, destination.ID) - assert.ErrorIs(s.T(), err, models.ErrDestinationDeleted) - assert.Nil(s.T(), dest) - }) - - t.Run("should not return deleted destination in list", func(t *testing.T) { - destinations, err := s.entityStore.ListDestinationByTenant(s.ctx, destination.TenantID) - assert.NoError(s.T(), err) - assert.Empty(s.T(), destinations) - }) -} - -func (s *EntityTestSuite) TestMultiSuiteDeleteAndMatch() { - t := s.T() - data := s.setupMultiDestination() - - t.Run("delete first destination", func(t *testing.T) { - require.NoError(s.T(), - s.entityStore.DeleteDestination(s.ctx, data.tenant.ID, data.destinations[0].ID), - ) - }) - - t.Run("match event", func(t *testing.T) { - event := testutil.EventFactory.Any( - testutil.EventFactory.WithTenantID(data.tenant.ID), - testutil.EventFactory.WithTopic("user.created"), - ) - - matchedDestinationSummaryList, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - require.Len(s.T(), matchedDestinationSummaryList, 2) - for _, summary := range matchedDestinationSummaryList { - require.Contains(s.T(), []string{data.destinations[1].ID, data.destinations[4].ID}, summary.ID) - } - }) -} - -func (s *EntityTestSuite) TestDestinationFilterPersistence() { - t := s.T() - tenant := models.Tenant{ID: idgen.String()} - require.NoError(s.T(), s.entityStore.UpsertTenant(s.ctx, tenant)) - - t.Run("stores and retrieves destination with filter", func(t *testing.T) { - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{ - "type": "order.created", - }, - }), - ) - - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destination)) - - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, tenant.ID, destination.ID) - require.NoError(s.T(), err) - assert.NotNil(s.T(), retrieved.Filter) - assert.Equal(s.T(), "order.created", retrieved.Filter["data"].(map[string]any)["type"]) - }) - - t.Run("stores destination with nil filter", func(t *testing.T) { - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - ) - - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destination)) - - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, tenant.ID, destination.ID) - require.NoError(s.T(), err) - assert.Nil(s.T(), retrieved.Filter) - }) - - t.Run("updates destination filter", func(t *testing.T) { - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{"type": "order.created"}, - }), - ) - - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destination)) - - // Update filter - destination.Filter = models.Filter{ - "data": map[string]any{"type": "order.updated"}, - } - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, destination)) - - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, tenant.ID, destination.ID) - require.NoError(s.T(), err) - assert.Equal(s.T(), "order.updated", retrieved.Filter["data"].(map[string]any)["type"]) - }) - - t.Run("removes destination filter", func(t *testing.T) { - destination := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{"type": "order.created"}, - }), - ) - - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destination)) - - // Remove filter - destination.Filter = nil - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, destination)) - - retrieved, err := s.entityStore.RetrieveDestination(s.ctx, tenant.ID, destination.ID) - require.NoError(s.T(), err) - assert.Nil(s.T(), retrieved.Filter) - }) -} - -func (s *EntityTestSuite) TestMatchEventWithFilter() { - t := s.T() - tenant := models.Tenant{ID: idgen.String()} - require.NoError(s.T(), s.entityStore.UpsertTenant(s.ctx, tenant)) - - // Create destinations with different filters - destNoFilter := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID("dest_no_filter"), - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - ) - - destFilterOrderCreated := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID("dest_filter_order_created"), - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{ - "type": "order.created", - }, - }), - ) - - destFilterOrderUpdated := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID("dest_filter_order_updated"), - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{ - "type": "order.updated", - }, - }), - ) - - destFilterPremiumCustomer := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID("dest_filter_premium"), - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"*"}), - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{ - "customer": map[string]any{ - "tier": "premium", - }, - }, - }), - ) - - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destNoFilter)) - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destFilterOrderCreated)) - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destFilterOrderUpdated)) - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destFilterPremiumCustomer)) - - t.Run("event without filter field matches only destinations with matching filter", func(t *testing.T) { - event := models.Event{ - ID: idgen.Event(), - TenantID: tenant.ID, - Topic: "order", - Time: time.Now(), - Metadata: map[string]string{}, - Data: map[string]interface{}{ - "type": "order.created", - }, - } - - matchedDestinations, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - // Should match: destNoFilter (no filter), destFilterOrderCreated (matches type) - // Should NOT match: destFilterOrderUpdated (wrong type), destFilterPremiumCustomer (missing customer.tier) - assert.Len(s.T(), matchedDestinations, 2) - ids := []string{} - for _, dest := range matchedDestinations { - ids = append(ids, dest.ID) - } - assert.Contains(s.T(), ids, "dest_no_filter") - assert.Contains(s.T(), ids, "dest_filter_order_created") - }) - - t.Run("event with nested data matches nested filter", func(t *testing.T) { - event := models.Event{ - ID: idgen.Event(), - TenantID: tenant.ID, - Topic: "order", - Time: time.Now(), - Metadata: map[string]string{}, - Data: map[string]interface{}{ - "type": "order.created", - "customer": map[string]interface{}{ - "id": "cust_123", - "tier": "premium", - }, - }, - } - - matchedDestinations, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - // Should match: destNoFilter, destFilterOrderCreated, destFilterPremiumCustomer - // Should NOT match: destFilterOrderUpdated (wrong type) - assert.Len(s.T(), matchedDestinations, 3) - ids := []string{} - for _, dest := range matchedDestinations { - ids = append(ids, dest.ID) - } - assert.Contains(s.T(), ids, "dest_no_filter") - assert.Contains(s.T(), ids, "dest_filter_order_created") - assert.Contains(s.T(), ids, "dest_filter_premium") - }) - - t.Run("topic filter takes precedence before content filter", func(t *testing.T) { - // Create a destination with specific topic AND filter - destTopicAndFilter := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID("dest_topic_and_filter"), - testutil.DestinationFactory.WithTenantID(tenant.ID), - testutil.DestinationFactory.WithTopics([]string{"user.created"}), // Specific topic - testutil.DestinationFactory.WithFilter(models.Filter{ - "data": map[string]any{ - "type": "order.created", - }, - }), - ) - require.NoError(s.T(), s.entityStore.CreateDestination(s.ctx, destTopicAndFilter)) - - // Event with matching filter but wrong topic - event := models.Event{ - ID: idgen.Event(), - TenantID: tenant.ID, - Topic: "order", - Time: time.Now(), - Metadata: map[string]string{}, - Data: map[string]interface{}{ - "type": "order.created", - }, - } - - matchedDestinations, err := s.entityStore.MatchEvent(s.ctx, event) - require.NoError(s.T(), err) - - // Should NOT match destTopicAndFilter because topic doesn't match - for _, dest := range matchedDestinations { - assert.NotEqual(s.T(), "dest_topic_and_filter", dest.ID) - } - }) -} - -// ============================================================================= -// ListTenantTestSuite - Tests for ListTenant functionality (requires RediSearch) -// ============================================================================= - -// ListTenantTestSuite tests ListTenant functionality. -// Only runs with Redis Stack since it requires RediSearch. -type ListTenantTestSuite struct { - suite.Suite - ctx context.Context - redisClient redis.Cmdable - entityStore models.EntityStore - deploymentID string - RedisClientFactory RedisClientFactory // Required - must be set - - // Test data created in SetupSuite - tenants []models.Tenant // 25 tenants for pagination tests - tenantWithDests models.Tenant // First tenant, has destinations -} - -func (s *ListTenantTestSuite) SetupSuite() { - s.ctx = context.Background() - - require.NotNil(s.T(), s.RedisClientFactory, "RedisClientFactory must be set") - s.redisClient = s.RedisClientFactory(s.T()) - - opts := []models.EntityStoreOption{ - models.WithCipher(models.NewAESCipher("secret")), - models.WithAvailableTopics(testutil.TestTopics), - } - if s.deploymentID != "" { - opts = append(opts, models.WithDeploymentID(s.deploymentID)) - } - s.entityStore = models.NewEntityStore(s.redisClient, opts...) - - // Initialize entity store (probes for RediSearch) - err := s.entityStore.Init(s.ctx) - require.NoError(s.T(), err) - - // Verify Init is idempotent - for i := 0; i < 3; i++ { - err := s.entityStore.Init(s.ctx) - require.NoError(s.T(), err, "Init call %d should not fail", i+1) - } - - // Test empty list BEFORE creating any data - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{}) - require.NoError(s.T(), err) - require.Empty(s.T(), resp.Models, "should be empty before creating data") - require.Nil(s.T(), resp.Pagination.Next) - require.Nil(s.T(), resp.Pagination.Prev) - - // Create 25 tenants for pagination tests - s.tenants = make([]models.Tenant, 25) - baseTime := time.Now() - for i := range s.tenants { - s.tenants[i] = testutil.TenantFactory.Any( - testutil.TenantFactory.WithCreatedAt(baseTime.Add(time.Duration(i)*time.Second)), - testutil.TenantFactory.WithUpdatedAt(baseTime.Add(time.Duration(i)*time.Second)), - ) - require.NoError(s.T(), s.entityStore.UpsertTenant(s.ctx, s.tenants[i])) - } - // Use newest tenant (last in array) for destinations so it appears on page 1 with DESC order - s.tenantWithDests = s.tenants[24] - - // Add destinations to newest tenant - for i := range 2 { - dest := testutil.DestinationFactory.Any( - testutil.DestinationFactory.WithID(fmt.Sprintf("dest_suite_%d", i)), - testutil.DestinationFactory.WithTenantID(s.tenantWithDests.ID), - ) - require.NoError(s.T(), s.entityStore.UpsertDestination(s.ctx, dest)) - } -} - -// TestListTenantEnrichment tests ListTenant-specific fields: Count, destinations_count, topics. -// Pagination behavior is tested separately in paginationtest.Suite. -func (s *ListTenantTestSuite) TestListTenantEnrichment() { - // Uses 25 tenants created in SetupSuite, first tenant has 2 destinations - - s.T().Run("returns total count independent of pagination", func(t *testing.T) { - // First page with limit - resp1, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{Limit: 2}) - require.NoError(t, err) - assert.Equal(t, 25, resp1.Count, "count should be total tenants, not page size") - assert.Len(t, resp1.Models, 2, "data should respect limit") - - // Second page - count should still be total - var nextCursor string - if resp1.Pagination.Next != nil { - nextCursor = *resp1.Pagination.Next - } - resp2, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{Limit: 2, Next: nextCursor}) - require.NoError(t, err) - assert.Equal(t, 25, resp2.Count, "count should remain total across pages") - }) - - s.T().Run("does not include destinations in tenant list", func(t *testing.T) { - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{Limit: 100}) - require.NoError(t, err) - assert.Equal(t, 25, resp.Count, "count should be tenants only") - - for _, tenant := range resp.Models { - assert.NotContains(t, tenant.ID, "dest_", "destination should not appear in tenant list") - } - }) - - s.T().Run("returns destinations_count and topics", func(t *testing.T) { - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{Limit: 100}) - require.NoError(t, err) - - // tenantWithDests has 2 destinations from SetupSuite - var tenantWithDests *models.Tenant - for i := range resp.Models { - if resp.Models[i].ID == s.tenantWithDests.ID { - tenantWithDests = &resp.Models[i] - break - } - } - require.NotNil(t, tenantWithDests, "should find tenant with destinations") - assert.Equal(t, 2, tenantWithDests.DestinationsCount, "should have 2 destinations") - assert.NotNil(t, tenantWithDests.Topics, "topics should not be nil") - - // Verify tenants without destinations have 0 count - var tenantWithoutDests *models.Tenant - for i := range resp.Models { - if resp.Models[i].ID != s.tenantWithDests.ID { - tenantWithoutDests = &resp.Models[i] - break - } - } - require.NotNil(t, tenantWithoutDests, "should find tenant without destinations") - assert.Equal(t, 0, tenantWithoutDests.DestinationsCount, "tenant without destinations should have 0 count") - assert.Empty(t, tenantWithoutDests.Topics, "tenant without destinations should have empty topics") - }) -} - -func (s *ListTenantTestSuite) TestListTenantExcludesDeleted() { - s.T().Run("deleted tenant not returned", func(t *testing.T) { - // Get initial count (includes 25 tenants from SetupSuite) - initialResp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{}) - require.NoError(t, err) - initialCount := initialResp.Count - - // Create 2 additional tenants - tenant1 := testutil.TenantFactory.Any() - tenant2 := testutil.TenantFactory.Any() - require.NoError(t, s.entityStore.UpsertTenant(s.ctx, tenant1)) - require.NoError(t, s.entityStore.UpsertTenant(s.ctx, tenant2)) - - // List should show initial + 2 - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{}) - require.NoError(t, err) - assert.Equal(t, initialCount+2, resp.Count) - - // Delete one - require.NoError(t, s.entityStore.DeleteTenant(s.ctx, tenant1.ID)) - - // List should show initial + 1 - resp, err = s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{}) - require.NoError(t, err) - assert.Equal(t, initialCount+1, resp.Count) - - // Verify deleted tenant is not in results - for _, tenant := range resp.Models { - assert.NotEqual(t, tenant1.ID, tenant.ID, "deleted tenant should not appear") - } - - // Cleanup - _ = s.entityStore.DeleteTenant(s.ctx, tenant2.ID) - }) - - s.T().Run("deleted tenants do not consume LIMIT slots", func(t *testing.T) { - // This tests that deleted tenants are filtered at the FT.SEARCH query level, - // not in Go code after fetching. If filtered in Go, requesting limit=2 might - // return fewer results if deleted tenants consumed the LIMIT slots. - - // Create 5 tenants with distinct timestamps far in the future (to be first in DESC order) - baseTime := time.Now().Add(30 * time.Hour) - prefix := fmt.Sprintf("limit_test_%d_", time.Now().UnixNano()) - tenantIDs := make([]string, 5) - for i := 0; i < 5; i++ { - tenantIDs[i] = fmt.Sprintf("%s%d", prefix, i) - tenant := models.Tenant{ - ID: tenantIDs[i], - CreatedAt: baseTime.Add(time.Duration(i) * time.Second), - UpdatedAt: baseTime.Add(time.Duration(i) * time.Second), - } - require.NoError(t, s.entityStore.UpsertTenant(s.ctx, tenant)) - } - - // Delete the 2 newest tenants (index 3 and 4) - require.NoError(t, s.entityStore.DeleteTenant(s.ctx, tenantIDs[3])) - require.NoError(t, s.entityStore.DeleteTenant(s.ctx, tenantIDs[4])) - - // Request limit=2 with DESC order - the 2 newest active (index 2,1) should be returned - // NOT the deleted ones (index 3,4) which would have been first if not filtered - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Limit: 2, - Dir: "desc", - }) - require.NoError(t, err) - require.GreaterOrEqual(t, len(resp.Models), 2, "should get at least 2 tenants") - - // Verify deleted tenants don't appear in the first 2 results - for i := 0; i < 2 && i < len(resp.Models); i++ { - assert.NotEqual(t, tenantIDs[3], resp.Models[i].ID, "deleted tenant should not appear") - assert.NotEqual(t, tenantIDs[4], resp.Models[i].ID, "deleted tenant should not appear") - } - - // Cleanup - for _, id := range tenantIDs { - _ = s.entityStore.DeleteTenant(s.ctx, id) - } - }) -} - -// TestListTenantKeysetPagination verifies keyset pagination handles concurrent modifications correctly. -func (s *ListTenantTestSuite) TestListTenantKeysetPagination() { - s.T().Run("add during traversal does not cause duplicate", func(t *testing.T) { - // With keyset pagination, adding a new item with a newer timestamp - // does NOT cause duplicates because the cursor is based on timestamp, - // not offset. The new item falls outside the timestamp range. - - // Create 15 tenants with unique prefix and timestamps far in the future - prefix := fmt.Sprintf("add_edge_%d_", time.Now().UnixNano()) - tenantIDs := make([]string, 15) - baseTime := time.Now().Add(20 * time.Hour) // Even further future - for i := 0; i < 15; i++ { - tenantIDs[i] = fmt.Sprintf("%s%02d", prefix, i) - tenant := models.Tenant{ - ID: tenantIDs[i], - CreatedAt: baseTime.Add(time.Duration(i) * time.Second), - UpdatedAt: baseTime, - } - require.NoError(t, s.entityStore.UpsertTenant(s.ctx, tenant)) - } - - // Fetch page 1 (items 14, 13, 12, 11, 10 with DESC order) - resp1, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{Limit: 5}) - require.NoError(t, err) - require.Len(t, resp1.Models, 5, "page 1 should have 5 items") - - // Verify we got our test tenants - for _, tenant := range resp1.Models { - require.Contains(t, tenant.ID, prefix, "page 1 should contain our test tenants") - } - - // Add a new tenant that will sort BEFORE all existing ones (newest) - newTenantID := prefix + "NEW" - newTenant := models.Tenant{ - ID: newTenantID, - CreatedAt: baseTime.Add(time.Hour), // Definitely newest in our set - UpdatedAt: baseTime, - } - require.NoError(t, s.entityStore.UpsertTenant(s.ctx, newTenant)) - tenantIDs = append(tenantIDs, newTenantID) - - // Fetch page 2 using cursor from page 1 - // With keyset pagination, the cursor is the timestamp of item 10 - // Page 2 will get items with timestamp < cursor, so items 9, 8, 7, 6, 5 - // The new tenant has a newer timestamp so it won't appear on page 2 - var nextCursor string - if resp1.Pagination.Next != nil { - nextCursor = *resp1.Pagination.Next - } - resp2, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Limit: 5, - Next: nextCursor, - }) - require.NoError(t, err) - require.NotEmpty(t, resp2.Models, "page 2 should have items") - - // Verify no duplicates - first item on page 2 should NOT be the last from page 1 - page1IDs := make(map[string]bool) - for _, tenant := range resp1.Models { - page1IDs[tenant.ID] = true - } - for _, tenant := range resp2.Models { - assert.False(t, page1IDs[tenant.ID], - "keyset pagination: no duplicates when adding during traversal, but found %s", tenant.ID) - } - - // Cleanup - for _, id := range tenantIDs { - _ = s.entityStore.DeleteTenant(s.ctx, id) - } - }) -} - -// TestListTenantInputValidation tests input validation and error handling. -func (s *ListTenantTestSuite) TestListTenantInputValidation() { - s.T().Run("invalid dir returns error", func(t *testing.T) { - _, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Dir: "invalid", - }) - require.Error(t, err) - assert.ErrorIs(t, err, models.ErrInvalidOrder) - }) - - s.T().Run("conflicting cursors returns error", func(t *testing.T) { - _, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Next: "somecursor", - Prev: "anothercursor", - }) - require.Error(t, err) - assert.ErrorIs(t, err, models.ErrConflictingCursors) - }) - - s.T().Run("invalid next cursor returns error", func(t *testing.T) { - _, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Next: "not-valid-base62!!!", - }) - require.Error(t, err) - assert.ErrorIs(t, err, models.ErrInvalidCursor) - }) - - s.T().Run("invalid prev cursor returns error", func(t *testing.T) { - _, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Prev: "not-valid-base62!!!", - }) - require.Error(t, err) - assert.ErrorIs(t, err, models.ErrInvalidCursor) - }) - - s.T().Run("malformed cursor format returns error", func(t *testing.T) { - // Valid base62 but wrong format (missing version prefix) - _, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Next: "abc123", - }) - require.Error(t, err) - assert.ErrorIs(t, err, models.ErrInvalidCursor) - }) - - s.T().Run("limit zero uses default", func(t *testing.T) { - // Uses 25 tenants from SetupSuite - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Limit: 0, // Should use default (20) - }) - require.NoError(t, err) - // Default limit is 20, should return up to 20 of the 25 tenants - assert.Equal(t, 20, len(resp.Models), "default limit should be 20") - assert.Equal(t, 25, resp.Count, "total count should be 25") - }) - - s.T().Run("limit negative uses default", func(t *testing.T) { - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Limit: -5, // Should use default (20) - }) - require.NoError(t, err) - // Should succeed, not error - assert.NotNil(t, resp) - }) - - s.T().Run("limit exceeding max is capped", func(t *testing.T) { - // Uses 25 tenants from SetupSuite - // Request with limit > max (100) - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Limit: 1000, // Should be capped to 100 - }) - require.NoError(t, err) - // Should succeed and return all 25 (capped to 100, but we only have 25) - assert.NotNil(t, resp) - assert.Equal(t, 25, len(resp.Models), "should return all 25 tenants") - assert.Equal(t, 25, resp.Count) - }) - - s.T().Run("empty dir uses default desc", func(t *testing.T) { - // Uses 25 tenants from SetupSuite created with sequential timestamps - // s.tenants[0] is oldest, s.tenants[24] is newest - resp, err := s.entityStore.ListTenant(s.ctx, models.ListTenantRequest{ - Dir: "", // Should default to "desc" - }) - require.NoError(t, err) - require.Len(t, resp.Models, 20, "default limit is 20") - - // With DESC dir, newest (s.tenants[24]) should be first - // and older tenants should follow - assert.Equal(t, s.tenants[24].ID, resp.Models[0].ID, "newest tenant should be first with desc dir") - assert.Equal(t, s.tenants[23].ID, resp.Models[1].ID, "second newest should be second") - }) -} diff --git a/internal/models/serialization.go b/internal/models/serialization.go new file mode 100644 index 00000000..b7dd2a46 --- /dev/null +++ b/internal/models/serialization.go @@ -0,0 +1,154 @@ +package models + +import ( + "encoding" + "encoding/json" + "fmt" + "log" + "strings" +) + +// ============================== Interface assertions ============================== + +var _ encoding.BinaryMarshaler = &Topics{} +var _ encoding.BinaryUnmarshaler = &Topics{} +var _ json.Marshaler = &Topics{} +var _ json.Unmarshaler = &Topics{} + +var _ encoding.BinaryMarshaler = &Filter{} +var _ encoding.BinaryUnmarshaler = &Filter{} + +var _ encoding.BinaryMarshaler = &MapStringString{} +var _ encoding.BinaryUnmarshaler = &MapStringString{} +var _ json.Unmarshaler = &MapStringString{} + +var _ fmt.Stringer = &Data{} +var _ encoding.BinaryUnmarshaler = &Data{} + +// ============================== Topics serialization ============================== + +func (t *Topics) MarshalBinary() ([]byte, error) { + str := strings.Join(*t, ",") + return []byte(str), nil +} + +func (t *Topics) UnmarshalBinary(data []byte) error { + *t = TopicsFromString(string(data)) + return nil +} + +func (t *Topics) MarshalJSON() ([]byte, error) { + return json.Marshal(*t) +} + +func (t *Topics) UnmarshalJSON(data []byte) error { + if string(data) == `"*"` { + *t = TopicsFromString("*") + return nil + } + var arr []string + if err := json.Unmarshal(data, &arr); err != nil { + log.Println(err) + return ErrInvalidTopicsFormat + } + *t = arr + return nil +} + +// ============================== Filter ============================== + +// Filter represents a JSON schema filter for event matching. +// It uses the simplejsonmatch schema syntax for filtering events. +type Filter map[string]any + +func (f *Filter) MarshalBinary() ([]byte, error) { + if f == nil || len(*f) == 0 { + return nil, nil + } + return json.Marshal(f) +} + +func (f *Filter) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, f) +} + +// ============================== MapStringString ============================== + +type Config = MapStringString +type Credentials = MapStringString +type DeliveryMetadata = MapStringString +type MapStringString map[string]string + +func (m *MapStringString) MarshalBinary() ([]byte, error) { + return json.Marshal(m) +} + +func (m *MapStringString) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, m) +} + +func (m *MapStringString) UnmarshalJSON(data []byte) error { + // First try to unmarshal as map[string]string + var stringMap map[string]string + if err := json.Unmarshal(data, &stringMap); err == nil { + *m = stringMap + return nil + } + + // If that fails, try map[string]interface{} to handle mixed types + var mixedMap map[string]interface{} + if err := json.Unmarshal(data, &mixedMap); err != nil { + return err + } + + // Convert all values to strings + result := make(map[string]string) + for k, v := range mixedMap { + switch val := v.(type) { + case string: + result[k] = val + case bool: + result[k] = fmt.Sprintf("%v", val) + case float64: + result[k] = fmt.Sprintf("%v", val) + case nil: + result[k] = "" + default: + // For other types, try to convert to string using JSON marshaling + if b, err := json.Marshal(val); err == nil { + result[k] = string(b) + } else { + result[k] = fmt.Sprintf("%v", val) + } + } + } + + *m = result + return nil +} + +// ============================== Data ============================== + +type Data map[string]interface{} + +func (d *Data) String() string { + data, err := json.Marshal(d) + if err != nil { + return "" + } + return string(data) +} + +func (d *Data) UnmarshalBinary(data []byte) error { + if string(data) == "" { + return nil + } + return json.Unmarshal(data, d) +} + +// ============================== Metadata ============================== + +type Metadata = MapStringString diff --git a/internal/models/event.go b/internal/models/tasks.go similarity index 65% rename from internal/models/event.go rename to internal/models/tasks.go index f6f6e3b0..ae685ade 100644 --- a/internal/models/event.go +++ b/internal/models/tasks.go @@ -1,55 +1,20 @@ package models import ( - "encoding" "encoding/json" - "fmt" - "time" "github.com/hookdeck/outpost/internal/mqs" ) -type Data map[string]interface{} - -var _ fmt.Stringer = &Data{} -var _ encoding.BinaryUnmarshaler = &Data{} - -func (d *Data) String() string { - data, err := json.Marshal(d) - if err != nil { - return "" - } - return string(data) -} - -func (d *Data) UnmarshalBinary(data []byte) error { - if string(data) == "" { - return nil - } - return json.Unmarshal(data, d) -} - -type Metadata = MapStringString - type EventTelemetry struct { TraceID string SpanID string ReceivedTime string // format time.RFC3339Nano } -type Event struct { - ID string `json:"id"` - TenantID string `json:"tenant_id"` - DestinationID string `json:"destination_id"` - Topic string `json:"topic"` - EligibleForRetry bool `json:"eligible_for_retry"` - Time time.Time `json:"time"` - Metadata Metadata `json:"metadata"` - Data Data `json:"data"` - Status string `json:"status,omitempty"` - - // Telemetry data, must exist to properly trace events between publish receiver & delivery handler - Telemetry *EventTelemetry `json:"telemetry,omitempty"` +type DeliveryTelemetry struct { + TraceID string + SpanID string } var _ mqs.IncomingMessage = &Event{} @@ -66,11 +31,6 @@ func (e *Event) ToMessage() (*mqs.Message, error) { return &mqs.Message{Body: data}, nil } -type DeliveryTelemetry struct { - TraceID string - SpanID string -} - // DeliveryTask represents a task to deliver an event to a destination. // This is a message type (no ID) used by: publishmq -> deliverymq, retry -> deliverymq type DeliveryTask struct { @@ -127,11 +87,6 @@ func NewManualDeliveryTask(event Event, destinationID string) DeliveryTask { return task } -const ( - AttemptStatusSuccess = "success" - AttemptStatusFailed = "failed" -) - // LogEntry represents a message for the log queue. // // IMPORTANT: Both Event and Attempt are REQUIRED. The logstore requires both @@ -155,16 +110,3 @@ func (e *LogEntry) ToMessage() (*mqs.Message, error) { } return &mqs.Message{Body: data}, nil } - -type Attempt struct { - ID string `json:"id"` - TenantID string `json:"tenant_id"` - EventID string `json:"event_id"` - DestinationID string `json:"destination_id"` - AttemptNumber int `json:"attempt_number"` - Manual bool `json:"manual"` - Status string `json:"status"` - Time time.Time `json:"time"` - Code string `json:"code"` - ResponseData map[string]interface{} `json:"response_data"` -} diff --git a/internal/models/tenant.go b/internal/models/tenant.go deleted file mode 100644 index 0b30ef65..00000000 --- a/internal/models/tenant.go +++ /dev/null @@ -1,76 +0,0 @@ -package models - -import ( - "fmt" - "strconv" - "time" -) - -type Tenant struct { - ID string `json:"id" redis:"id"` - DestinationsCount int `json:"destinations_count" redis:"-"` - Topics []string `json:"topics" redis:"-"` - Metadata Metadata `json:"metadata,omitempty" redis:"-"` - CreatedAt time.Time `json:"created_at" redis:"created_at"` - UpdatedAt time.Time `json:"updated_at" redis:"updated_at"` -} - -func (t *Tenant) parseRedisHash(hash map[string]string) error { - if _, ok := hash["deleted_at"]; ok { - return ErrTenantDeleted - } - if hash["id"] == "" { - return fmt.Errorf("missing id") - } - t.ID = hash["id"] - - // Parse created_at - supports both numeric (Unix timestamp) and RFC3339 formats - // This enables lazy migration: old records have RFC3339, new records have Unix timestamps - var err error - t.CreatedAt, err = parseTimestamp(hash["created_at"]) - if err != nil { - return fmt.Errorf("invalid created_at: %w", err) - } - - // Parse updated_at - same lazy migration support - if hash["updated_at"] != "" { - t.UpdatedAt, err = parseTimestamp(hash["updated_at"]) - if err != nil { - // Fallback to created_at if updated_at is invalid - t.UpdatedAt = t.CreatedAt - } - } else { - t.UpdatedAt = t.CreatedAt - } - - // Deserialize metadata if present - if metadataStr, exists := hash["metadata"]; exists && metadataStr != "" { - err = t.Metadata.UnmarshalBinary([]byte(metadataStr)) - if err != nil { - return fmt.Errorf("invalid metadata: %w", err) - } - } - - return nil -} - -// parseTimestamp parses a timestamp from either numeric (Unix milliseconds) or RFC3339 format. -// This supports lazy migration from RFC3339 strings to Unix millisecond timestamps. -func parseTimestamp(value string) (time.Time, error) { - if value == "" { - return time.Time{}, fmt.Errorf("missing timestamp") - } - - // Try to parse as Unix milliseconds (numeric) first - new format - if ts, err := strconv.ParseInt(value, 10, 64); err == nil { - return time.UnixMilli(ts).UTC(), nil - } - - // Fallback to RFC3339Nano (old format) - if t, err := time.Parse(time.RFC3339Nano, value); err == nil { - return t, nil - } - - // Fallback to RFC3339 - return time.Parse(time.RFC3339, value) -} diff --git a/internal/models/tenant_test.go b/internal/models/tenant_test.go deleted file mode 100644 index 26957d2c..00000000 --- a/internal/models/tenant_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package models_test - -import ( - "encoding/json" - "testing" - - "github.com/hookdeck/outpost/internal/models" - "github.com/hookdeck/outpost/internal/util/testutil" - "github.com/stretchr/testify/assert" -) - -func TestTenant_JSONMarshalWithMetadata(t *testing.T) { - t.Parallel() - - tenant := testutil.TenantFactory.Any( - testutil.TenantFactory.WithID("tenant_123"), - testutil.TenantFactory.WithMetadata(map[string]string{ - "environment": "production", - "team": "platform", - "region": "us-east-1", - }), - ) - - // Marshal to JSON - jsonBytes, err := json.Marshal(tenant) - assert.NoError(t, err) - - // Unmarshal back - var unmarshaled models.Tenant - err = json.Unmarshal(jsonBytes, &unmarshaled) - assert.NoError(t, err) - - // Verify metadata is preserved - assert.Equal(t, tenant.Metadata, unmarshaled.Metadata) - assert.Equal(t, "production", unmarshaled.Metadata["environment"]) - assert.Equal(t, "platform", unmarshaled.Metadata["team"]) - assert.Equal(t, "us-east-1", unmarshaled.Metadata["region"]) - - // Verify other fields still work - assert.Equal(t, tenant.ID, unmarshaled.ID) -} - -func TestTenant_JSONMarshalWithoutMetadata(t *testing.T) { - t.Parallel() - - tenant := testutil.TenantFactory.Any( - testutil.TenantFactory.WithID("tenant_123"), - ) - - // Marshal to JSON - jsonBytes, err := json.Marshal(tenant) - assert.NoError(t, err) - - // Unmarshal back - var unmarshaled models.Tenant - err = json.Unmarshal(jsonBytes, &unmarshaled) - assert.NoError(t, err) - - // Verify metadata is nil when not provided - assert.Nil(t, unmarshaled.Metadata) -} - -func TestTenant_JSONUnmarshalEmptyMetadata(t *testing.T) { - t.Parallel() - - jsonData := `{ - "id": "tenant_123", - "destinations_count": 0, - "topics": [], - "metadata": {}, - "created_at": "2024-01-01T00:00:00Z" - }` - - var tenant models.Tenant - err := json.Unmarshal([]byte(jsonData), &tenant) - assert.NoError(t, err) - - // Empty maps should be preserved as empty, not nil - assert.NotNil(t, tenant.Metadata) - assert.Empty(t, tenant.Metadata) -} - -func TestTenant_JSONMarshalWithUpdatedAt(t *testing.T) { - t.Parallel() - - tenant := testutil.TenantFactory.Any( - testutil.TenantFactory.WithID("tenant_123"), - ) - - // Marshal to JSON - jsonBytes, err := json.Marshal(tenant) - assert.NoError(t, err) - - // Unmarshal back - var unmarshaled models.Tenant - err = json.Unmarshal(jsonBytes, &unmarshaled) - assert.NoError(t, err) - - // Verify updated_at is preserved - assert.Equal(t, tenant.UpdatedAt.Unix(), unmarshaled.UpdatedAt.Unix()) - assert.Equal(t, tenant.CreatedAt.Unix(), unmarshaled.CreatedAt.Unix()) -} diff --git a/internal/publishmq/eventhandler.go b/internal/publishmq/eventhandler.go index 740dbb3a..36a4f470 100644 --- a/internal/publishmq/eventhandler.go +++ b/internal/publishmq/eventhandler.go @@ -11,6 +11,7 @@ import ( "github.com/hookdeck/outpost/internal/idempotence" "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/tenantstore" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) @@ -35,14 +36,14 @@ type eventHandler struct { logger *logging.Logger idempotence idempotence.Idempotence deliveryMQ *deliverymq.DeliveryMQ - entityStore models.EntityStore + tenantStore tenantstore.TenantStore topics []string } func NewEventHandler( logger *logging.Logger, deliveryMQ *deliverymq.DeliveryMQ, - entityStore models.EntityStore, + tenantStore tenantstore.TenantStore, eventTracer eventtracer.EventTracer, topics []string, idempotence idempotence.Idempotence, @@ -52,7 +53,7 @@ func NewEventHandler( logger: logger, idempotence: idempotence, deliveryMQ: deliveryMQ, - entityStore: entityStore, + tenantStore: tenantStore, eventTracer: eventTracer, topics: topics, emeter: emeter, @@ -78,7 +79,7 @@ func (h *eventHandler) Handle(ctx context.Context, event *models.Event) (*Handle zap.String("destination_id", event.DestinationID), zap.String("topic", event.Topic)) - var matchedDestinations []models.DestinationSummary + var matchedDestinations []string var err error // Branch: specific destination vs topic-based matching @@ -90,7 +91,7 @@ func (h *eventHandler) Handle(ctx context.Context, event *models.Event) (*Handle } } else { // Topic-based matching path - matchedDestinations, err = h.entityStore.MatchEvent(ctx, *event) + matchedDestinations, err = h.tenantStore.MatchEvent(ctx, *event) if err != nil { logger.Error("failed to match event destinations", zap.Error(err), @@ -132,15 +133,14 @@ func (h *eventHandler) Handle(ctx context.Context, event *models.Event) (*Handle return result, nil } -func (h *eventHandler) doPublish(ctx context.Context, event *models.Event, matchedDestinations []models.DestinationSummary) error { +func (h *eventHandler) doPublish(ctx context.Context, event *models.Event, matchedDestinations []string) error { _, span := h.eventTracer.Receive(ctx, event) defer span.End() h.emeter.EventEligbible(ctx, event) var g errgroup.Group - for _, destinationSummary := range matchedDestinations { - destID := destinationSummary.ID + for _, destID := range matchedDestinations { g.Go(func() error { return h.enqueueDeliveryTask(ctx, models.NewDeliveryTask(*event, destID)) }) @@ -153,27 +153,27 @@ func (h *eventHandler) doPublish(ctx context.Context, event *models.Event, match } // matchSpecificDestination handles the case where a specific destination_id is provided. -// It retrieves the destination and validates it, returning the matched destinations. -func (h *eventHandler) matchSpecificDestination(ctx context.Context, event *models.Event) ([]models.DestinationSummary, error) { - destination, err := h.entityStore.RetrieveDestination(ctx, event.TenantID, event.DestinationID) +// It retrieves the destination and validates it, returning the matched destination IDs. +func (h *eventHandler) matchSpecificDestination(ctx context.Context, event *models.Event) ([]string, error) { + destination, err := h.tenantStore.RetrieveDestination(ctx, event.TenantID, event.DestinationID) if err != nil { h.logger.Ctx(ctx).Warn("failed to retrieve destination", zap.Error(err), zap.String("event_id", event.ID), zap.String("tenant_id", event.TenantID), zap.String("destination_id", event.DestinationID)) - return []models.DestinationSummary{}, nil + return []string{}, nil } if destination == nil { - return []models.DestinationSummary{}, nil + return []string{}, nil } if !destination.MatchEvent(*event) { - return []models.DestinationSummary{}, nil + return []string{}, nil } - return []models.DestinationSummary{*destination.ToSummary()}, nil + return []string{destination.ID}, nil } func (h *eventHandler) enqueueDeliveryTask(ctx context.Context, task models.DeliveryTask) error { diff --git a/internal/publishmq/eventhandler_test.go b/internal/publishmq/eventhandler_test.go index 74b147fe..b7a53b97 100644 --- a/internal/publishmq/eventhandler_test.go +++ b/internal/publishmq/eventhandler_test.go @@ -11,6 +11,7 @@ import ( "github.com/hookdeck/outpost/internal/idgen" "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/publishmq" + "github.com/hookdeck/outpost/internal/tenantstore" "github.com/hookdeck/outpost/internal/util/testinfra" "github.com/hookdeck/outpost/internal/util/testutil" "github.com/stretchr/testify/require" @@ -26,7 +27,7 @@ func TestIntegrationPublishMQEventHandler_Concurrency(t *testing.T) { ctx := context.Background() logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) - entityStore := models.NewEntityStore(redisClient, models.WithAvailableTopics(testutil.TestTopics)) + tenantStore := tenantstore.New(tenantstore.Config{RedisClient: redisClient, AvailableTopics: testutil.TestTopics}) mqConfig := testinfra.NewMQAWSConfig(t, nil) deliveryMQ := deliverymq.New(deliverymq.WithQueue(&mqConfig)) cleanup, err := deliveryMQ.Init(ctx) @@ -34,7 +35,7 @@ func TestIntegrationPublishMQEventHandler_Concurrency(t *testing.T) { defer cleanup() eventHandler := publishmq.NewEventHandler(logger, deliveryMQ, - entityStore, + tenantStore, mockEventTracer, testutil.TestTopics, idempotence.New(testutil.CreateTestRedisClient(t), idempotence.WithSuccessfulTTL(24*time.Hour)), @@ -44,10 +45,10 @@ func TestIntegrationPublishMQEventHandler_Concurrency(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(ctx, tenant) + tenantStore.UpsertTenant(ctx, tenant) destFactory := testutil.DestinationFactory for i := 0; i < 5; i++ { - entityStore.UpsertDestination(ctx, destFactory.Any(destFactory.WithTenantID(tenant.ID))) + tenantStore.UpsertDestination(ctx, destFactory.Any(destFactory.WithTenantID(tenant.ID))) } _, err = eventHandler.Handle(ctx, testutil.EventFactory.AnyPointer( @@ -84,7 +85,7 @@ func TestEventHandler_WildcardTopic(t *testing.T) { ctx := context.Background() logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) - entityStore := models.NewEntityStore(redisClient, models.WithAvailableTopics(testutil.TestTopics)) + tenantStore := tenantstore.New(tenantstore.Config{RedisClient: redisClient, AvailableTopics: testutil.TestTopics}) mqConfig := testinfra.NewMQAWSConfig(t, nil) deliveryMQ := deliverymq.New(deliverymq.WithQueue(&mqConfig)) cleanup, err := deliveryMQ.Init(ctx) @@ -97,7 +98,7 @@ func TestEventHandler_WildcardTopic(t *testing.T) { eventHandler := publishmq.NewEventHandler(logger, deliveryMQ, - entityStore, + tenantStore, mockEventTracer, testutil.TestTopics, idempotence.New(testutil.CreateTestRedisClient(t), idempotence.WithSuccessfulTTL(24*time.Hour)), @@ -107,7 +108,7 @@ func TestEventHandler_WildcardTopic(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - entityStore.UpsertTenant(ctx, tenant) + tenantStore.UpsertTenant(ctx, tenant) // Create destinations with different topics destFactory := testutil.DestinationFactory @@ -126,7 +127,7 @@ func TestEventHandler_WildcardTopic(t *testing.T) { ), } for _, dest := range destinations { - err := entityStore.UpsertDestination(ctx, dest) + err := tenantStore.UpsertDestination(ctx, dest) require.NoError(t, err) } @@ -137,7 +138,7 @@ func TestEventHandler_WildcardTopic(t *testing.T) { ) now := time.Now() disabledDest.DisabledAt = &now - err = entityStore.UpsertDestination(ctx, disabledDest) + err = tenantStore.UpsertDestination(ctx, disabledDest) require.NoError(t, err) // Test publishing with wildcard topic @@ -216,7 +217,7 @@ func TestEventHandler_HandleResult(t *testing.T) { ctx := context.Background() logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) - entityStore := models.NewEntityStore(redisClient, models.WithAvailableTopics(testutil.TestTopics)) + tenantStore := tenantstore.New(tenantstore.Config{RedisClient: redisClient, AvailableTopics: testutil.TestTopics}) mqConfig := testinfra.NewMQAWSConfig(t, nil) deliveryMQ := deliverymq.New(deliverymq.WithQueue(&mqConfig)) cleanup, err := deliveryMQ.Init(ctx) @@ -226,7 +227,7 @@ func TestEventHandler_HandleResult(t *testing.T) { eventHandler := publishmq.NewEventHandler( logger, deliveryMQ, - entityStore, + tenantStore, testutil.NewMockEventTracer(tracetest.NewInMemoryExporter()), testutil.TestTopics, idempotence.New(testutil.CreateTestRedisClient(t), idempotence.WithSuccessfulTTL(24*time.Hour)), @@ -236,13 +237,13 @@ func TestEventHandler_HandleResult(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - require.NoError(t, entityStore.UpsertTenant(ctx, tenant)) + require.NoError(t, tenantStore.UpsertTenant(ctx, tenant)) t.Run("normal publish with matches", func(t *testing.T) { // Create 3 destinations destFactory := testutil.DestinationFactory for i := 0; i < 3; i++ { - require.NoError(t, entityStore.UpsertDestination(ctx, destFactory.Any( + require.NoError(t, tenantStore.UpsertDestination(ctx, destFactory.Any( destFactory.WithTenantID(tenant.ID), destFactory.WithTopics([]string{"user.created"}), ))) @@ -279,7 +280,7 @@ func TestEventHandler_HandleResult(t *testing.T) { testutil.DestinationFactory.WithTenantID(tenant.ID), testutil.DestinationFactory.WithTopics([]string{"user.deleted"}), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), @@ -302,7 +303,7 @@ func TestEventHandler_HandleResult(t *testing.T) { testutil.DestinationFactory.WithTenantID(tenant.ID), testutil.DestinationFactory.WithTopics([]string{"user.updated"}), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), @@ -320,7 +321,7 @@ func TestEventHandler_HandleResult(t *testing.T) { testutil.DestinationFactory.WithTenantID(tenant.ID), testutil.DestinationFactory.WithTopics([]string{"user.created"}), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), @@ -346,7 +347,7 @@ func TestEventHandler_HandleResult(t *testing.T) { ) now := time.Now() dest.DisabledAt = &now - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), @@ -376,7 +377,7 @@ func TestEventHandler_HandleResult(t *testing.T) { testutil.DestinationFactory.WithTenantID(tenant.ID), testutil.DestinationFactory.WithTopics([]string{"order.created"}), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), @@ -396,7 +397,7 @@ func TestEventHandler_Filter(t *testing.T) { ctx := context.Background() logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) - entityStore := models.NewEntityStore(redisClient, models.WithAvailableTopics(testutil.TestTopics)) + tenantStore := tenantstore.New(tenantstore.Config{RedisClient: redisClient, AvailableTopics: testutil.TestTopics}) mqConfig := testinfra.NewMQAWSConfig(t, nil) deliveryMQ := deliverymq.New(deliverymq.WithQueue(&mqConfig)) cleanup, err := deliveryMQ.Init(ctx) @@ -410,7 +411,7 @@ func TestEventHandler_Filter(t *testing.T) { eventHandler := publishmq.NewEventHandler( logger, deliveryMQ, - entityStore, + tenantStore, testutil.NewMockEventTracer(tracetest.NewInMemoryExporter()), testutil.TestTopics, idempotence.New(testutil.CreateTestRedisClient(t), idempotence.WithSuccessfulTTL(24*time.Hour)), @@ -420,7 +421,7 @@ func TestEventHandler_Filter(t *testing.T) { ID: idgen.String(), CreatedAt: time.Now(), } - require.NoError(t, entityStore.UpsertTenant(ctx, tenant)) + require.NoError(t, tenantStore.UpsertTenant(ctx, tenant)) t.Run("topic-based matching with filter - event matches", func(t *testing.T) { // Create destination with filter @@ -433,7 +434,7 @@ func TestEventHandler_Filter(t *testing.T) { }, }), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) // Event that matches the filter event := testutil.EventFactory.AnyPointer( @@ -471,7 +472,7 @@ func TestEventHandler_Filter(t *testing.T) { }, }), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) // Event that does NOT match the filter (amount < 100) event := testutil.EventFactory.AnyPointer( @@ -502,7 +503,7 @@ func TestEventHandler_Filter(t *testing.T) { }, }), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), @@ -538,7 +539,7 @@ func TestEventHandler_Filter(t *testing.T) { }, }), ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) // Event with different currency - should NOT match event := testutil.EventFactory.AnyPointer( @@ -567,7 +568,7 @@ func TestEventHandler_Filter(t *testing.T) { testutil.DestinationFactory.WithTopics([]string{"user.updated"}), // No filter ) - require.NoError(t, entityStore.UpsertDestination(ctx, dest)) + require.NoError(t, tenantStore.UpsertDestination(ctx, dest)) event := testutil.EventFactory.AnyPointer( testutil.EventFactory.WithTenantID(tenant.ID), diff --git a/internal/services/builder.go b/internal/services/builder.go index 4bb2c969..a8be0188 100644 --- a/internal/services/builder.go +++ b/internal/services/builder.go @@ -18,11 +18,11 @@ import ( "github.com/hookdeck/outpost/internal/logging" "github.com/hookdeck/outpost/internal/logmq" "github.com/hookdeck/outpost/internal/logstore" - "github.com/hookdeck/outpost/internal/models" "github.com/hookdeck/outpost/internal/publishmq" "github.com/hookdeck/outpost/internal/redis" "github.com/hookdeck/outpost/internal/scheduler" "github.com/hookdeck/outpost/internal/telemetry" + "github.com/hookdeck/outpost/internal/tenantstore" "github.com/hookdeck/outpost/internal/worker" "go.uber.org/zap" ) @@ -47,7 +47,7 @@ type serviceInstance struct { // Common infrastructure redisClient redis.Client logStore logstore.LogStore - entityStore models.EntityStore + tenantStore tenantstore.TenantStore destRegistry destregistry.Registry eventTracer eventtracer.EventTracer deliveryMQ *deliverymq.DeliveryMQ @@ -172,7 +172,7 @@ func (b *ServiceBuilder) BuildAPIWorkers(baseRouter *gin.Engine) error { if err := svc.initEventTracer(b.cfg, b.logger); err != nil { return err } - if err := svc.initEntityStore(b.ctx, b.cfg, b.logger); err != nil { + if err := svc.initTenantStore(b.ctx, b.cfg, b.logger); err != nil { return err } @@ -188,7 +188,7 @@ func (b *ServiceBuilder) BuildAPIWorkers(baseRouter *gin.Engine) error { idempotence.WithSuccessfulTTL(time.Duration(b.cfg.PublishIdempotencyKeyTTL)*time.Second), idempotence.WithDeploymentID(b.cfg.DeploymentID), ) - eventHandler := publishmq.NewEventHandler(b.logger, svc.deliveryMQ, svc.entityStore, svc.eventTracer, b.cfg.Topics, publishIdempotence) + eventHandler := publishmq.NewEventHandler(b.logger, svc.deliveryMQ, svc.tenantStore, svc.eventTracer, b.cfg.Topics, publishIdempotence) apiHandler := apirouter.NewRouter( apirouter.RouterConfig{ @@ -204,7 +204,7 @@ func (b *ServiceBuilder) BuildAPIWorkers(baseRouter *gin.Engine) error { b.logger, svc.redisClient, svc.deliveryMQ, - svc.entityStore, + svc.tenantStore, svc.logStore, eventHandler, b.telemetry, @@ -263,7 +263,7 @@ func (b *ServiceBuilder) BuildDeliveryWorker(baseRouter *gin.Engine) error { if err := svc.initEventTracer(b.cfg, b.logger); err != nil { return err } - if err := svc.initEntityStore(b.ctx, b.cfg, b.logger); err != nil { + if err := svc.initTenantStore(b.ctx, b.cfg, b.logger); err != nil { return err } if err := svc.initLogStore(b.ctx, b.cfg, b.logger); err != nil { @@ -280,7 +280,7 @@ func (b *ServiceBuilder) BuildDeliveryWorker(baseRouter *gin.Engine) error { alertNotifier = alert.NewHTTPAlertNotifier(b.cfg.Alert.CallbackURL, alert.NotifierWithBearerToken(b.cfg.APIKey)) } if b.cfg.Alert.AutoDisableDestination { - destinationDisabler = newDestinationDisabler(svc.entityStore) + destinationDisabler = newDestinationDisabler(svc.tenantStore) } alertMonitor := alert.NewAlertMonitor( b.logger, @@ -304,7 +304,7 @@ func (b *ServiceBuilder) BuildDeliveryWorker(baseRouter *gin.Engine) error { handler := deliverymq.NewMessageHandler( b.logger, svc.logMQ, - svc.entityStore, + svc.tenantStore, svc.destRegistry, svc.eventTracer, svc.retryScheduler, @@ -403,17 +403,17 @@ func (b *ServiceBuilder) BuildLogWorker(baseRouter *gin.Engine) error { // destinationDisabler implements alert.DestinationDisabler type destinationDisabler struct { - entityStore models.EntityStore + tenantStore tenantstore.TenantStore } -func newDestinationDisabler(entityStore models.EntityStore) alert.DestinationDisabler { +func newDestinationDisabler(tenantStore tenantstore.TenantStore) alert.DestinationDisabler { return &destinationDisabler{ - entityStore: entityStore, + tenantStore: tenantStore, } } func (d *destinationDisabler) DisableDestination(ctx context.Context, tenantID, destinationID string) error { - destination, err := d.entityStore.RetrieveDestination(ctx, tenantID, destinationID) + destination, err := d.tenantStore.RetrieveDestination(ctx, tenantID, destinationID) if err != nil { return err } @@ -422,7 +422,7 @@ func (d *destinationDisabler) DisableDestination(ctx context.Context, tenantID, } now := time.Now() destination.DisabledAt = &now - return d.entityStore.UpsertDestination(ctx, *destination) + return d.tenantStore.UpsertDestination(ctx, *destination) } // Helper methods for serviceInstance to initialize common dependencies @@ -463,19 +463,20 @@ func (s *serviceInstance) initLogStore(ctx context.Context, cfg *config.Config, return nil } -func (s *serviceInstance) initEntityStore(ctx context.Context, cfg *config.Config, logger *logging.Logger) error { +func (s *serviceInstance) initTenantStore(ctx context.Context, cfg *config.Config, logger *logging.Logger) error { if s.redisClient == nil { - return fmt.Errorf("redis client must be initialized before entity store") - } - logger.Debug("creating entity store", zap.String("service", s.name)) - s.entityStore = models.NewEntityStore(s.redisClient, - models.WithCipher(models.NewAESCipher(cfg.AESEncryptionSecret)), - models.WithAvailableTopics(cfg.Topics), - models.WithMaxDestinationsPerTenant(cfg.MaxDestinationsPerTenant), - models.WithDeploymentID(cfg.DeploymentID), - ) - if err := s.entityStore.Init(ctx); err != nil { - return fmt.Errorf("failed to initialize entity store: %w", err) + return fmt.Errorf("redis client must be initialized before tenant store") + } + logger.Debug("creating tenant store", zap.String("service", s.name)) + s.tenantStore = tenantstore.New(tenantstore.Config{ + RedisClient: s.redisClient, + Secret: cfg.AESEncryptionSecret, + AvailableTopics: cfg.Topics, + MaxDestinationsPerTenant: cfg.MaxDestinationsPerTenant, + DeploymentID: cfg.DeploymentID, + }) + if err := s.tenantStore.Init(ctx); err != nil { + return fmt.Errorf("failed to initialize tenant store: %w", err) } return nil } diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go index 576ea087..d380bdb9 100644 --- a/internal/telemetry/telemetry.go +++ b/internal/telemetry/telemetry.go @@ -212,7 +212,7 @@ func (t *telemetryImpl) makeEvent(eventType string, data map[string]interface{}) type ApplicationInfo struct { Version string MQ string - EntityStore string + TenantStore string LogStore string PortalEnabled bool } @@ -221,7 +221,7 @@ func (a *ApplicationInfo) ToData() map[string]interface{} { return map[string]interface{}{ "version": a.Version, "mq": a.MQ, - "entity_store": a.EntityStore, + "tenant_store": a.TenantStore, "log_store": a.LogStore, "portal_enabled": a.PortalEnabled, } diff --git a/internal/tenantstore/driver/driver.go b/internal/tenantstore/driver/driver.go new file mode 100644 index 00000000..3e57c559 --- /dev/null +++ b/internal/tenantstore/driver/driver.go @@ -0,0 +1,77 @@ +// Package driver defines the TenantStore interface and associated types. +package driver + +import ( + "context" + "errors" + + "github.com/hookdeck/outpost/internal/models" +) + +// TenantStore is the interface for tenant and destination storage. +type TenantStore interface { + Init(ctx context.Context) error + RetrieveTenant(ctx context.Context, tenantID string) (*models.Tenant, error) + UpsertTenant(ctx context.Context, tenant models.Tenant) error + DeleteTenant(ctx context.Context, tenantID string) error + ListTenant(ctx context.Context, req ListTenantRequest) (*TenantPaginatedResult, error) + ListDestinationByTenant(ctx context.Context, tenantID string, options ...ListDestinationByTenantOpts) ([]models.Destination, error) + RetrieveDestination(ctx context.Context, tenantID, destinationID string) (*models.Destination, error) + CreateDestination(ctx context.Context, destination models.Destination) error + UpsertDestination(ctx context.Context, destination models.Destination) error + DeleteDestination(ctx context.Context, tenantID, destinationID string) error + MatchEvent(ctx context.Context, event models.Event) ([]string, error) +} + +var ( + ErrTenantNotFound = errors.New("tenant does not exist") + ErrTenantDeleted = errors.New("tenant has been deleted") + ErrDuplicateDestination = errors.New("destination already exists") + ErrDestinationNotFound = errors.New("destination does not exist") + ErrDestinationDeleted = errors.New("destination has been deleted") + ErrMaxDestinationsPerTenantReached = errors.New("maximum number of destinations per tenant reached") + ErrListTenantNotSupported = errors.New("list tenant feature is not enabled") + ErrInvalidCursor = errors.New("invalid cursor") + ErrInvalidOrder = errors.New("invalid order: must be 'asc' or 'desc'") + ErrConflictingCursors = errors.New("cannot specify both next and prev cursors") +) + +// ListTenantRequest contains parameters for listing tenants. +type ListTenantRequest struct { + Limit int // Number of results per page (default: 20) + Next string // Cursor for next page + Prev string // Cursor for previous page + Dir string // Sort direction: "asc" or "desc" (default: "desc") +} + +// SeekPagination represents cursor-based pagination metadata for list responses. +type SeekPagination struct { + OrderBy string `json:"order_by"` + Dir string `json:"dir"` + Limit int `json:"limit"` + Next *string `json:"next"` + Prev *string `json:"prev"` +} + +// TenantPaginatedResult contains the paginated list of tenants. +type TenantPaginatedResult struct { + Models []models.Tenant `json:"models"` + Pagination SeekPagination `json:"pagination"` + Count int `json:"count"` +} + +// ListDestinationByTenantOpts contains options for listing destinations. +type ListDestinationByTenantOpts struct { + Filter *DestinationFilter +} + +// DestinationFilter specifies criteria for filtering destinations. +type DestinationFilter struct { + Type []string + Topics []string +} + +// WithDestinationFilter creates a ListDestinationByTenantOpts with the given filter. +func WithDestinationFilter(filter DestinationFilter) ListDestinationByTenantOpts { + return ListDestinationByTenantOpts{Filter: &filter} +} diff --git a/internal/tenantstore/drivertest/crud.go b/internal/tenantstore/drivertest/crud.go new file mode 100644 index 00000000..68c1f7bb --- /dev/null +++ b/internal/tenantstore/drivertest/crud.go @@ -0,0 +1,548 @@ +package drivertest + +import ( + "context" + "testing" + "time" + + "github.com/hookdeck/outpost/internal/idgen" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/tenantstore/driver" + "github.com/hookdeck/outpost/internal/util/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testCRUD(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + t.Run("InitIdempotency", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + err := store.Init(ctx) + require.NoError(t, err, "Init call %d should not fail", i+1) + } + }) + + t.Run("TenantCRUD", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + now := time.Now() + input := models.Tenant{ + ID: idgen.String(), + CreatedAt: now, + UpdatedAt: now, + } + + t.Run("gets empty", func(t *testing.T) { + actual, err := store.RetrieveTenant(ctx, input.ID) + assert.Nil(t, actual) + assert.NoError(t, err) + }) + + t.Run("sets", func(t *testing.T) { + err := store.UpsertTenant(ctx, input) + require.NoError(t, err) + + retrieved, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Equal(t, input.ID, retrieved.ID) + assertEqualTime(t, input.CreatedAt, retrieved.CreatedAt, "CreatedAt") + }) + + t.Run("gets", func(t *testing.T) { + actual, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Equal(t, input.ID, actual.ID) + assertEqualTime(t, input.CreatedAt, actual.CreatedAt, "CreatedAt") + }) + + t.Run("overrides", func(t *testing.T) { + input.CreatedAt = time.Now() + err := store.UpsertTenant(ctx, input) + require.NoError(t, err) + + actual, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Equal(t, input.ID, actual.ID) + assertEqualTime(t, input.CreatedAt, actual.CreatedAt, "CreatedAt") + }) + + t.Run("clears", func(t *testing.T) { + require.NoError(t, store.DeleteTenant(ctx, input.ID)) + + actual, err := store.RetrieveTenant(ctx, input.ID) + assert.ErrorIs(t, err, driver.ErrTenantDeleted) + assert.Nil(t, actual) + }) + + t.Run("deletes again", func(t *testing.T) { + assert.NoError(t, store.DeleteTenant(ctx, input.ID)) + }) + + t.Run("deletes non-existent", func(t *testing.T) { + assert.ErrorIs(t, store.DeleteTenant(ctx, "non-existent-tenant"), driver.ErrTenantNotFound) + }) + + t.Run("creates & overrides deleted resource", func(t *testing.T) { + require.NoError(t, store.UpsertTenant(ctx, input)) + + actual, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Equal(t, input.ID, actual.ID) + assertEqualTime(t, input.CreatedAt, actual.CreatedAt, "CreatedAt") + }) + + t.Run("upserts with metadata", func(t *testing.T) { + input.Metadata = map[string]string{ + "environment": "production", + "team": "platform", + } + err := store.UpsertTenant(ctx, input) + require.NoError(t, err) + + retrieved, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Equal(t, input.ID, retrieved.ID) + assert.Equal(t, input.Metadata, retrieved.Metadata) + }) + + t.Run("updates metadata", func(t *testing.T) { + input.Metadata = map[string]string{ + "environment": "staging", + "team": "engineering", + "region": "us-west-2", + } + err := store.UpsertTenant(ctx, input) + require.NoError(t, err) + + retrieved, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Equal(t, input.Metadata, retrieved.Metadata) + }) + + t.Run("handles nil metadata", func(t *testing.T) { + input.Metadata = nil + err := store.UpsertTenant(ctx, input) + require.NoError(t, err) + + retrieved, err := store.RetrieveTenant(ctx, input.ID) + require.NoError(t, err) + assert.Nil(t, retrieved.Metadata) + }) + + t.Run("sets updated_at on create", func(t *testing.T) { + newTenant := testutil.TenantFactory.Any() + err := store.UpsertTenant(ctx, newTenant) + require.NoError(t, err) + + retrieved, err := store.RetrieveTenant(ctx, newTenant.ID) + require.NoError(t, err) + assertEqualTime(t, newTenant.UpdatedAt, retrieved.UpdatedAt, "UpdatedAt") + }) + + t.Run("updates updated_at on upsert", func(t *testing.T) { + originalTime := time.Now().Add(-2 * time.Second).Truncate(time.Millisecond) + updatedTime := originalTime.Add(1 * time.Second) + + original := testutil.TenantFactory.Any( + testutil.TenantFactory.WithUpdatedAt(originalTime), + ) + err := store.UpsertTenant(ctx, original) + require.NoError(t, err) + + updated := original + updated.UpdatedAt = updatedTime + err = store.UpsertTenant(ctx, updated) + require.NoError(t, err) + + retrieved, err := store.RetrieveTenant(ctx, updated.ID) + require.NoError(t, err) + assert.True(t, retrieved.UpdatedAt.After(originalTime) || retrieved.UpdatedAt.Equal(originalTime.Add(time.Second))) + }) + }) + + t.Run("DestinationCRUD", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + now := time.Now() + input := models.Destination{ + ID: idgen.Destination(), + Type: "rabbitmq", + Topics: []string{"user.created", "user.updated"}, + Config: map[string]string{ + "server_url": "localhost:5672", + "exchange": "events", + }, + Credentials: map[string]string{ + "username": "guest", + "password": "guest", + }, + DeliveryMetadata: map[string]string{ + "app-id": "test-app", + "source": "outpost", + }, + Metadata: map[string]string{ + "environment": "test", + "team": "platform", + }, + CreatedAt: now, + UpdatedAt: now, + DisabledAt: nil, + TenantID: idgen.String(), + } + + t.Run("gets empty", func(t *testing.T) { + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assert.Nil(t, actual) + }) + + t.Run("sets", func(t *testing.T) { + err := store.CreateDestination(ctx, input) + require.NoError(t, err) + }) + + t.Run("gets", func(t *testing.T) { + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assertEqualDestination(t, input, *actual) + }) + + t.Run("updates", func(t *testing.T) { + input.Topics = []string{"*"} + input.DeliveryMetadata = map[string]string{ + "app-id": "updated-app", + "version": "2.0", + } + input.Metadata = map[string]string{ + "environment": "staging", + } + err := store.UpsertDestination(ctx, input) + require.NoError(t, err) + + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assertEqualDestination(t, input, *actual) + }) + + t.Run("clears", func(t *testing.T) { + err := store.DeleteDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + assert.ErrorIs(t, err, driver.ErrDestinationDeleted) + assert.Nil(t, actual) + }) + + t.Run("creates & overrides deleted resource", func(t *testing.T) { + err := store.CreateDestination(ctx, input) + require.NoError(t, err) + + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assertEqualDestination(t, input, *actual) + }) + + t.Run("err when creates duplicate", func(t *testing.T) { + assert.ErrorIs(t, store.CreateDestination(ctx, input), driver.ErrDuplicateDestination) + require.NoError(t, store.DeleteDestination(ctx, input.TenantID, input.ID)) + }) + + t.Run("handles nil delivery_metadata and metadata", func(t *testing.T) { + inputWithNilFields := testutil.DestinationFactory.Any() + err := store.CreateDestination(ctx, inputWithNilFields) + require.NoError(t, err) + + actual, err := store.RetrieveDestination(ctx, inputWithNilFields.TenantID, inputWithNilFields.ID) + require.NoError(t, err) + assert.Nil(t, actual.DeliveryMetadata) + assert.Nil(t, actual.Metadata) + require.NoError(t, store.DeleteDestination(ctx, inputWithNilFields.TenantID, inputWithNilFields.ID)) + }) + + t.Run("sets updated_at on create", func(t *testing.T) { + newDest := testutil.DestinationFactory.Any() + err := store.CreateDestination(ctx, newDest) + require.NoError(t, err) + + retrieved, err := store.RetrieveDestination(ctx, newDest.TenantID, newDest.ID) + require.NoError(t, err) + assertEqualTime(t, newDest.UpdatedAt, retrieved.UpdatedAt, "UpdatedAt") + require.NoError(t, store.DeleteDestination(ctx, newDest.TenantID, newDest.ID)) + }) + + t.Run("updates updated_at on upsert", func(t *testing.T) { + originalTime := time.Now().Add(-2 * time.Second).Truncate(time.Millisecond) + updatedTime := originalTime.Add(1 * time.Second) + + original := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithUpdatedAt(originalTime), + ) + err := store.CreateDestination(ctx, original) + require.NoError(t, err) + + updated := original + updated.UpdatedAt = updatedTime + updated.Topics = []string{"updated.topic"} + err = store.UpsertDestination(ctx, updated) + require.NoError(t, err) + + retrieved, err := store.RetrieveDestination(ctx, updated.TenantID, updated.ID) + require.NoError(t, err) + assert.True(t, retrieved.UpdatedAt.After(originalTime) || retrieved.UpdatedAt.Equal(originalTime.Add(time.Second))) + require.NoError(t, store.DeleteDestination(ctx, updated.TenantID, updated.ID)) + }) + }) + + t.Run("ListDestinationEmpty", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + destinations, err := store.ListDestinationByTenant(ctx, idgen.String()) + require.NoError(t, err) + assert.Empty(t, destinations) + }) + + t.Run("DeleteTenantAndAssociatedDestinations", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + tenant := models.Tenant{ + ID: idgen.String(), + CreatedAt: time.Now(), + } + require.NoError(t, store.UpsertTenant(ctx, tenant)) + + destinationIDs := []string{idgen.Destination(), idgen.Destination(), idgen.Destination()} + for _, id := range destinationIDs { + require.NoError(t, store.UpsertDestination(ctx, testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(id), + testutil.DestinationFactory.WithTenantID(tenant.ID), + ))) + } + + require.NoError(t, store.DeleteTenant(ctx, tenant.ID)) + + _, err = store.RetrieveTenant(ctx, tenant.ID) + assert.ErrorIs(t, err, driver.ErrTenantDeleted) + for _, id := range destinationIDs { + _, err := store.RetrieveDestination(ctx, tenant.ID, id) + assert.ErrorIs(t, err, driver.ErrDestinationDeleted) + } + }) + + t.Run("DeleteDestination", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + destination := testutil.DestinationFactory.Any() + require.NoError(t, store.CreateDestination(ctx, destination)) + + t.Run("no error when deleting existing", func(t *testing.T) { + assert.NoError(t, store.DeleteDestination(ctx, destination.TenantID, destination.ID)) + }) + + t.Run("no error when deleting already-deleted", func(t *testing.T) { + assert.NoError(t, store.DeleteDestination(ctx, destination.TenantID, destination.ID)) + }) + + t.Run("error when deleting non-existent", func(t *testing.T) { + err := store.DeleteDestination(ctx, destination.TenantID, idgen.Destination()) + assert.ErrorIs(t, err, driver.ErrDestinationNotFound) + }) + + t.Run("returns ErrDestinationDeleted when retrieving deleted", func(t *testing.T) { + dest, err := store.RetrieveDestination(ctx, destination.TenantID, destination.ID) + assert.ErrorIs(t, err, driver.ErrDestinationDeleted) + assert.Nil(t, dest) + }) + + t.Run("does not return deleted in list", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, destination.TenantID) + assert.NoError(t, err) + assert.Empty(t, destinations) + }) + }) + + t.Run("DestinationEnableDisable", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + input := testutil.DestinationFactory.Any() + require.NoError(t, store.UpsertDestination(ctx, input)) + + t.Run("should disable", func(t *testing.T) { + now := time.Now() + input.DisabledAt = &now + require.NoError(t, store.UpsertDestination(ctx, input)) + + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assertEqualTimePtr(t, input.DisabledAt, actual.DisabledAt, "DisabledAt") + }) + + t.Run("should enable", func(t *testing.T) { + input.DisabledAt = nil + require.NoError(t, store.UpsertDestination(ctx, input)) + + actual, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assertEqualTimePtr(t, input.DisabledAt, actual.DisabledAt, "DisabledAt") + }) + }) + + t.Run("FilterPersistence", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + tenant := models.Tenant{ID: idgen.String()} + require.NoError(t, store.UpsertTenant(ctx, tenant)) + + t.Run("stores and retrieves destination with filter", func(t *testing.T) { + destination := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{"type": "order.created"}, + }), + ) + require.NoError(t, store.CreateDestination(ctx, destination)) + + retrieved, err := store.RetrieveDestination(ctx, tenant.ID, destination.ID) + require.NoError(t, err) + assert.NotNil(t, retrieved.Filter) + assert.Equal(t, "order.created", retrieved.Filter["data"].(map[string]any)["type"]) + }) + + t.Run("stores destination with nil filter", func(t *testing.T) { + destination := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + ) + require.NoError(t, store.CreateDestination(ctx, destination)) + + retrieved, err := store.RetrieveDestination(ctx, tenant.ID, destination.ID) + require.NoError(t, err) + assert.Nil(t, retrieved.Filter) + }) + + t.Run("updates destination filter", func(t *testing.T) { + destination := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{"type": "order.created"}, + }), + ) + require.NoError(t, store.CreateDestination(ctx, destination)) + + destination.Filter = models.Filter{ + "data": map[string]any{"type": "order.updated"}, + } + require.NoError(t, store.UpsertDestination(ctx, destination)) + + retrieved, err := store.RetrieveDestination(ctx, tenant.ID, destination.ID) + require.NoError(t, err) + assert.Equal(t, "order.updated", retrieved.Filter["data"].(map[string]any)["type"]) + }) + + t.Run("removes destination filter", func(t *testing.T) { + destination := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{"type": "order.created"}, + }), + ) + require.NoError(t, store.CreateDestination(ctx, destination)) + + destination.Filter = nil + require.NoError(t, store.UpsertDestination(ctx, destination)) + + retrieved, err := store.RetrieveDestination(ctx, tenant.ID, destination.ID) + require.NoError(t, err) + assert.Nil(t, retrieved.Filter) + }) + }) +} + +// assertEqualTime compares two times by truncating to millisecond precision. +func assertEqualTime(t *testing.T, expected, actual time.Time, field string) { + t.Helper() + expectedTrunc := expected.Truncate(time.Millisecond) + actualTrunc := actual.Truncate(time.Millisecond) + assert.True(t, expectedTrunc.Equal(actualTrunc), + "expected %s %v, got %v", field, expectedTrunc, actualTrunc) +} + +// assertEqualTimePtr compares two optional times by truncating to millisecond precision. +func assertEqualTimePtr(t *testing.T, expected, actual *time.Time, field string) { + t.Helper() + if expected == nil { + assert.Nil(t, actual, "%s should be nil", field) + return + } + require.NotNil(t, actual, "%s should not be nil", field) + assertEqualTime(t, *expected, *actual, field) +} + +// assertEqualDestination compares two destinations field-by-field. +func assertEqualDestination(t *testing.T, expected, actual models.Destination) { + t.Helper() + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.Type, actual.Type) + assert.Equal(t, expected.Topics, actual.Topics) + assert.Equal(t, expected.Filter, actual.Filter) + assert.Equal(t, expected.Config, actual.Config) + assert.Equal(t, expected.Credentials, actual.Credentials) + assert.Equal(t, expected.DeliveryMetadata, actual.DeliveryMetadata) + assert.Equal(t, expected.Metadata, actual.Metadata) + assertEqualTime(t, expected.CreatedAt, actual.CreatedAt, "CreatedAt") + assertEqualTime(t, expected.UpdatedAt, actual.UpdatedAt, "UpdatedAt") + assertEqualTimePtr(t, expected.DisabledAt, actual.DisabledAt, "DisabledAt") +} diff --git a/internal/tenantstore/drivertest/drivertest.go b/internal/tenantstore/drivertest/drivertest.go new file mode 100644 index 00000000..5ac0bf3b --- /dev/null +++ b/internal/tenantstore/drivertest/drivertest.go @@ -0,0 +1,61 @@ +// Package drivertest provides a conformance test suite for tenantstore drivers. +package drivertest + +import ( + "context" + "testing" + + "github.com/hookdeck/outpost/internal/tenantstore/driver" +) + +// Harness provides the test infrastructure for a tenantstore driver implementation. +type Harness interface { + // MakeDriver creates a driver with default settings. + MakeDriver(ctx context.Context) (driver.TenantStore, error) + // MakeDriverWithMaxDest creates a driver with a specific max destinations limit. + MakeDriverWithMaxDest(ctx context.Context, maxDest int) (driver.TenantStore, error) + // MakeIsolatedDrivers creates two drivers that share the same backend + // but are isolated from each other (e.g., different deployment IDs). + MakeIsolatedDrivers(ctx context.Context) (store1, store2 driver.TenantStore, err error) + Close() +} + +// HarnessMaker creates a new Harness for each test. +type HarnessMaker func(ctx context.Context, t *testing.T) (Harness, error) + +// RunConformanceTests executes the core conformance test suite for a tenantstore driver. +// The suite is organized into four parts: +// - CRUD: tenant and destination create/read/update/delete +// - List: destination listing and filtering operations +// - Match: event matching operations +// - Misc: max destinations, deployment isolation +func RunConformanceTests(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + t.Run("CRUD", func(t *testing.T) { + testCRUD(t, newHarness) + }) + t.Run("List", func(t *testing.T) { + testList(t, newHarness) + }) + t.Run("Match", func(t *testing.T) { + testMatch(t, newHarness) + }) + t.Run("Misc", func(t *testing.T) { + testMisc(t, newHarness) + }) +} + +// RunListTenantTests executes the ListTenant test suite, which requires RediSearch. +// Run this only on backends that support RediSearch (Redis Stack, Dragonfly Stack). +// The suite covers: +// - Enrichment: tenant list includes destinations count and topics +// - ExcludesDeleted: deleted tenants are excluded from results +// - InputValidation: limit, order direction, cursor validation +// - KeysetPagination: cursor-based pagination edge cases +// - PaginationSuite: comprehensive forward/backward/round-trip pagination +func RunListTenantTests(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + testListTenant(t, newHarness) +} diff --git a/internal/tenantstore/drivertest/list.go b/internal/tenantstore/drivertest/list.go new file mode 100644 index 00000000..5a42a6e9 --- /dev/null +++ b/internal/tenantstore/drivertest/list.go @@ -0,0 +1,569 @@ +package drivertest + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/hookdeck/outpost/internal/idgen" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/pagination/paginationtest" + "github.com/hookdeck/outpost/internal/tenantstore/driver" + "github.com/hookdeck/outpost/internal/util/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// multiDestinationData holds shared test data for multi-destination tests. +type multiDestinationData struct { + tenant models.Tenant + destinations []models.Destination +} + +func setupMultiDestination(t *testing.T, ctx context.Context, store driver.TenantStore) multiDestinationData { + t.Helper() + data := multiDestinationData{ + tenant: models.Tenant{ + ID: idgen.String(), + CreatedAt: time.Now(), + }, + destinations: make([]models.Destination, 5), + } + require.NoError(t, store.UpsertTenant(ctx, data.tenant)) + + destinationTopicList := [][]string{ + {"*"}, + {"user.created"}, + {"user.updated"}, + {"user.deleted"}, + {"user.created", "user.updated"}, + } + baseTime := time.Now().Add(-10 * time.Second).Truncate(time.Second) + for i := 0; i < 5; i++ { + data.destinations[i] = testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(idgen.Destination()), + testutil.DestinationFactory.WithTenantID(data.tenant.ID), + testutil.DestinationFactory.WithTopics(destinationTopicList[i]), + testutil.DestinationFactory.WithCreatedAt(baseTime.Add(time.Duration(i)*time.Second)), + ) + require.NoError(t, store.UpsertDestination(ctx, data.destinations[i])) + } + + // Insert & Delete destination to ensure cleanup + toBeDeletedID := idgen.Destination() + require.NoError(t, store.UpsertDestination(ctx, + testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(toBeDeletedID), + testutil.DestinationFactory.WithTenantID(data.tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + ))) + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, toBeDeletedID)) + + return data +} + +func testList(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + t.Run("MultiDestinationRetrieveTenantDestinationsCount", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + tenant, err := store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, 5, tenant.DestinationsCount) + }) + + t.Run("MultiDestinationRetrieveTenantTopics", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + // destinations[0] has topics ["*"] + tenant, err := store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, []string{"*"}, tenant.Topics) + + // After deleting wildcard destination, topics should aggregate remaining + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, data.destinations[0].ID)) + tenant, err = store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, []string{"user.created", "user.deleted", "user.updated"}, tenant.Topics) + + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, data.destinations[1].ID)) + tenant, err = store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, []string{"user.created", "user.deleted", "user.updated"}, tenant.Topics) + + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, data.destinations[2].ID)) + tenant, err = store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, []string{"user.created", "user.deleted", "user.updated"}, tenant.Topics) + + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, data.destinations[3].ID)) + tenant, err = store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, []string{"user.created", "user.updated"}, tenant.Topics) + + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, data.destinations[4].ID)) + tenant, err = store.RetrieveTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Equal(t, []string{}, tenant.Topics) + }) + + t.Run("MultiDestinationListDestinationByTenant", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID) + require.NoError(t, err) + require.Len(t, destinations, 5) + for index, destination := range destinations { + require.Equal(t, data.destinations[index].ID, destination.ID) + } + }) + + t.Run("MultiDestinationListDestinationWithOpts", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + t.Run("filter by type: webhook", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Type: []string{"webhook"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 5) + }) + + t.Run("filter by type: rabbitmq", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Type: []string{"rabbitmq"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 0) + }) + + t.Run("filter by type: webhook,rabbitmq", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Type: []string{"webhook", "rabbitmq"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 5) + }) + + t.Run("filter by topic: user.created", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Topics: []string{"user.created"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 3) + }) + + t.Run("filter by topic: user.created,user.updated", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Topics: []string{"user.created", "user.updated"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 2) + }) + + t.Run("filter by type: rabbitmq, topic: user.created,user.updated", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Type: []string{"rabbitmq"}, + Topics: []string{"user.created", "user.updated"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 0) + }) + + t.Run("filter by topic: *", func(t *testing.T) { + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID, driver.WithDestinationFilter(driver.DestinationFilter{ + Topics: []string{"*"}, + })) + require.NoError(t, err) + require.Len(t, destinations, 1) + }) + }) + +} + +func testListTenant(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + t.Run("Enrichment", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + require.NoError(t, store.Init(ctx)) + + // Create 25 tenants + tenants := make([]models.Tenant, 25) + baseTime := time.Now() + for i := range tenants { + tenants[i] = testutil.TenantFactory.Any( + testutil.TenantFactory.WithCreatedAt(baseTime.Add(time.Duration(i)*time.Second)), + testutil.TenantFactory.WithUpdatedAt(baseTime.Add(time.Duration(i)*time.Second)), + ) + require.NoError(t, store.UpsertTenant(ctx, tenants[i])) + } + tenantWithDests := tenants[24] + for i := range 2 { + dest := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(fmt.Sprintf("dest_suite_%d", i)), + testutil.DestinationFactory.WithTenantID(tenantWithDests.ID), + ) + require.NoError(t, store.UpsertDestination(ctx, dest)) + } + + t.Run("returns total count independent of pagination", func(t *testing.T) { + resp1, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 2}) + require.NoError(t, err) + assert.Equal(t, 25, resp1.Count) + assert.Len(t, resp1.Models, 2) + + var nextCursor string + if resp1.Pagination.Next != nil { + nextCursor = *resp1.Pagination.Next + } + resp2, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 2, Next: nextCursor}) + require.NoError(t, err) + assert.Equal(t, 25, resp2.Count) + }) + + t.Run("does not include destinations in tenant list", func(t *testing.T) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 100}) + require.NoError(t, err) + assert.Equal(t, 25, resp.Count) + for _, tenant := range resp.Models { + assert.NotContains(t, tenant.ID, "dest_") + } + }) + + t.Run("returns destinations_count and topics", func(t *testing.T) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 100}) + require.NoError(t, err) + + var found *models.Tenant + for i := range resp.Models { + if resp.Models[i].ID == tenantWithDests.ID { + found = &resp.Models[i] + break + } + } + require.NotNil(t, found) + assert.Equal(t, 2, found.DestinationsCount) + assert.NotNil(t, found.Topics) + + var without *models.Tenant + for i := range resp.Models { + if resp.Models[i].ID != tenantWithDests.ID { + without = &resp.Models[i] + break + } + } + require.NotNil(t, without) + assert.Equal(t, 0, without.DestinationsCount) + assert.Empty(t, without.Topics) + }) + }) + + t.Run("ExcludesDeleted", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + require.NoError(t, store.Init(ctx)) + + // Create initial tenants + for i := 0; i < 5; i++ { + require.NoError(t, store.UpsertTenant(ctx, testutil.TenantFactory.Any())) + } + + t.Run("deleted tenant not returned", func(t *testing.T) { + initialResp, err := store.ListTenant(ctx, driver.ListTenantRequest{}) + require.NoError(t, err) + initialCount := initialResp.Count + + tenant1 := testutil.TenantFactory.Any() + tenant2 := testutil.TenantFactory.Any() + require.NoError(t, store.UpsertTenant(ctx, tenant1)) + require.NoError(t, store.UpsertTenant(ctx, tenant2)) + + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{}) + require.NoError(t, err) + assert.Equal(t, initialCount+2, resp.Count) + + require.NoError(t, store.DeleteTenant(ctx, tenant1.ID)) + + resp, err = store.ListTenant(ctx, driver.ListTenantRequest{}) + require.NoError(t, err) + assert.Equal(t, initialCount+1, resp.Count) + + for _, tenant := range resp.Models { + assert.NotEqual(t, tenant1.ID, tenant.ID) + } + + _ = store.DeleteTenant(ctx, tenant2.ID) + }) + }) + + t.Run("InputValidation", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + require.NoError(t, store.Init(ctx)) + + // Create 25 tenants for pagination tests + tenants := make([]models.Tenant, 25) + baseTime := time.Now() + for i := range tenants { + tenants[i] = testutil.TenantFactory.Any( + testutil.TenantFactory.WithCreatedAt(baseTime.Add(time.Duration(i)*time.Second)), + testutil.TenantFactory.WithUpdatedAt(baseTime.Add(time.Duration(i)*time.Second)), + ) + require.NoError(t, store.UpsertTenant(ctx, tenants[i])) + } + + t.Run("invalid dir returns error", func(t *testing.T) { + _, err := store.ListTenant(ctx, driver.ListTenantRequest{Dir: "invalid"}) + require.Error(t, err) + assert.ErrorIs(t, err, driver.ErrInvalidOrder) + }) + + t.Run("conflicting cursors returns error", func(t *testing.T) { + _, err := store.ListTenant(ctx, driver.ListTenantRequest{ + Next: "somecursor", + Prev: "anothercursor", + }) + require.Error(t, err) + assert.ErrorIs(t, err, driver.ErrConflictingCursors) + }) + + t.Run("invalid next cursor returns error", func(t *testing.T) { + _, err := store.ListTenant(ctx, driver.ListTenantRequest{ + Next: "not-valid-base62!!!", + }) + require.Error(t, err) + assert.ErrorIs(t, err, driver.ErrInvalidCursor) + }) + + t.Run("invalid prev cursor returns error", func(t *testing.T) { + _, err := store.ListTenant(ctx, driver.ListTenantRequest{ + Prev: "not-valid-base62!!!", + }) + require.Error(t, err) + assert.ErrorIs(t, err, driver.ErrInvalidCursor) + }) + + t.Run("malformed cursor format returns error", func(t *testing.T) { + _, err := store.ListTenant(ctx, driver.ListTenantRequest{ + Next: "abc123", + }) + require.Error(t, err) + assert.ErrorIs(t, err, driver.ErrInvalidCursor) + }) + + t.Run("limit zero uses default", func(t *testing.T) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 0}) + require.NoError(t, err) + assert.Equal(t, 20, len(resp.Models)) + assert.Equal(t, 25, resp.Count) + }) + + t.Run("limit negative uses default", func(t *testing.T) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: -5}) + require.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("limit exceeding max is capped", func(t *testing.T) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 1000}) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 25, len(resp.Models)) + assert.Equal(t, 25, resp.Count) + }) + + t.Run("empty dir uses default desc", func(t *testing.T) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{Dir: ""}) + require.NoError(t, err) + require.Len(t, resp.Models, 20) + assert.Equal(t, tenants[24].ID, resp.Models[0].ID) + assert.Equal(t, tenants[23].ID, resp.Models[1].ID) + }) + }) + + t.Run("KeysetPagination", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + require.NoError(t, store.Init(ctx)) + + t.Run("add during traversal does not cause duplicate", func(t *testing.T) { + prefix := fmt.Sprintf("add_edge_%d_", time.Now().UnixNano()) + tenantIDs := make([]string, 15) + baseTime := time.Now().Add(20 * time.Hour) + for i := 0; i < 15; i++ { + tenantIDs[i] = fmt.Sprintf("%s%02d", prefix, i) + tenant := models.Tenant{ + ID: tenantIDs[i], + CreatedAt: baseTime.Add(time.Duration(i) * time.Second), + UpdatedAt: baseTime, + } + require.NoError(t, store.UpsertTenant(ctx, tenant)) + } + + resp1, err := store.ListTenant(ctx, driver.ListTenantRequest{Limit: 5}) + require.NoError(t, err) + require.Len(t, resp1.Models, 5) + + // Add new tenant with newest timestamp + newTenantID := prefix + "NEW" + newTenant := models.Tenant{ + ID: newTenantID, + CreatedAt: baseTime.Add(time.Hour), + UpdatedAt: baseTime, + } + require.NoError(t, store.UpsertTenant(ctx, newTenant)) + tenantIDs = append(tenantIDs, newTenantID) + + var nextCursor string + if resp1.Pagination.Next != nil { + nextCursor = *resp1.Pagination.Next + } + resp2, err := store.ListTenant(ctx, driver.ListTenantRequest{ + Limit: 5, + Next: nextCursor, + }) + require.NoError(t, err) + require.NotEmpty(t, resp2.Models) + + page1IDs := make(map[string]bool) + for _, tenant := range resp1.Models { + page1IDs[tenant.ID] = true + } + for _, tenant := range resp2.Models { + assert.False(t, page1IDs[tenant.ID], + "keyset pagination: no duplicates, but found %s", tenant.ID) + } + + // Cleanup + for _, id := range tenantIDs { + _ = store.DeleteTenant(ctx, id) + } + }) + }) + + t.Run("PaginationSuite", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + require.NoError(t, store.Init(ctx)) + + var createdTenantIDs []string + baseTime := time.Now() + + suite := paginationtest.Suite[models.Tenant]{ + Name: "ListTenant", + + NewItem: func(index int) models.Tenant { + return models.Tenant{ + ID: fmt.Sprintf("tenant_pagination_%d_%d", time.Now().UnixNano(), index), + CreatedAt: baseTime.Add(time.Duration(index) * time.Second), + UpdatedAt: baseTime.Add(time.Duration(index) * time.Second), + } + }, + + InsertMany: func(ctx context.Context, items []models.Tenant) error { + for _, item := range items { + if err := store.UpsertTenant(ctx, item); err != nil { + return err + } + createdTenantIDs = append(createdTenantIDs, item.ID) + } + return nil + }, + + List: func(ctx context.Context, opts paginationtest.ListOpts) (paginationtest.ListResult[models.Tenant], error) { + resp, err := store.ListTenant(ctx, driver.ListTenantRequest{ + Limit: opts.Limit, + Dir: opts.Order, + Next: opts.Next, + Prev: opts.Prev, + }) + if err != nil { + return paginationtest.ListResult[models.Tenant]{}, err + } + var next, prev string + if resp.Pagination.Next != nil { + next = *resp.Pagination.Next + } + if resp.Pagination.Prev != nil { + prev = *resp.Pagination.Prev + } + return paginationtest.ListResult[models.Tenant]{ + Items: resp.Models, + Next: next, + Prev: prev, + }, nil + }, + + GetID: func(t models.Tenant) string { + return t.ID + }, + + Cleanup: func(ctx context.Context) error { + for _, id := range createdTenantIDs { + _ = store.DeleteTenant(ctx, id) + } + createdTenantIDs = nil + return nil + }, + } + + suite.Run(t) + }) +} diff --git a/internal/tenantstore/drivertest/match.go b/internal/tenantstore/drivertest/match.go new file mode 100644 index 00000000..bd6e64d9 --- /dev/null +++ b/internal/tenantstore/drivertest/match.go @@ -0,0 +1,326 @@ +package drivertest + +import ( + "context" + "testing" + "time" + + "github.com/hookdeck/outpost/internal/idgen" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/util/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testMatch(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + t.Run("MatchByTopic", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + t.Run("match by topic", func(t *testing.T) { + event := models.Event{ + ID: idgen.Event(), + Topic: "user.created", + Time: time.Now(), + TenantID: data.tenant.ID, + Metadata: map[string]string{}, + Data: map[string]interface{}{}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 3) + for _, id := range matched { + require.Contains(t, []string{data.destinations[0].ID, data.destinations[1].ID, data.destinations[4].ID}, id) + } + }) + + t.Run("ignores destination_id and matches by topic only", func(t *testing.T) { + event := models.Event{ + ID: idgen.Event(), + Topic: "user.created", + Time: time.Now(), + TenantID: data.tenant.ID, + DestinationID: data.destinations[1].ID, + Metadata: map[string]string{}, + Data: map[string]interface{}{}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 3) + }) + + t.Run("ignores non-existent destination_id", func(t *testing.T) { + event := models.Event{ + ID: idgen.Event(), + Topic: "user.created", + Time: time.Now(), + TenantID: data.tenant.ID, + DestinationID: "not-found", + Metadata: map[string]string{}, + Data: map[string]interface{}{}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 3) + }) + + t.Run("ignores destination_id with mismatched topic", func(t *testing.T) { + event := models.Event{ + ID: idgen.Event(), + Topic: "user.created", + Time: time.Now(), + TenantID: data.tenant.ID, + DestinationID: data.destinations[3].ID, // user.deleted + Metadata: map[string]string{}, + Data: map[string]interface{}{}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 3) + }) + + t.Run("match after destination is updated", func(t *testing.T) { + updatedDestination := data.destinations[2] // user.updated + updatedDestination.Topics = []string{"user.created"} + require.NoError(t, store.UpsertDestination(ctx, updatedDestination)) + + actual, err := store.RetrieveDestination(ctx, updatedDestination.TenantID, updatedDestination.ID) + require.NoError(t, err) + assert.Equal(t, updatedDestination.Topics, actual.Topics) + + destinations, err := store.ListDestinationByTenant(ctx, data.tenant.ID) + require.NoError(t, err) + assert.Len(t, destinations, 5) + + // Match user.created (now 4 destinations match) + event := models.Event{ + ID: idgen.Event(), + Topic: "user.created", + Time: time.Now(), + TenantID: data.tenant.ID, + Metadata: map[string]string{}, + Data: map[string]interface{}{}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 4) + + // Match user.updated (now only 2: wildcard + destinations[4]) + event = models.Event{ + ID: idgen.Event(), + Topic: "user.updated", + Time: time.Now(), + TenantID: data.tenant.ID, + Metadata: map[string]string{}, + Data: map[string]interface{}{}, + } + matched, err = store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 2) + for _, id := range matched { + require.Contains(t, []string{data.destinations[0].ID, data.destinations[4].ID}, id) + } + }) + }) + + t.Run("MatchEventWithFilter", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + + tenant := models.Tenant{ID: idgen.String()} + require.NoError(t, store.UpsertTenant(ctx, tenant)) + + destNoFilter := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID("dest_no_filter"), + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + ) + destFilterOrderCreated := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID("dest_filter_order_created"), + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{"type": "order.created"}, + }), + ) + destFilterOrderUpdated := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID("dest_filter_order_updated"), + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{"type": "order.updated"}, + }), + ) + destFilterPremiumCustomer := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID("dest_filter_premium"), + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"*"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{ + "customer": map[string]any{"tier": "premium"}, + }, + }), + ) + + require.NoError(t, store.CreateDestination(ctx, destNoFilter)) + require.NoError(t, store.CreateDestination(ctx, destFilterOrderCreated)) + require.NoError(t, store.CreateDestination(ctx, destFilterOrderUpdated)) + require.NoError(t, store.CreateDestination(ctx, destFilterPremiumCustomer)) + + t.Run("event matches only destinations with matching filter", func(t *testing.T) { + event := models.Event{ + ID: idgen.Event(), + TenantID: tenant.ID, + Topic: "order", + Time: time.Now(), + Metadata: map[string]string{}, + Data: map[string]interface{}{"type": "order.created"}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + assert.Len(t, matched, 2) + assert.Contains(t, matched, "dest_no_filter") + assert.Contains(t, matched, "dest_filter_order_created") + }) + + t.Run("event with nested data matches nested filter", func(t *testing.T) { + event := models.Event{ + ID: idgen.Event(), + TenantID: tenant.ID, + Topic: "order", + Time: time.Now(), + Metadata: map[string]string{}, + Data: map[string]interface{}{ + "type": "order.created", + "customer": map[string]interface{}{ + "id": "cust_123", + "tier": "premium", + }, + }, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + assert.Len(t, matched, 3) + assert.Contains(t, matched, "dest_no_filter") + assert.Contains(t, matched, "dest_filter_order_created") + assert.Contains(t, matched, "dest_filter_premium") + }) + + t.Run("topic filter takes precedence before content filter", func(t *testing.T) { + destTopicAndFilter := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID("dest_topic_and_filter"), + testutil.DestinationFactory.WithTenantID(tenant.ID), + testutil.DestinationFactory.WithTopics([]string{"user.created"}), + testutil.DestinationFactory.WithFilter(models.Filter{ + "data": map[string]any{"type": "order.created"}, + }), + ) + require.NoError(t, store.CreateDestination(ctx, destTopicAndFilter)) + + event := models.Event{ + ID: idgen.Event(), + TenantID: tenant.ID, + Topic: "order", + Time: time.Now(), + Metadata: map[string]string{}, + Data: map[string]interface{}{"type": "order.created"}, + } + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + for _, id := range matched { + assert.NotEqual(t, "dest_topic_and_filter", id) + } + }) + }) + + t.Run("DisableAndMatch", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + t.Run("initial match user.deleted", func(t *testing.T) { + event := testutil.EventFactory.Any( + testutil.EventFactory.WithTenantID(data.tenant.ID), + testutil.EventFactory.WithTopic("user.deleted"), + ) + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 2) + for _, id := range matched { + require.Contains(t, []string{data.destinations[0].ID, data.destinations[3].ID}, id) + } + }) + + t.Run("should not match disabled destination", func(t *testing.T) { + destination := data.destinations[0] + now := time.Now() + destination.DisabledAt = &now + require.NoError(t, store.UpsertDestination(ctx, destination)) + + event := testutil.EventFactory.Any( + testutil.EventFactory.WithTenantID(data.tenant.ID), + testutil.EventFactory.WithTopic("user.deleted"), + ) + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 1) + require.Equal(t, data.destinations[3].ID, matched[0]) + }) + + t.Run("should match after re-enabled destination", func(t *testing.T) { + destination := data.destinations[0] + destination.DisabledAt = nil + require.NoError(t, store.UpsertDestination(ctx, destination)) + + event := testutil.EventFactory.Any( + testutil.EventFactory.WithTenantID(data.tenant.ID), + testutil.EventFactory.WithTopic("user.deleted"), + ) + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 2) + }) + }) + + t.Run("DeleteAndMatch", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store, err := h.MakeDriver(ctx) + require.NoError(t, err) + data := setupMultiDestination(t, ctx, store) + + require.NoError(t, store.DeleteDestination(ctx, data.tenant.ID, data.destinations[0].ID)) + + event := testutil.EventFactory.Any( + testutil.EventFactory.WithTenantID(data.tenant.ID), + testutil.EventFactory.WithTopic("user.created"), + ) + matched, err := store.MatchEvent(ctx, event) + require.NoError(t, err) + require.Len(t, matched, 2) + for _, id := range matched { + require.Contains(t, []string{data.destinations[1].ID, data.destinations[4].ID}, id) + } + }) +} diff --git a/internal/tenantstore/drivertest/misc.go b/internal/tenantstore/drivertest/misc.go new file mode 100644 index 00000000..804b60de --- /dev/null +++ b/internal/tenantstore/drivertest/misc.go @@ -0,0 +1,129 @@ +package drivertest + +import ( + "context" + "testing" + "time" + + "github.com/hookdeck/outpost/internal/idgen" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/tenantstore/driver" + "github.com/hookdeck/outpost/internal/util/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testMisc(t *testing.T, newHarness HarnessMaker) { + t.Helper() + + t.Run("MaxDestinationsPerTenant", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + maxDestinations := 2 + store, err := h.MakeDriverWithMaxDest(ctx, maxDestinations) + require.NoError(t, err) + + tenant := models.Tenant{ + ID: idgen.String(), + CreatedAt: time.Now(), + } + require.NoError(t, store.UpsertTenant(ctx, tenant)) + + // Should be able to create up to maxDestinations + for i := 0; i < maxDestinations; i++ { + destination := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + ) + err := store.CreateDestination(ctx, destination) + require.NoError(t, err, "Should be able to create destination %d", i+1) + } + + // Should fail when trying to create one more + destination := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + ) + err = store.CreateDestination(ctx, destination) + require.Error(t, err) + require.ErrorIs(t, err, driver.ErrMaxDestinationsPerTenantReached) + + // Should be able to create after deleting one + destinations, err := store.ListDestinationByTenant(ctx, tenant.ID) + require.NoError(t, err) + require.NoError(t, store.DeleteDestination(ctx, tenant.ID, destinations[0].ID)) + + destination = testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithTenantID(tenant.ID), + ) + err = store.CreateDestination(ctx, destination) + require.NoError(t, err, "Should be able to create destination after deleting one") + }) + + t.Run("DeploymentIsolation", func(t *testing.T) { + ctx := context.Background() + h, err := newHarness(ctx, t) + require.NoError(t, err) + t.Cleanup(h.Close) + + store1, store2, err := h.MakeIsolatedDrivers(ctx) + require.NoError(t, err) + + // Use same tenant ID and destination ID for both + tenantID := idgen.String() + destinationID := idgen.Destination() + + tenant := models.Tenant{ + ID: tenantID, + CreatedAt: time.Now(), + } + require.NoError(t, store1.UpsertTenant(ctx, tenant)) + require.NoError(t, store2.UpsertTenant(ctx, tenant)) + + // Create destination with different config in each + destination1 := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(destinationID), + testutil.DestinationFactory.WithTenantID(tenantID), + testutil.DestinationFactory.WithConfig(map[string]string{"deployment": "dp_001"}), + ) + destination2 := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(destinationID), + testutil.DestinationFactory.WithTenantID(tenantID), + testutil.DestinationFactory.WithConfig(map[string]string{"deployment": "dp_002"}), + ) + + require.NoError(t, store1.CreateDestination(ctx, destination1)) + require.NoError(t, store2.CreateDestination(ctx, destination2)) + + // Verify each store sees its own data + retrieved1, err := store1.RetrieveDestination(ctx, tenantID, destinationID) + require.NoError(t, err) + assert.Equal(t, "dp_001", retrieved1.Config["deployment"]) + + retrieved2, err := store2.RetrieveDestination(ctx, tenantID, destinationID) + require.NoError(t, err) + assert.Equal(t, "dp_002", retrieved2.Config["deployment"]) + + // Verify list operations are isolated + list1, err := store1.ListDestinationByTenant(ctx, tenantID) + require.NoError(t, err) + require.Len(t, list1, 1) + assert.Equal(t, "dp_001", list1[0].Config["deployment"]) + + list2, err := store2.ListDestinationByTenant(ctx, tenantID) + require.NoError(t, err) + require.Len(t, list2, 1) + assert.Equal(t, "dp_002", list2[0].Config["deployment"]) + + // Verify deleting from one doesn't affect the other + require.NoError(t, store1.DeleteDestination(ctx, tenantID, destinationID)) + + _, err = store1.RetrieveDestination(ctx, tenantID, destinationID) + require.ErrorIs(t, err, driver.ErrDestinationDeleted) + + retrieved2Again, err := store2.RetrieveDestination(ctx, tenantID, destinationID) + require.NoError(t, err) + assert.Equal(t, "dp_002", retrieved2Again.Config["deployment"]) + }) +} diff --git a/internal/tenantstore/memtenantstore/memtenantstore.go b/internal/tenantstore/memtenantstore/memtenantstore.go new file mode 100644 index 00000000..aa9f6ace --- /dev/null +++ b/internal/tenantstore/memtenantstore/memtenantstore.go @@ -0,0 +1,449 @@ +// Package memtenantstore provides an in-memory implementation of driver.TenantStore. +package memtenantstore + +import ( + "context" + "fmt" + "slices" + "sort" + "strconv" + "sync" + "time" + + "github.com/hookdeck/outpost/internal/cursor" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/pagination" + "github.com/hookdeck/outpost/internal/tenantstore/driver" +) + +const defaultMaxDestinationsPerTenant = 20 + +const ( + defaultListTenantLimit = 20 + maxListTenantLimit = 100 +) + +type tenantRecord struct { + tenant models.Tenant + deletedAt *time.Time +} + +type destinationRecord struct { + destination models.Destination + deletedAt *time.Time +} + +type store struct { + mu sync.RWMutex + + tenants map[string]*tenantRecord // tenantID -> record + destinations map[string]*destinationRecord // "tenantID\x00destID" -> record + destsByTenant map[string]map[string]struct{} // tenantID -> set of destIDs + + maxDestinationsPerTenant int +} + +var _ driver.TenantStore = (*store)(nil) + +// Option configures a memtenantstore. +type Option func(*store) + +// WithMaxDestinationsPerTenant sets the max destinations per tenant. +func WithMaxDestinationsPerTenant(max int) Option { + return func(s *store) { + s.maxDestinationsPerTenant = max + } +} + +// New creates a new in-memory TenantStore. +func New(opts ...Option) driver.TenantStore { + s := &store{ + tenants: make(map[string]*tenantRecord), + destinations: make(map[string]*destinationRecord), + destsByTenant: make(map[string]map[string]struct{}), + maxDestinationsPerTenant: defaultMaxDestinationsPerTenant, + } + for _, opt := range opts { + opt(s) + } + return s +} + +func destKey(tenantID, destID string) string { + return tenantID + "\x00" + destID +} + +func (s *store) Init(_ context.Context) error { + return nil +} + +func (s *store) RetrieveTenant(_ context.Context, tenantID string) (*models.Tenant, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + rec, ok := s.tenants[tenantID] + if !ok { + return nil, nil + } + if rec.deletedAt != nil { + return nil, driver.ErrTenantDeleted + } + + t := rec.tenant + destIDs := s.destsByTenant[tenantID] + t.DestinationsCount = len(destIDs) + t.Topics = s.computeTenantTopics(tenantID) + return &t, nil +} + +func (s *store) UpsertTenant(_ context.Context, tenant models.Tenant) error { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + if tenant.CreatedAt.IsZero() { + tenant.CreatedAt = now + } + if tenant.UpdatedAt.IsZero() { + tenant.UpdatedAt = now + } + + s.tenants[tenant.ID] = &tenantRecord{tenant: tenant} + return nil +} + +func (s *store) DeleteTenant(_ context.Context, tenantID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + rec, ok := s.tenants[tenantID] + if !ok { + return driver.ErrTenantNotFound + } + // Already deleted is OK (idempotent) + now := time.Now() + rec.deletedAt = &now + + // Delete all destinations + if destIDs, ok := s.destsByTenant[tenantID]; ok { + for destID := range destIDs { + if drec, ok := s.destinations[destKey(tenantID, destID)]; ok { + drec.deletedAt = &now + } + } + delete(s.destsByTenant, tenantID) + } + + return nil +} + +func (s *store) ListTenant(ctx context.Context, req driver.ListTenantRequest) (*driver.TenantPaginatedResult, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if req.Next != "" && req.Prev != "" { + return nil, driver.ErrConflictingCursors + } + + limit := req.Limit + if limit <= 0 { + limit = defaultListTenantLimit + } + if limit > maxListTenantLimit { + limit = maxListTenantLimit + } + + dir := req.Dir + if dir == "" { + dir = "desc" + } + if dir != "asc" && dir != "desc" { + return nil, driver.ErrInvalidOrder + } + + // Collect non-deleted tenants + var activeTenants []models.Tenant + for _, rec := range s.tenants { + if rec.deletedAt != nil { + continue + } + activeTenants = append(activeTenants, rec.tenant) + } + totalCount := len(activeTenants) + + result, err := pagination.Run(ctx, pagination.Config[models.Tenant]{ + Limit: limit, + Order: dir, + Next: req.Next, + Prev: req.Prev, + Cursor: pagination.Cursor[models.Tenant]{ + Encode: func(t models.Tenant) string { + return cursor.Encode("tnt", 1, strconv.FormatInt(t.CreatedAt.UnixMilli(), 10)) + }, + Decode: func(c string) (string, error) { + data, err := cursor.Decode(c, "tnt", 1) + if err != nil { + return "", fmt.Errorf("%w: %v", driver.ErrInvalidCursor, err) + } + return data, nil + }, + }, + Fetch: func(_ context.Context, q pagination.QueryInput) ([]models.Tenant, error) { + return s.fetchTenants(activeTenants, q) + }, + }) + if err != nil { + return nil, err + } + + tenants := result.Items + + // Enrich with DestinationsCount and Topics + for i := range tenants { + destIDs := s.destsByTenant[tenants[i].ID] + tenants[i].DestinationsCount = len(destIDs) + tenants[i].Topics = s.computeTenantTopics(tenants[i].ID) + } + + var nextCursor, prevCursor *string + if result.Next != "" { + nextCursor = &result.Next + } + if result.Prev != "" { + prevCursor = &result.Prev + } + + return &driver.TenantPaginatedResult{ + Models: tenants, + Pagination: driver.SeekPagination{ + OrderBy: "created_at", + Dir: dir, + Limit: limit, + Next: nextCursor, + Prev: prevCursor, + }, + Count: totalCount, + }, nil +} + +func (s *store) fetchTenants(activeTenants []models.Tenant, q pagination.QueryInput) ([]models.Tenant, error) { + var filtered []models.Tenant + + if q.CursorPos == "" { + filtered = append(filtered, activeTenants...) + } else { + cursorTs, err := strconv.ParseInt(q.CursorPos, 10, 64) + if err != nil { + return nil, fmt.Errorf("%w: invalid timestamp", driver.ErrInvalidCursor) + } + for _, t := range activeTenants { + ts := t.CreatedAt.UnixMilli() + if q.Compare == "<" && ts < cursorTs { + filtered = append(filtered, t) + } else if q.Compare == ">" && ts > cursorTs { + filtered = append(filtered, t) + } + } + } + + // Sort + if q.SortDir == "desc" { + sort.Slice(filtered, func(i, j int) bool { + return filtered[i].CreatedAt.After(filtered[j].CreatedAt) + }) + } else { + sort.Slice(filtered, func(i, j int) bool { + return filtered[i].CreatedAt.Before(filtered[j].CreatedAt) + }) + } + + // Apply limit + if len(filtered) > q.Limit { + filtered = filtered[:q.Limit] + } + + return filtered, nil +} + +func (s *store) ListDestinationByTenant(_ context.Context, tenantID string, options ...driver.ListDestinationByTenantOpts) ([]models.Destination, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var opts driver.ListDestinationByTenantOpts + if len(options) > 0 { + opts = options[0] + } + + destIDs := s.destsByTenant[tenantID] + if len(destIDs) == 0 { + return []models.Destination{}, nil + } + + var destinations []models.Destination + for destID := range destIDs { + drec, ok := s.destinations[destKey(tenantID, destID)] + if !ok || drec.deletedAt != nil { + continue + } + if opts.Filter != nil && !matchDestFilter(opts.Filter, drec.destination) { + continue + } + destinations = append(destinations, drec.destination) + } + + sort.Slice(destinations, func(i, j int) bool { + return destinations[i].CreatedAt.Before(destinations[j].CreatedAt) + }) + + return destinations, nil +} + +func (s *store) RetrieveDestination(_ context.Context, tenantID, destinationID string) (*models.Destination, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + drec, ok := s.destinations[destKey(tenantID, destinationID)] + if !ok { + return nil, nil + } + if drec.deletedAt != nil { + return nil, driver.ErrDestinationDeleted + } + d := drec.destination + return &d, nil +} + +func (s *store) CreateDestination(_ context.Context, destination models.Destination) error { + s.mu.Lock() + defer s.mu.Unlock() + + key := destKey(destination.TenantID, destination.ID) + + // Check for existing non-deleted destination + if drec, ok := s.destinations[key]; ok && drec.deletedAt == nil { + return driver.ErrDuplicateDestination + } + + // Check max destinations + destIDs := s.destsByTenant[destination.TenantID] + if len(destIDs) >= s.maxDestinationsPerTenant { + return driver.ErrMaxDestinationsPerTenantReached + } + + return s.upsertDestinationLocked(destination) +} + +func (s *store) UpsertDestination(_ context.Context, destination models.Destination) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.upsertDestinationLocked(destination) +} + +func (s *store) upsertDestinationLocked(destination models.Destination) error { + now := time.Now() + if destination.CreatedAt.IsZero() { + destination.CreatedAt = now + } + if destination.UpdatedAt.IsZero() { + destination.UpdatedAt = now + } + + key := destKey(destination.TenantID, destination.ID) + s.destinations[key] = &destinationRecord{destination: destination} + + // Update destsByTenant index + if s.destsByTenant[destination.TenantID] == nil { + s.destsByTenant[destination.TenantID] = make(map[string]struct{}) + } + s.destsByTenant[destination.TenantID][destination.ID] = struct{}{} + return nil +} + +func (s *store) DeleteDestination(_ context.Context, tenantID, destinationID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + key := destKey(tenantID, destinationID) + drec, ok := s.destinations[key] + if !ok { + return driver.ErrDestinationNotFound + } + // Already deleted is OK (idempotent) + now := time.Now() + drec.deletedAt = &now + + // Remove from destsByTenant index + if destIDs, ok := s.destsByTenant[tenantID]; ok { + delete(destIDs, destinationID) + } + + return nil +} + +func (s *store) MatchEvent(_ context.Context, event models.Event) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + destIDs := s.destsByTenant[event.TenantID] + var matched []string + for destID := range destIDs { + drec, ok := s.destinations[destKey(event.TenantID, destID)] + if !ok || drec.deletedAt != nil { + continue + } + if drec.destination.MatchEvent(event) { + matched = append(matched, destID) + } + } + return matched, nil +} + +func (s *store) computeTenantTopics(tenantID string) []string { + destIDs := s.destsByTenant[tenantID] + all := false + topicsSet := make(map[string]struct{}) + for destID := range destIDs { + drec, ok := s.destinations[destKey(tenantID, destID)] + if !ok || drec.deletedAt != nil { + continue + } + for _, topic := range drec.destination.Topics { + if topic == "*" { + all = true + break + } + topicsSet[topic] = struct{}{} + } + } + + if all { + return []string{"*"} + } + + topics := make([]string, 0, len(topicsSet)) + for topic := range topicsSet { + topics = append(topics, topic) + } + sort.Strings(topics) + return topics +} + +func matchDestFilter(filter *driver.DestinationFilter, dest models.Destination) bool { + if len(filter.Type) > 0 && !slices.Contains(filter.Type, dest.Type) { + return false + } + if len(filter.Topics) > 0 { + filterMatchesAll := len(filter.Topics) == 1 && filter.Topics[0] == "*" + if !dest.Topics.MatchesAll() { + if filterMatchesAll { + return false + } + for _, topic := range filter.Topics { + if !slices.Contains(dest.Topics, topic) { + return false + } + } + } + } + return true +} diff --git a/internal/tenantstore/memtenantstore/memtenantstore_test.go b/internal/tenantstore/memtenantstore/memtenantstore_test.go new file mode 100644 index 00000000..e23edc68 --- /dev/null +++ b/internal/tenantstore/memtenantstore/memtenantstore_test.go @@ -0,0 +1,33 @@ +package memtenantstore + +import ( + "context" + "testing" + + "github.com/hookdeck/outpost/internal/tenantstore/driver" + "github.com/hookdeck/outpost/internal/tenantstore/drivertest" +) + +type memTenantStoreHarness struct{} + +func (h *memTenantStoreHarness) MakeDriver(_ context.Context) (driver.TenantStore, error) { + return New(), nil +} + +func (h *memTenantStoreHarness) MakeDriverWithMaxDest(_ context.Context, maxDest int) (driver.TenantStore, error) { + return New(WithMaxDestinationsPerTenant(maxDest)), nil +} + +func (h *memTenantStoreHarness) MakeIsolatedDrivers(_ context.Context) (driver.TenantStore, driver.TenantStore, error) { + return New(), New(), nil +} + +func (h *memTenantStoreHarness) Close() {} + +func newHarness(_ context.Context, _ *testing.T) (drivertest.Harness, error) { + return &memTenantStoreHarness{}, nil +} + +func TestMemTenantStoreConformance(t *testing.T) { + drivertest.RunConformanceTests(t, newHarness) +} diff --git a/internal/models/encryption.go b/internal/tenantstore/redistenantstore/cipher.go similarity index 68% rename from internal/models/encryption.go rename to internal/tenantstore/redistenantstore/cipher.go index ec9cef03..59a12d17 100644 --- a/internal/models/encryption.go +++ b/internal/tenantstore/redistenantstore/cipher.go @@ -1,4 +1,4 @@ -package models +package redistenantstore import ( "crypto/aes" @@ -9,18 +9,11 @@ import ( "io" ) -type Cipher interface { - Encrypt(data []byte) ([]byte, error) - Decrypt(data []byte) ([]byte, error) -} - -type AESCipher struct { +type aesCipher struct { secret string } -var _ Cipher = (*AESCipher)(nil) - -func (a *AESCipher) Encrypt(toBeEncrypted []byte) ([]byte, error) { +func (a *aesCipher) encrypt(toBeEncrypted []byte) ([]byte, error) { aead, err := a.aead() if err != nil { return nil, err @@ -33,11 +26,10 @@ func (a *AESCipher) Encrypt(toBeEncrypted []byte) ([]byte, error) { } encrypted := aead.Seal(nonce, nonce, toBeEncrypted, nil) - return encrypted, nil } -func (a *AESCipher) Decrypt(toBeDecrypted []byte) ([]byte, error) { +func (a *aesCipher) decrypt(toBeDecrypted []byte) ([]byte, error) { aead, err := a.aead() if err != nil { return nil, err @@ -54,7 +46,7 @@ func (a *AESCipher) Decrypt(toBeDecrypted []byte) ([]byte, error) { return decrypted, nil } -func (a *AESCipher) aead() (cipher.AEAD, error) { +func (a *aesCipher) aead() (cipher.AEAD, error) { aesBlock, err := aes.NewCipher([]byte(mdHashing(a.secret))) if err != nil { return nil, err @@ -62,10 +54,8 @@ func (a *AESCipher) aead() (cipher.AEAD, error) { return cipher.NewGCM(aesBlock) } -func NewAESCipher(secret string) Cipher { - return &AESCipher{ - secret: secret, - } +func newAESCipher(secret string) *aesCipher { + return &aesCipher{secret: secret} } func mdHashing(input string) string { diff --git a/internal/tenantstore/redistenantstore/redistenantstore.go b/internal/tenantstore/redistenantstore/redistenantstore.go new file mode 100644 index 00000000..ae0c0983 --- /dev/null +++ b/internal/tenantstore/redistenantstore/redistenantstore.go @@ -0,0 +1,606 @@ +// Package redistenantstore provides a Redis-backed implementation of driver.TenantStore. +package redistenantstore + +import ( + "context" + "errors" + "fmt" + "sort" + "strconv" + "time" + + "github.com/hookdeck/outpost/internal/cursor" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/pagination" + "github.com/hookdeck/outpost/internal/redis" + "github.com/hookdeck/outpost/internal/tenantstore/driver" +) + +const defaultMaxDestinationsPerTenant = 20 + +const ( + defaultListTenantLimit = 20 + maxListTenantLimit = 100 +) + +type store struct { + redisClient redis.Cmdable + cipher *aesCipher + availableTopics []string + maxDestinationsPerTenant int + deploymentID string + listTenantSupported bool +} + +var _ driver.TenantStore = (*store)(nil) + +// Option configures a redistenantstore. +type Option func(*store) + +// WithSecret sets the encryption secret for credentials. +func WithSecret(secret string) Option { + return func(s *store) { + s.cipher = newAESCipher(secret) + } +} + +// WithAvailableTopics sets the available topics for destination validation. +func WithAvailableTopics(topics []string) Option { + return func(s *store) { + s.availableTopics = topics + } +} + +// WithMaxDestinationsPerTenant sets the maximum number of destinations per tenant. +func WithMaxDestinationsPerTenant(max int) Option { + return func(s *store) { + s.maxDestinationsPerTenant = max + } +} + +// WithDeploymentID sets the deployment ID for key isolation. +func WithDeploymentID(deploymentID string) Option { + return func(s *store) { + s.deploymentID = deploymentID + } +} + +// New creates a new Redis-backed TenantStore. +func New(redisClient redis.Cmdable, opts ...Option) driver.TenantStore { + s := &store{ + redisClient: redisClient, + cipher: newAESCipher(""), + availableTopics: []string{}, + maxDestinationsPerTenant: defaultMaxDestinationsPerTenant, + } + for _, opt := range opts { + opt(s) + } + return s +} + +// doCmd executes an arbitrary Redis command using the Do method. +func (s *store) doCmd(ctx context.Context, args ...interface{}) *redis.Cmd { + if dc, ok := s.redisClient.(redis.DoContext); ok { + return dc.Do(ctx, args...) + } + cmd := &redis.Cmd{} + cmd.SetErr(errors.New("redis client does not support Do command")) + return cmd +} + +func (s *store) deploymentPrefix() string { + if s.deploymentID == "" { + return "" + } + return fmt.Sprintf("%s:", s.deploymentID) +} + +func (s *store) redisTenantID(tenantID string) string { + return fmt.Sprintf("%stenant:{%s}:tenant", s.deploymentPrefix(), tenantID) +} + +func (s *store) redisTenantDestinationSummaryKey(tenantID string) string { + return fmt.Sprintf("%stenant:{%s}:destinations", s.deploymentPrefix(), tenantID) +} + +func (s *store) redisDestinationID(destinationID, tenantID string) string { + return fmt.Sprintf("%stenant:{%s}:destination:%s", s.deploymentPrefix(), tenantID, destinationID) +} + +func (s *store) tenantIndexName() string { + return s.deploymentPrefix() + "tenant_idx" +} + +func (s *store) tenantKeyPrefix() string { + return s.deploymentPrefix() + "tenant:" +} + +// Init initializes the store, probing for RediSearch support. +func (s *store) Init(ctx context.Context) error { + _, err := s.doCmd(ctx, "FT._LIST").Result() + if err != nil { + s.listTenantSupported = false + return nil + } + + if err := s.ensureTenantIndex(ctx); err != nil { + s.listTenantSupported = false + return nil + } + + s.listTenantSupported = true + return nil +} + +func (s *store) ensureTenantIndex(ctx context.Context) error { + indexName := s.tenantIndexName() + + _, err := s.doCmd(ctx, "FT.INFO", indexName).Result() + if err == nil { + return nil + } + + prefix := s.tenantKeyPrefix() + _, err = s.doCmd(ctx, "FT.CREATE", indexName, + "ON", "HASH", + "PREFIX", "1", prefix, + "FILTER", `@entity == "tenant"`, + "SCHEMA", + "id", "TAG", + "entity", "TAG", + "created_at", "NUMERIC", "SORTABLE", + "deleted_at", "NUMERIC", + ).Result() + + if err != nil { + return fmt.Errorf("failed to create tenant index: %w", err) + } + + return nil +} + +func (s *store) RetrieveTenant(ctx context.Context, tenantID string) (*models.Tenant, error) { + pipe := s.redisClient.Pipeline() + tenantCmd := pipe.HGetAll(ctx, s.redisTenantID(tenantID)) + destinationListCmd := pipe.HGetAll(ctx, s.redisTenantDestinationSummaryKey(tenantID)) + + if _, err := pipe.Exec(ctx); err != nil { + return nil, err + } + + tenantHash, err := tenantCmd.Result() + if err != nil { + return nil, err + } + if len(tenantHash) == 0 { + return nil, nil + } + tenant, err := parseTenantHash(tenantHash) + if err != nil { + return nil, err + } + + destinationSummaryList, err := parseListDestinationSummaryByTenantCmd(destinationListCmd, driver.ListDestinationByTenantOpts{}) + if err != nil { + return nil, err + } + tenant.DestinationsCount = len(destinationSummaryList) + tenant.Topics = parseTenantTopics(destinationSummaryList) + + return tenant, nil +} + +func (s *store) UpsertTenant(ctx context.Context, tenant models.Tenant) error { + key := s.redisTenantID(tenant.ID) + + if err := s.redisClient.Persist(ctx, key).Err(); err != nil && err != redis.Nil { + return err + } + + if err := s.redisClient.HDel(ctx, key, "deleted_at").Err(); err != nil && err != redis.Nil { + return err + } + + now := time.Now() + if tenant.CreatedAt.IsZero() { + tenant.CreatedAt = now + } + if tenant.UpdatedAt.IsZero() { + tenant.UpdatedAt = now + } + + if err := s.redisClient.HSet(ctx, key, + "id", tenant.ID, + "entity", "tenant", + "created_at", tenant.CreatedAt.UnixMilli(), + "updated_at", tenant.UpdatedAt.UnixMilli(), + ).Err(); err != nil { + return err + } + + if tenant.Metadata != nil { + if err := s.redisClient.HSet(ctx, key, "metadata", &tenant.Metadata).Err(); err != nil { + return err + } + } else { + if err := s.redisClient.HDel(ctx, key, "metadata").Err(); err != nil && err != redis.Nil { + return err + } + } + + return nil +} + +func (s *store) DeleteTenant(ctx context.Context, tenantID string) error { + if exists, err := s.redisClient.Exists(ctx, s.redisTenantID(tenantID)).Result(); err != nil { + return err + } else if exists == 0 { + return driver.ErrTenantNotFound + } + + destinationIDs, err := s.redisClient.HKeys(ctx, s.redisTenantDestinationSummaryKey(tenantID)).Result() + if err != nil { + return err + } + + _, err = s.redisClient.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + nowUnixMilli := time.Now().UnixMilli() + + for _, destinationID := range destinationIDs { + destKey := s.redisDestinationID(destinationID, tenantID) + pipe.HSet(ctx, destKey, "deleted_at", nowUnixMilli) + pipe.Expire(ctx, destKey, 7*24*time.Hour) + } + + pipe.Del(ctx, s.redisTenantDestinationSummaryKey(tenantID)) + pipe.HSet(ctx, s.redisTenantID(tenantID), "deleted_at", nowUnixMilli) + pipe.Expire(ctx, s.redisTenantID(tenantID), 7*24*time.Hour) + + return nil + }) + + return err +} + +func (s *store) ListTenant(ctx context.Context, req driver.ListTenantRequest) (*driver.TenantPaginatedResult, error) { + if !s.listTenantSupported { + return nil, driver.ErrListTenantNotSupported + } + + if req.Next != "" && req.Prev != "" { + return nil, driver.ErrConflictingCursors + } + + limit := req.Limit + if limit <= 0 { + limit = defaultListTenantLimit + } + if limit > maxListTenantLimit { + limit = maxListTenantLimit + } + + dir := req.Dir + if dir == "" { + dir = "desc" + } + if dir != "asc" && dir != "desc" { + return nil, driver.ErrInvalidOrder + } + + baseFilter := "@entity:{tenant} -@deleted_at:[1 +inf]" + + result, err := pagination.Run(ctx, pagination.Config[models.Tenant]{ + Limit: limit, + Order: dir, + Next: req.Next, + Prev: req.Prev, + Cursor: pagination.Cursor[models.Tenant]{ + Encode: func(t models.Tenant) string { + return cursor.Encode("tnt", 1, strconv.FormatInt(t.CreatedAt.UnixMilli(), 10)) + }, + Decode: func(c string) (string, error) { + data, err := cursor.Decode(c, "tnt", 1) + if err != nil { + return "", fmt.Errorf("%w: %v", driver.ErrInvalidCursor, err) + } + return data, nil + }, + }, + Fetch: func(ctx context.Context, q pagination.QueryInput) ([]models.Tenant, error) { + return s.fetchTenants(ctx, baseFilter, q) + }, + }) + if err != nil { + return nil, err + } + + tenants := result.Items + + if len(tenants) > 0 { + pipe := s.redisClient.Pipeline() + cmds := make([]*redis.MapStringStringCmd, len(tenants)) + for i, t := range tenants { + cmds[i] = pipe.HGetAll(ctx, s.redisTenantDestinationSummaryKey(t.ID)) + } + if _, err := pipe.Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to fetch destination summaries: %w", err) + } + + for i := range tenants { + destinationSummaryList, err := parseListDestinationSummaryByTenantCmd(cmds[i], driver.ListDestinationByTenantOpts{}) + if err != nil { + return nil, err + } + tenants[i].DestinationsCount = len(destinationSummaryList) + tenants[i].Topics = parseTenantTopics(destinationSummaryList) + } + } + + var totalCount int + countResult, err := s.doCmd(ctx, "FT.SEARCH", s.tenantIndexName(), + baseFilter, + "LIMIT", 0, 0, + ).Result() + if err == nil { + _, totalCount, _ = parseSearchResult(countResult) + } + + var nextCursor, prevCursor *string + if result.Next != "" { + nextCursor = &result.Next + } + if result.Prev != "" { + prevCursor = &result.Prev + } + + return &driver.TenantPaginatedResult{ + Models: tenants, + Pagination: driver.SeekPagination{ + OrderBy: "created_at", + Dir: dir, + Limit: limit, + Next: nextCursor, + Prev: prevCursor, + }, + Count: totalCount, + }, nil +} + +func (s *store) fetchTenants(ctx context.Context, baseFilter string, q pagination.QueryInput) ([]models.Tenant, error) { + var query string + sortDir := "DESC" + if q.SortDir == "asc" { + sortDir = "ASC" + } + + if q.CursorPos == "" { + query = baseFilter + } else { + cursorTimestamp, err := strconv.ParseInt(q.CursorPos, 10, 64) + if err != nil { + return nil, fmt.Errorf("%w: invalid timestamp", driver.ErrInvalidCursor) + } + + if q.Compare == "<" { + query = fmt.Sprintf("(@created_at:[0 %d]) %s", cursorTimestamp-1, baseFilter) + } else { + query = fmt.Sprintf("(@created_at:[%d +inf]) %s", cursorTimestamp+1, baseFilter) + } + } + + result, err := s.doCmd(ctx, "FT.SEARCH", s.tenantIndexName(), + query, + "SORTBY", "created_at", sortDir, + "LIMIT", 0, q.Limit, + ).Result() + if err != nil { + return nil, fmt.Errorf("failed to search tenants: %w", err) + } + + tenants, _, err := parseSearchResult(result) + if err != nil { + return nil, err + } + + return tenants, nil +} + +func (s *store) listDestinationSummaryByTenant(ctx context.Context, tenantID string, opts driver.ListDestinationByTenantOpts) ([]destinationSummary, error) { + return parseListDestinationSummaryByTenantCmd(s.redisClient.HGetAll(ctx, s.redisTenantDestinationSummaryKey(tenantID)), opts) +} + +func (s *store) ListDestinationByTenant(ctx context.Context, tenantID string, options ...driver.ListDestinationByTenantOpts) ([]models.Destination, error) { + var opts driver.ListDestinationByTenantOpts + if len(options) > 0 { + opts = options[0] + } + + destinationSummaryList, err := s.listDestinationSummaryByTenant(ctx, tenantID, opts) + if err != nil { + return nil, err + } + + pipe := s.redisClient.Pipeline() + cmds := make([]*redis.MapStringStringCmd, len(destinationSummaryList)) + for i, destinationSummary := range destinationSummaryList { + cmds[i] = pipe.HGetAll(ctx, s.redisDestinationID(destinationSummary.ID, tenantID)) + } + _, err = pipe.Exec(ctx) + if err != nil { + return nil, err + } + + destinations := make([]models.Destination, len(destinationSummaryList)) + for i, cmd := range cmds { + destination, err := parseDestinationHash(cmd, tenantID, s.cipher) + if err != nil { + return []models.Destination{}, err + } + destinations[i] = *destination + } + + sort.Slice(destinations, func(i, j int) bool { + return destinations[i].CreatedAt.Before(destinations[j].CreatedAt) + }) + + return destinations, nil +} + +func (s *store) RetrieveDestination(ctx context.Context, tenantID, destinationID string) (*models.Destination, error) { + cmd := s.redisClient.HGetAll(ctx, s.redisDestinationID(destinationID, tenantID)) + destination, err := parseDestinationHash(cmd, tenantID, s.cipher) + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + return destination, nil +} + +func (s *store) CreateDestination(ctx context.Context, destination models.Destination) error { + key := s.redisDestinationID(destination.ID, destination.TenantID) + if fields, err := s.redisClient.HGetAll(ctx, key).Result(); err != nil { + return err + } else if len(fields) > 0 { + if _, isDeleted := fields["deleted_at"]; !isDeleted { + return driver.ErrDuplicateDestination + } + } + + count, err := s.redisClient.HLen(ctx, s.redisTenantDestinationSummaryKey(destination.TenantID)).Result() + if err != nil { + return err + } + if count >= int64(s.maxDestinationsPerTenant) { + return driver.ErrMaxDestinationsPerTenantReached + } + + return s.UpsertDestination(ctx, destination) +} + +func (s *store) UpsertDestination(ctx context.Context, destination models.Destination) error { + key := s.redisDestinationID(destination.ID, destination.TenantID) + + credentialsBytes, err := destination.Credentials.MarshalBinary() + if err != nil { + return fmt.Errorf("invalid destination credentials: %w", err) + } + encryptedCredentials, err := s.cipher.encrypt(credentialsBytes) + if err != nil { + return fmt.Errorf("failed to encrypt destination credentials: %w", err) + } + + var encryptedDeliveryMetadata []byte + if destination.DeliveryMetadata != nil { + deliveryMetadataBytes, err := destination.DeliveryMetadata.MarshalBinary() + if err != nil { + return fmt.Errorf("invalid destination delivery_metadata: %w", err) + } + encryptedDeliveryMetadata, err = s.cipher.encrypt(deliveryMetadataBytes) + if err != nil { + return fmt.Errorf("failed to encrypt destination delivery_metadata: %w", err) + } + } + + now := time.Now() + if destination.CreatedAt.IsZero() { + destination.CreatedAt = now + } + if destination.UpdatedAt.IsZero() { + destination.UpdatedAt = now + } + + summaryKey := s.redisTenantDestinationSummaryKey(destination.TenantID) + + _, err = s.redisClient.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Persist(ctx, key) + pipe.HDel(ctx, key, "deleted_at") + + pipe.HSet(ctx, key, "id", destination.ID) + pipe.HSet(ctx, key, "entity", "destination") + pipe.HSet(ctx, key, "type", destination.Type) + pipe.HSet(ctx, key, "topics", &destination.Topics) + pipe.HSet(ctx, key, "config", &destination.Config) + pipe.HSet(ctx, key, "credentials", encryptedCredentials) + pipe.HSet(ctx, key, "created_at", destination.CreatedAt.UnixMilli()) + pipe.HSet(ctx, key, "updated_at", destination.UpdatedAt.UnixMilli()) + + if destination.DisabledAt != nil { + pipe.HSet(ctx, key, "disabled_at", destination.DisabledAt.UnixMilli()) + } else { + pipe.HDel(ctx, key, "disabled_at") + } + + if destination.DeliveryMetadata != nil { + pipe.HSet(ctx, key, "delivery_metadata", encryptedDeliveryMetadata) + } else { + pipe.HDel(ctx, key, "delivery_metadata") + } + + if destination.Metadata != nil { + pipe.HSet(ctx, key, "metadata", &destination.Metadata) + } else { + pipe.HDel(ctx, key, "metadata") + } + + if len(destination.Filter) > 0 { + pipe.HSet(ctx, key, "filter", &destination.Filter) + } else { + pipe.HDel(ctx, key, "filter") + } + + pipe.HSet(ctx, summaryKey, destination.ID, newDestinationSummary(destination)) + return nil + }) + + return err +} + +func (s *store) DeleteDestination(ctx context.Context, tenantID, destinationID string) error { + key := s.redisDestinationID(destinationID, tenantID) + summaryKey := s.redisTenantDestinationSummaryKey(tenantID) + + if exists, err := s.redisClient.Exists(ctx, key).Result(); err != nil { + return err + } else if exists == 0 { + return driver.ErrDestinationNotFound + } + + _, err := s.redisClient.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + nowUnixMilli := time.Now().UnixMilli() + + pipe.HDel(ctx, summaryKey, destinationID) + pipe.HSet(ctx, key, "deleted_at", nowUnixMilli) + pipe.Expire(ctx, key, 7*24*time.Hour) + + return nil + }) + + return err +} + +func (s *store) MatchEvent(ctx context.Context, event models.Event) ([]string, error) { + destinationSummaryList, err := s.listDestinationSummaryByTenant(ctx, event.TenantID, driver.ListDestinationByTenantOpts{}) + if err != nil { + return nil, err + } + + var matched []string + + for _, ds := range destinationSummaryList { + if ds.Disabled { + continue + } + if event.Topic != "" && !ds.Topics.MatchTopic(event.Topic) { + continue + } + if !models.MatchFilter(ds.Filter, event) { + continue + } + matched = append(matched, ds.ID) + } + + return matched, nil +} diff --git a/internal/tenantstore/redistenantstore/redistenantstore_test.go b/internal/tenantstore/redistenantstore/redistenantstore_test.go new file mode 100644 index 00000000..4e02d501 --- /dev/null +++ b/internal/tenantstore/redistenantstore/redistenantstore_test.go @@ -0,0 +1,340 @@ +package redistenantstore_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/hookdeck/outpost/internal/idgen" + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/redis" + "github.com/hookdeck/outpost/internal/tenantstore/driver" + "github.com/hookdeck/outpost/internal/tenantstore/drivertest" + "github.com/hookdeck/outpost/internal/tenantstore/redistenantstore" + "github.com/hookdeck/outpost/internal/util/testinfra" + "github.com/hookdeck/outpost/internal/util/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// redisClientFactory is a function that creates a Redis client for testing. +type redisClientFactory func(t *testing.T) redis.Cmdable + +// miniredisFactory creates a miniredis client (in-memory, no RediSearch). +func miniredisFactory(t *testing.T) redis.Cmdable { + return testutil.CreateTestRedisClient(t) +} + +// redisStackFactory creates a Redis Stack client (with RediSearch). +func redisStackFactory(t *testing.T) redis.Cmdable { + testinfra.Start(t) + redisCfg := testinfra.NewRedisStackConfig(t) + client, err := redis.New(context.Background(), redisCfg) + if err != nil { + t.Fatalf("failed to create redis client: %v", err) + } + t.Cleanup(func() { client.Close() }) + return client +} + +// dragonflyFactory creates a Dragonfly client (no RediSearch). +func dragonflyFactory(t *testing.T) redis.Cmdable { + testinfra.Start(t) + redisCfg := testinfra.NewDragonflyConfig(t) + client, err := redis.New(context.Background(), redisCfg) + if err != nil { + t.Fatalf("failed to create dragonfly client: %v", err) + } + t.Cleanup(func() { client.Close() }) + return client +} + +// dragonflyStackFactory creates a Dragonfly client on DB 0 (with RediSearch). +func dragonflyStackFactory(t *testing.T) redis.Cmdable { + testinfra.Start(t) + redisCfg := testinfra.NewDragonflyStackConfig(t) + client, err := redis.New(context.Background(), redisCfg) + if err != nil { + t.Fatalf("failed to create dragonfly stack client: %v", err) + } + t.Cleanup(func() { client.Close() }) + return client +} + +// redisTenantStoreHarness implements drivertest.Harness for Redis-backed stores. +type redisTenantStoreHarness struct { + factory redisClientFactory + t *testing.T + deploymentID string +} + +func (h *redisTenantStoreHarness) MakeDriver(ctx context.Context) (driver.TenantStore, error) { + client := h.factory(h.t) + s := redistenantstore.New(client, + redistenantstore.WithSecret("test-secret"), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + redistenantstore.WithDeploymentID(h.deploymentID), + ) + if err := s.Init(ctx); err != nil { + return nil, err + } + return s, nil +} + +func (h *redisTenantStoreHarness) MakeDriverWithMaxDest(ctx context.Context, maxDest int) (driver.TenantStore, error) { + client := h.factory(h.t) + s := redistenantstore.New(client, + redistenantstore.WithSecret("test-secret"), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + redistenantstore.WithDeploymentID(h.deploymentID), + redistenantstore.WithMaxDestinationsPerTenant(maxDest), + ) + if err := s.Init(ctx); err != nil { + return nil, err + } + return s, nil +} + +func (h *redisTenantStoreHarness) MakeIsolatedDrivers(ctx context.Context) (driver.TenantStore, driver.TenantStore, error) { + client := h.factory(h.t) + s1 := redistenantstore.New(client, + redistenantstore.WithSecret("test-secret"), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + redistenantstore.WithDeploymentID("dp_001"), + ) + s2 := redistenantstore.New(client, + redistenantstore.WithSecret("test-secret"), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + redistenantstore.WithDeploymentID("dp_002"), + ) + if err := s1.Init(ctx); err != nil { + return nil, nil, err + } + if err := s2.Init(ctx); err != nil { + return nil, nil, err + } + return s1, s2, nil +} + +func (h *redisTenantStoreHarness) Close() {} + +func newHarness(factory redisClientFactory, deploymentID string) drivertest.HarnessMaker { + return func(_ context.Context, t *testing.T) (drivertest.Harness, error) { + return &redisTenantStoreHarness{ + factory: factory, + t: t, + deploymentID: deploymentID, + }, nil + } +} + +// ============================================================================= +// Conformance Tests with miniredis +// ============================================================================= + +func TestMiniredis(t *testing.T) { + t.Parallel() + drivertest.RunConformanceTests(t, newHarness(miniredisFactory, "")) +} + +func TestMiniredis_WithDeploymentID(t *testing.T) { + t.Parallel() + drivertest.RunConformanceTests(t, newHarness(miniredisFactory, "dp_test_001")) +} + +// ============================================================================= +// Conformance Tests with Redis Stack +// ============================================================================= + +func TestRedisStack(t *testing.T) { + t.Parallel() + drivertest.RunConformanceTests(t, newHarness(redisStackFactory, "")) +} + +func TestRedisStack_WithDeploymentID(t *testing.T) { + t.Parallel() + drivertest.RunConformanceTests(t, newHarness(redisStackFactory, "dp_test_001")) +} + +// ============================================================================= +// Conformance Tests with Dragonfly +// ============================================================================= + +func TestDragonfly(t *testing.T) { + t.Parallel() + drivertest.RunConformanceTests(t, newHarness(dragonflyFactory, "")) +} + +func TestDragonfly_WithDeploymentID(t *testing.T) { + t.Parallel() + drivertest.RunConformanceTests(t, newHarness(dragonflyFactory, "dp_test_001")) +} + +// ============================================================================= +// ListTenant Tests with Redis Stack (requires RediSearch) +// ============================================================================= + +func TestRedisStack_ListTenant(t *testing.T) { + t.Parallel() + drivertest.RunListTenantTests(t, newHarness(redisStackFactory, "")) +} + +func TestRedisStack_ListTenant_WithDeploymentID(t *testing.T) { + t.Parallel() + drivertest.RunListTenantTests(t, newHarness(redisStackFactory, "dp_test_001")) +} + +// ============================================================================= +// ListTenant Tests with Dragonfly Stack (requires RediSearch) +// ============================================================================= + +func TestDragonflyStack_ListTenant(t *testing.T) { + t.Parallel() + drivertest.RunListTenantTests(t, newHarness(dragonflyStackFactory, "")) +} + +func TestDragonflyStack_ListTenant_WithDeploymentID(t *testing.T) { + t.Parallel() + drivertest.RunListTenantTests(t, newHarness(dragonflyStackFactory, "dp_test_001")) +} + +// ============================================================================= +// Standalone: Credentials Encryption +// ============================================================================= + +func TestDestinationCredentialsEncryption(t *testing.T) { + t.Parallel() + + ctx := context.Background() + redisClient := testutil.CreateTestRedisClient(t) + secret := "test-secret" + + store := redistenantstore.New(redisClient, + redistenantstore.WithSecret(secret), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + ) + + input := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithType("rabbitmq"), + testutil.DestinationFactory.WithTopics([]string{"user.created", "user.updated"}), + testutil.DestinationFactory.WithConfig(map[string]string{ + "server_url": "localhost:5672", + "exchange": "events", + }), + testutil.DestinationFactory.WithCredentials(map[string]string{ + "username": "guest", + "password": "guest", + }), + testutil.DestinationFactory.WithDeliveryMetadata(map[string]string{ + "Authorization": "Bearer secret-token", + "X-API-Key": "sensitive-key", + }), + ) + + err := store.UpsertDestination(ctx, input) + require.NoError(t, err) + + // Access Redis directly to verify encryption (implementation detail) + keyFormat := "tenant:{%s}:destination:%s" + actual, err := redisClient.HGetAll(ctx, fmt.Sprintf(keyFormat, input.TenantID, input.ID)).Result() + require.NoError(t, err) + + // Verify credentials are encrypted (not plaintext JSON) + jsonCredentials, _ := json.Marshal(input.Credentials) + assert.NotEqual(t, string(jsonCredentials), actual["credentials"]) + + // Verify delivery_metadata is encrypted (not plaintext JSON) + jsonDeliveryMetadata, _ := json.Marshal(input.DeliveryMetadata) + assert.NotEqual(t, string(jsonDeliveryMetadata), actual["delivery_metadata"]) + + // Verify round-trip: retrieve destination and check values match + retrieved, err := store.RetrieveDestination(ctx, input.TenantID, input.ID) + require.NoError(t, err) + assert.Equal(t, input.Credentials, retrieved.Credentials) + assert.Equal(t, input.DeliveryMetadata, retrieved.DeliveryMetadata) +} + +// ============================================================================= +// Standalone: ListTenant not supported (miniredis has no RediSearch) +// ============================================================================= + +func TestListTenantNotSupported(t *testing.T) { + t.Parallel() + + ctx := context.Background() + redisClient := testutil.CreateTestRedisClient(t) + + store := redistenantstore.New(redisClient, + redistenantstore.WithSecret("test-secret"), + ) + require.NoError(t, store.Init(ctx)) + + _, err := store.ListTenant(ctx, driver.ListTenantRequest{}) + require.ErrorIs(t, err, driver.ErrListTenantNotSupported) +} + +// ============================================================================= +// Standalone: Deployment Isolation (same Redis, different deployment IDs) +// ============================================================================= + +func TestDeploymentIsolation(t *testing.T) { + t.Parallel() + + ctx := context.Background() + redisClient := testutil.CreateTestRedisClient(t) + + store1 := redistenantstore.New(redisClient, + redistenantstore.WithSecret("test-secret"), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + redistenantstore.WithDeploymentID("dp_001"), + ) + store2 := redistenantstore.New(redisClient, + redistenantstore.WithSecret("test-secret"), + redistenantstore.WithAvailableTopics(testutil.TestTopics), + redistenantstore.WithDeploymentID("dp_002"), + ) + + tenantID := idgen.String() + destinationID := idgen.Destination() + + tenant := models.Tenant{ + ID: tenantID, + CreatedAt: time.Now(), + } + require.NoError(t, store1.UpsertTenant(ctx, tenant)) + require.NoError(t, store2.UpsertTenant(ctx, tenant)) + + destination1 := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(destinationID), + testutil.DestinationFactory.WithTenantID(tenantID), + testutil.DestinationFactory.WithConfig(map[string]string{"deployment": "dp_001"}), + ) + destination2 := testutil.DestinationFactory.Any( + testutil.DestinationFactory.WithID(destinationID), + testutil.DestinationFactory.WithTenantID(tenantID), + testutil.DestinationFactory.WithConfig(map[string]string{"deployment": "dp_002"}), + ) + + require.NoError(t, store1.CreateDestination(ctx, destination1)) + require.NoError(t, store2.CreateDestination(ctx, destination2)) + + retrieved1, err := store1.RetrieveDestination(ctx, tenantID, destinationID) + require.NoError(t, err) + assert.Equal(t, "dp_001", retrieved1.Config["deployment"]) + + retrieved2, err := store2.RetrieveDestination(ctx, tenantID, destinationID) + require.NoError(t, err) + assert.Equal(t, "dp_002", retrieved2.Config["deployment"]) + + // Delete from store1 should not affect store2 + require.NoError(t, store1.DeleteDestination(ctx, tenantID, destinationID)) + + _, err = store1.RetrieveDestination(ctx, tenantID, destinationID) + require.ErrorIs(t, err, driver.ErrDestinationDeleted) + + retrieved2Again, err := store2.RetrieveDestination(ctx, tenantID, destinationID) + require.NoError(t, err) + assert.Equal(t, "dp_002", retrieved2Again.Config["deployment"]) +} diff --git a/internal/tenantstore/redistenantstore/serialization.go b/internal/tenantstore/redistenantstore/serialization.go new file mode 100644 index 00000000..6df3437e --- /dev/null +++ b/internal/tenantstore/redistenantstore/serialization.go @@ -0,0 +1,379 @@ +package redistenantstore + +import ( + "encoding/json" + "fmt" + "sort" + "strconv" + "time" + + "github.com/hookdeck/outpost/internal/models" + "github.com/hookdeck/outpost/internal/redis" + "github.com/hookdeck/outpost/internal/tenantstore/driver" +) + +// destinationSummary is a package-private summary used for Redis storage. +type destinationSummary struct { + ID string `json:"id"` + Type string `json:"type"` + Topics models.Topics `json:"topics"` + Filter models.Filter `json:"filter,omitempty"` + Disabled bool `json:"disabled"` +} + +func newDestinationSummary(d models.Destination) *destinationSummary { + return &destinationSummary{ + ID: d.ID, + Type: d.Type, + Topics: d.Topics, + Filter: d.Filter, + Disabled: d.DisabledAt != nil, + } +} + +func (ds *destinationSummary) MarshalBinary() ([]byte, error) { + return json.Marshal(ds) +} + +func (ds *destinationSummary) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, ds) +} + +// parseTenantHash parses a Redis hash map into a Tenant struct. +func parseTenantHash(hash map[string]string) (*models.Tenant, error) { + if _, ok := hash["deleted_at"]; ok { + return nil, driver.ErrTenantDeleted + } + if hash["id"] == "" { + return nil, fmt.Errorf("missing id") + } + + t := &models.Tenant{} + t.ID = hash["id"] + + var err error + t.CreatedAt, err = parseTimestamp(hash["created_at"]) + if err != nil { + return nil, fmt.Errorf("invalid created_at: %w", err) + } + + if hash["updated_at"] != "" { + t.UpdatedAt, err = parseTimestamp(hash["updated_at"]) + if err != nil { + t.UpdatedAt = t.CreatedAt + } + } else { + t.UpdatedAt = t.CreatedAt + } + + if metadataStr, exists := hash["metadata"]; exists && metadataStr != "" { + err = t.Metadata.UnmarshalBinary([]byte(metadataStr)) + if err != nil { + return nil, fmt.Errorf("invalid metadata: %w", err) + } + } + + return t, nil +} + +// parseDestinationHash parses a Redis HGetAll command result into a Destination struct. +func parseDestinationHash(cmd *redis.MapStringStringCmd, tenantID string, cipher *aesCipher) (*models.Destination, error) { + hash, err := cmd.Result() + if err != nil { + return nil, err + } + if len(hash) == 0 { + return nil, redis.Nil + } + if _, exists := hash["deleted_at"]; exists { + return nil, driver.ErrDestinationDeleted + } + + d := &models.Destination{TenantID: tenantID} + d.ID = hash["id"] + d.Type = hash["type"] + + d.CreatedAt, err = parseTimestamp(hash["created_at"]) + if err != nil { + return nil, fmt.Errorf("invalid created_at: %w", err) + } + + if hash["updated_at"] != "" { + d.UpdatedAt, err = parseTimestamp(hash["updated_at"]) + if err != nil { + d.UpdatedAt = d.CreatedAt + } + } else { + d.UpdatedAt = d.CreatedAt + } + + if hash["disabled_at"] != "" { + disabledAt, err := parseTimestamp(hash["disabled_at"]) + if err == nil { + d.DisabledAt = &disabledAt + } + } + + if err := d.Topics.UnmarshalBinary([]byte(hash["topics"])); err != nil { + return nil, fmt.Errorf("invalid topics: %w", err) + } + if err := d.Config.UnmarshalBinary([]byte(hash["config"])); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + credentialsBytes, err := cipher.decrypt([]byte(hash["credentials"])) + if err != nil { + return nil, fmt.Errorf("invalid credentials: %w", err) + } + if err := d.Credentials.UnmarshalBinary(credentialsBytes); err != nil { + return nil, fmt.Errorf("invalid credentials: %w", err) + } + + if deliveryMetadataStr, exists := hash["delivery_metadata"]; exists && deliveryMetadataStr != "" { + deliveryMetadataBytes, err := cipher.decrypt([]byte(deliveryMetadataStr)) + if err != nil { + return nil, fmt.Errorf("invalid delivery_metadata: %w", err) + } + if err := d.DeliveryMetadata.UnmarshalBinary(deliveryMetadataBytes); err != nil { + return nil, fmt.Errorf("invalid delivery_metadata: %w", err) + } + } + + if metadataStr, exists := hash["metadata"]; exists && metadataStr != "" { + if err := d.Metadata.UnmarshalBinary([]byte(metadataStr)); err != nil { + return nil, fmt.Errorf("invalid metadata: %w", err) + } + } + + if filterStr, exists := hash["filter"]; exists && filterStr != "" { + if err := d.Filter.UnmarshalBinary([]byte(filterStr)); err != nil { + return nil, fmt.Errorf("invalid filter: %w", err) + } + } + + return d, nil +} + +// parseTimestamp parses a timestamp from either numeric (Unix milliseconds) or RFC3339 format. +func parseTimestamp(value string) (time.Time, error) { + if value == "" { + return time.Time{}, fmt.Errorf("missing timestamp") + } + + if ts, err := strconv.ParseInt(value, 10, 64); err == nil { + return time.UnixMilli(ts).UTC(), nil + } + + if t, err := time.Parse(time.RFC3339Nano, value); err == nil { + return t, nil + } + + return time.Parse(time.RFC3339, value) +} + +// parseSearchResult parses an FT.SEARCH result (RESP2 or RESP3) into a list of tenants and total count. +func parseSearchResult(result interface{}) ([]models.Tenant, int, error) { + if resultMap, ok := result.(map[interface{}]interface{}); ok { + return parseResp3SearchResult(resultMap) + } + + arr, ok := result.([]interface{}) + if !ok || len(arr) == 0 { + return []models.Tenant{}, 0, nil + } + + totalCount, ok := arr[0].(int64) + if !ok { + return nil, 0, fmt.Errorf("invalid search result: expected total count") + } + + tenants := make([]models.Tenant, 0, (len(arr)-1)/2) + + for i := 1; i < len(arr); i += 2 { + if i+1 >= len(arr) { + break + } + + hash := make(map[string]string) + + switch fields := arr[i+1].(type) { + case []interface{}: + for j := 0; j < len(fields)-1; j += 2 { + key, keyOk := fields[j].(string) + val, valOk := fields[j+1].(string) + if keyOk && valOk { + hash[key] = val + } + } + case map[interface{}]interface{}: + for k, v := range fields { + key, keyOk := k.(string) + if !keyOk { + continue + } + switch val := v.(type) { + case string: + hash[key] = val + case float64: + hash[key] = fmt.Sprintf("%.0f", val) + case int64: + hash[key] = fmt.Sprintf("%d", val) + } + } + default: + continue + } + + if _, deleted := hash["deleted_at"]; deleted { + continue + } + + tenant, err := parseTenantHash(hash) + if err != nil { + continue + } + + tenants = append(tenants, *tenant) + } + + return tenants, int(totalCount), nil +} + +// parseResp3SearchResult parses a RESP3 FT.SEARCH result into a list of tenants and total count. +func parseResp3SearchResult(resultMap map[interface{}]interface{}) ([]models.Tenant, int, error) { + totalCount := 0 + if tc, ok := resultMap["total_results"].(int64); ok { + totalCount = int(tc) + } + + results, ok := resultMap["results"].([]interface{}) + if !ok { + return []models.Tenant{}, totalCount, nil + } + + tenants := make([]models.Tenant, 0, len(results)) + + for _, r := range results { + docMap, ok := r.(map[interface{}]interface{}) + if !ok { + continue + } + + extraAttrs, ok := docMap["extra_attributes"].(map[interface{}]interface{}) + if !ok { + continue + } + + hash := make(map[string]string) + for k, v := range extraAttrs { + if keyStr, ok := k.(string); ok { + if valStr, ok := v.(string); ok { + hash[keyStr] = valStr + } + } + } + + if _, deleted := hash["deleted_at"]; deleted { + continue + } + + tenant, err := parseTenantHash(hash) + if err != nil { + continue + } + + tenants = append(tenants, *tenant) + } + + return tenants, totalCount, nil +} + +// parseListDestinationSummaryByTenantCmd parses a Redis HGetAll command result into destination summaries. +func parseListDestinationSummaryByTenantCmd(cmd *redis.MapStringStringCmd, opts driver.ListDestinationByTenantOpts) ([]destinationSummary, error) { + destinationSummaryListHash, err := cmd.Result() + if err != nil { + if err == redis.Nil { + return []destinationSummary{}, nil + } + return nil, err + } + destinationSummaryList := make([]destinationSummary, 0, len(destinationSummaryListHash)) + for _, destinationSummaryStr := range destinationSummaryListHash { + ds := destinationSummary{} + if err := ds.UnmarshalBinary([]byte(destinationSummaryStr)); err != nil { + return nil, err + } + included := true + if opts.Filter != nil { + included = matchDestinationFilter(opts.Filter, ds) + } + if included { + destinationSummaryList = append(destinationSummaryList, ds) + } + } + return destinationSummaryList, nil +} + +// parseTenantTopics extracts and deduplicates topics from a list of destination summaries. +func parseTenantTopics(destinationSummaryList []destinationSummary) []string { + all := false + topicsSet := make(map[string]struct{}) + for _, destination := range destinationSummaryList { + for _, topic := range destination.Topics { + if topic == "*" { + all = true + break + } + topicsSet[topic] = struct{}{} + } + } + + if all { + return []string{"*"} + } + + topics := make([]string, 0, len(topicsSet)) + for topic := range topicsSet { + topics = append(topics, topic) + } + + sort.Strings(topics) + return topics +} + +// matchDestinationFilter checks if a destination summary matches the given filter criteria. +func matchDestinationFilter(filter *driver.DestinationFilter, summary destinationSummary) bool { + if len(filter.Type) > 0 { + found := false + for _, t := range filter.Type { + if t == summary.Type { + found = true + break + } + } + if !found { + return false + } + } + if len(filter.Topics) > 0 { + filterMatchesAll := len(filter.Topics) == 1 && filter.Topics[0] == "*" + if !summary.Topics.MatchesAll() { + if filterMatchesAll { + return false + } + for _, topic := range filter.Topics { + found := false + for _, st := range summary.Topics { + if st == topic { + found = true + break + } + } + if !found { + return false + } + } + } + } + return true +} diff --git a/internal/tenantstore/tenantstore.go b/internal/tenantstore/tenantstore.go new file mode 100644 index 00000000..3e82b0b8 --- /dev/null +++ b/internal/tenantstore/tenantstore.go @@ -0,0 +1,66 @@ +// Package tenantstore provides the TenantStore facade for tenant and destination storage. +package tenantstore + +import ( + "github.com/hookdeck/outpost/internal/redis" + "github.com/hookdeck/outpost/internal/tenantstore/driver" + "github.com/hookdeck/outpost/internal/tenantstore/memtenantstore" + "github.com/hookdeck/outpost/internal/tenantstore/redistenantstore" +) + +// Type aliases re-exported from driver. +type TenantStore = driver.TenantStore +type ListTenantRequest = driver.ListTenantRequest +type SeekPagination = driver.SeekPagination +type TenantPaginatedResult = driver.TenantPaginatedResult +type ListDestinationByTenantOpts = driver.ListDestinationByTenantOpts +type DestinationFilter = driver.DestinationFilter + +// Error sentinels re-exported from driver. +var ( + ErrTenantNotFound = driver.ErrTenantNotFound + ErrTenantDeleted = driver.ErrTenantDeleted + ErrDuplicateDestination = driver.ErrDuplicateDestination + ErrDestinationNotFound = driver.ErrDestinationNotFound + ErrDestinationDeleted = driver.ErrDestinationDeleted + ErrMaxDestinationsPerTenantReached = driver.ErrMaxDestinationsPerTenantReached + ErrListTenantNotSupported = driver.ErrListTenantNotSupported + ErrInvalidCursor = driver.ErrInvalidCursor + ErrInvalidOrder = driver.ErrInvalidOrder + ErrConflictingCursors = driver.ErrConflictingCursors +) + +// WithDestinationFilter creates a ListDestinationByTenantOpts with the given filter. +var WithDestinationFilter = driver.WithDestinationFilter + +// Config holds the configuration for creating a TenantStore. +type Config struct { + RedisClient redis.Cmdable + Secret string + AvailableTopics []string + MaxDestinationsPerTenant int + DeploymentID string +} + +// New creates a new Redis-backed TenantStore. +func New(cfg Config) TenantStore { + var opts []redistenantstore.Option + if cfg.Secret != "" { + opts = append(opts, redistenantstore.WithSecret(cfg.Secret)) + } + if len(cfg.AvailableTopics) > 0 { + opts = append(opts, redistenantstore.WithAvailableTopics(cfg.AvailableTopics)) + } + if cfg.MaxDestinationsPerTenant > 0 { + opts = append(opts, redistenantstore.WithMaxDestinationsPerTenant(cfg.MaxDestinationsPerTenant)) + } + if cfg.DeploymentID != "" { + opts = append(opts, redistenantstore.WithDeploymentID(cfg.DeploymentID)) + } + return redistenantstore.New(cfg.RedisClient, opts...) +} + +// NewMemTenantStore creates an in-memory TenantStore for testing. +func NewMemTenantStore() TenantStore { + return memtenantstore.New() +} diff --git a/internal/util/testinfra/mock.go b/internal/util/testinfra/mock.go index 52aae012..c0acf0e8 100644 --- a/internal/util/testinfra/mock.go +++ b/internal/util/testinfra/mock.go @@ -43,7 +43,7 @@ func startMockServer(cfg *Config) { } type MockServerInfra struct { - sdk destinationmockserver.EntityStore + sdk destinationmockserver.MockStore } func NewMockServerInfra(baseURL string) *MockServerInfra {