diff --git a/plugin/s3spanstore/athena_query_cache.go b/plugin/s3spanstore/athena_query_cache.go new file mode 100644 index 0000000..967ac4f --- /dev/null +++ b/plugin/s3spanstore/athena_query_cache.go @@ -0,0 +1,94 @@ +package s3spanstore + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/service/athena" + "github.com/aws/aws-sdk-go-v2/service/athena/types" + "golang.org/x/sync/errgroup" +) + +type AthenaQueryCache struct { + svc AthenaAPI + workGroup string +} + +func NewAthenaQueryCache(svc AthenaAPI, workGroup string) *AthenaQueryCache { + return &AthenaQueryCache{svc: svc, workGroup: workGroup} +} + +func (c *AthenaQueryCache) Lookup(ctx context.Context, key string, ttl time.Duration) (*types.QueryExecution, error) { + paginator := athena.NewListQueryExecutionsPaginator(c.svc, &athena.ListQueryExecutionsInput{ + WorkGroup: &c.workGroup, + }) + queryExecutionIds := []string{} + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get athena query result: %w", err) + } + + queryExecutionIds = append(queryExecutionIds, output.QueryExecutionIds...) + } + + queryExecutionIdChunks := chunks(queryExecutionIds, 50) + g, getQueryExecutionCtx := errgroup.WithContext(ctx) + + ttlTime := time.Now().Add(-ttl) + var mu sync.Mutex + + var latestQueryExecution *types.QueryExecution + + for _, value := range queryExecutionIdChunks { + value := value + g.Go(func() error { + result, err := c.svc.BatchGetQueryExecution(getQueryExecutionCtx, &athena.BatchGetQueryExecutionInput{ + QueryExecutionIds: value, + }) + if err != nil { + return err + } + + for _, v := range result.QueryExecutions { + // Different query + if !strings.Contains(*v.Query, key) { + continue + } + + // Query didn't completed + if v.Status.CompletionDateTime == nil { + continue + } + + // Query already expired + if v.Status.CompletionDateTime.Before(ttlTime) { + continue + } + + mu.Lock() + + // Store the latest query result + if latestQueryExecution == nil { + latestQueryExecution = &v + } else { + if v.Status.CompletionDateTime.After(*latestQueryExecution.Status.CompletionDateTime) { + latestQueryExecution = &v + } + } + + mu.Unlock() + } + + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + + return latestQueryExecution, nil +} diff --git a/plugin/s3spanstore/reader.go b/plugin/s3spanstore/reader.go index adb7d1c..f8db13a 100644 --- a/plugin/s3spanstore/reader.go +++ b/plugin/s3spanstore/reader.go @@ -6,7 +6,6 @@ import ( "fmt" "strconv" "strings" - "sync" "time" "github.com/aws/aws-sdk-go-v2/service/athena" @@ -16,7 +15,6 @@ import ( "github.com/jaegertracing/jaeger/storage/spanstore" "github.com/johanneswuerbach/jaeger-s3/plugin/config" "github.com/opentracing/opentracing-go" - "golang.org/x/sync/errgroup" ) // mockgen -destination=./plugin/s3spanstore/mocks/mock_athena.go -package=mocks github.com/johanneswuerbach/jaeger-s3/plugin/s3spanstore AthenaAPI @@ -53,6 +51,7 @@ func NewReader(logger hclog.Logger, svc AthenaAPI, cfg config.Athena) (*Reader, maxSpanAge: maxSpanAge, dependenciesQueryTTL: dependenciesQueryTTL, servicesQueryTTL: servicesQueryTTL, + athenaQueryCache: NewAthenaQueryCache(svc, cfg.WorkGroup), }, nil } @@ -63,6 +62,7 @@ type Reader struct { maxSpanAge time.Duration dependenciesQueryTTL time.Duration servicesQueryTTL time.Duration + athenaQueryCache *AthenaQueryCache } const ( @@ -328,122 +328,70 @@ func (r *Reader) GetDependencies(ctx context.Context, endTs time.Time, lookback } func (r *Reader) queryAthenaCached(ctx context.Context, queryString string, lookupString string, ttl time.Duration) ([]types.Row, error) { - paginator := athena.NewListQueryExecutionsPaginator(r.svc, &athena.ListQueryExecutionsInput{ - WorkGroup: &r.cfg.WorkGroup, - }) - queryExecutionIds := []string{} - for paginator.HasMorePages() { - output, err := paginator.NextPage(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get athena query result: %w", err) - } - - queryExecutionIds = append(queryExecutionIds, output.QueryExecutionIds...) - } - - queryExecutionIdChunks := chunks(queryExecutionIds, 50) - g, getQueryExecutionCtx := errgroup.WithContext(ctx) - - ttlTime := time.Now().Add(-ttl) - var mu sync.Mutex - - latestCompletionDateTime := time.Now() - latestQueryExecutionId := "" - - for _, value := range queryExecutionIdChunks { - value := value - g.Go(func() error { - result, err := r.svc.BatchGetQueryExecution(getQueryExecutionCtx, &athena.BatchGetQueryExecutionInput{ - QueryExecutionIds: value, - }) - if err != nil { - return err - } - - for _, v := range result.QueryExecutions { - // Different query - if !strings.Contains(*v.Query, lookupString) { - continue - } - - // Query didn't completed - if v.Status.CompletionDateTime == nil { - continue - } - - // Query already expired - if v.Status.CompletionDateTime.Before(ttlTime) { - continue - } - - mu.Lock() - - // Store the latest query result - if latestQueryExecutionId == "" { - latestQueryExecutionId = *v.QueryExecutionId - latestCompletionDateTime = *v.Status.CompletionDateTime - } else { - if v.Status.CompletionDateTime.After(latestCompletionDateTime) { - latestQueryExecutionId = *v.QueryExecutionId - latestCompletionDateTime = *v.Status.CompletionDateTime - } - } - - mu.Unlock() - } - - return nil - }) - } - if err := g.Wait(); err != nil { - return nil, err + queryExecution, err := r.athenaQueryCache.Lookup(ctx, lookupString, ttl) + if err != nil { + return nil, fmt.Errorf("failed to lookup cached athena query: %w", err) } - if latestQueryExecutionId != "" { - return r.fetchQueryResult(ctx, latestQueryExecutionId) + if queryExecution != nil { + return r.waitAndFetchQueryResult(ctx, queryExecution) } return r.queryAthena(ctx, queryString) } -func (s *Reader) queryAthena(ctx context.Context, queryString string) ([]types.Row, error) { - output, err := s.svc.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ +func (r *Reader) queryAthena(ctx context.Context, queryString string) ([]types.Row, error) { + output, err := r.svc.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ QueryString: &queryString, QueryExecutionContext: &types.QueryExecutionContext{ - Database: &s.cfg.DatabaseName, + Database: &r.cfg.DatabaseName, }, ResultConfiguration: &types.ResultConfiguration{ - OutputLocation: &s.cfg.OutputLocation, + OutputLocation: &r.cfg.OutputLocation, }, - WorkGroup: &s.cfg.WorkGroup, + WorkGroup: &r.cfg.WorkGroup, }) if err != nil { return nil, fmt.Errorf("failed to start athena query: %w", err) } + status, err := r.svc.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ + QueryExecutionId: output.QueryExecutionId, + }) + if err != nil { + return nil, fmt.Errorf("failed to get athena query execution: %w", err) + } + + return r.waitAndFetchQueryResult(ctx, status.QueryExecution) +} + +func (r *Reader) waitAndFetchQueryResult(ctx context.Context, queryExecution *types.QueryExecution) ([]types.Row, error) { // Poll until the query completed for { - status, err := s.svc.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ - QueryExecutionId: output.QueryExecutionId, + if queryExecution.Status.CompletionDateTime != nil { + break + } + + time.Sleep(100 * time.Millisecond) + + status, err := r.svc.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ + QueryExecutionId: queryExecution.QueryExecutionId, }) if err != nil { return nil, fmt.Errorf("failed to get athena query execution: %w", err) } - if status.QueryExecution.Status.CompletionDateTime != nil { - break - } - time.Sleep(100 * time.Millisecond) + queryExecution = status.QueryExecution } - return s.fetchQueryResult(ctx, *output.QueryExecutionId) + return r.fetchQueryResult(ctx, queryExecution.QueryExecutionId) } -func (r *Reader) fetchQueryResult(ctx context.Context, queryExecutionId string) ([]types.Row, error) { +func (r *Reader) fetchQueryResult(ctx context.Context, queryExecutionId *string) ([]types.Row, error) { // Get query results paginator := athena.NewGetQueryResultsPaginator(r.svc, &athena.GetQueryResultsInput{ - QueryExecutionId: &queryExecutionId, + QueryExecutionId: queryExecutionId, }) rows := []types.Row{} for paginator.HasMorePages() {