diff --git a/server/src/main/java/org/opensearch/action/ActionModule.java b/server/src/main/java/org/opensearch/action/ActionModule.java index 12fbabf341c41..858c11216164f 100644 --- a/server/src/main/java/org/opensearch/action/ActionModule.java +++ b/server/src/main/java/org/opensearch/action/ActionModule.java @@ -286,8 +286,6 @@ import org.opensearch.action.search.PutSearchPipelineTransportAction; import org.opensearch.action.search.SearchAction; import org.opensearch.action.search.SearchScrollAction; -import org.opensearch.action.search.StreamSearchAction; -import org.opensearch.action.search.StreamTransportSearchAction; import org.opensearch.action.search.TransportClearScrollAction; import org.opensearch.action.search.TransportCreatePitAction; import org.opensearch.action.search.TransportDeletePitAction; @@ -736,9 +734,7 @@ public void reg actions.register(MultiGetAction.INSTANCE, TransportMultiGetAction.class, TransportShardMultiGetAction.class); actions.register(BulkAction.INSTANCE, TransportBulkAction.class, TransportShardBulkAction.class); actions.register(SearchAction.INSTANCE, TransportSearchAction.class); - if (FeatureFlags.isEnabled(FeatureFlags.STREAM_TRANSPORT)) { - actions.register(StreamSearchAction.INSTANCE, StreamTransportSearchAction.class); - } + // Streaming search handled via SearchAction with streamingSearchMode parameter actions.register(SearchScrollAction.INSTANCE, TransportSearchScrollAction.class); actions.register(MultiSearchAction.INSTANCE, TransportMultiSearchAction.class); actions.register(ExplainAction.INSTANCE, TransportExplainAction.class); diff --git a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java index b04d3086d8c95..cfbb1d8e0bb8c 100644 --- a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java @@ -151,6 +151,14 @@ int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) { return (hasAggs || hasTopDocs) ? Math.min(requestBatchedReduceSize, minBatchReduceSize) : minBatchReduceSize; } + /** + * Protected accessor for progressListener to allow subclasses to access it. + * @return the search progress listener + */ + protected SearchProgressListener progressListener() { + return this.progressListener; + } + @Override public void close() { Releasables.close(pendingReduces); @@ -239,6 +247,7 @@ private ReduceResult partialReduce( } for (QuerySearchResult result : toConsume) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + // For streaming, avoid reassigning shardIndex if already set SearchPhaseController.setShardIndex(topDocs.topDocs, result.getShardIndex()); topDocsList.add(topDocs.topDocs); } @@ -273,7 +282,18 @@ private ReduceResult partialReduce( SearchShardTarget target = result.getSearchShardTarget(); processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); } - progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); + // For streaming search with TopDocs, use the new notification method + if (hasTopDocs && newTopDocs != null) { + progressListener.notifyPartialReduceWithTopDocs( + processedShards, + topDocsStats.getTotalHits(), + newTopDocs, + newAggs, + numReducePhases + ); + } else { + progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); + } // we leave the results un-serialized because serializing is slow but we compute the serialized // size as an estimate of the memory used by the newly reduced aggregations. long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0; @@ -564,6 +584,7 @@ private synchronized List consumeTopDocs() { } for (QuerySearchResult result : buffer) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + // For streaming, avoid reassigning shardIndex if already set SearchPhaseController.setShardIndex(topDocs.topDocs, result.getShardIndex()); topDocsList.add(topDocs.topDocs); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index 40a2805563369..4191be54bfb38 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -32,6 +32,8 @@ package org.opensearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.Term; import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.FieldDoc; @@ -90,6 +92,7 @@ * @opensearch.internal */ public final class SearchPhaseController { + private static final Logger logger = LogManager.getLogger(SearchPhaseController.class); private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0]; private final NamedWriteableRegistry namedWriteableRegistry; @@ -246,7 +249,14 @@ static TopDocs mergeTopDocs(Collection results, int topN, int from) { } static void setShardIndex(TopDocs topDocs, int shardIndex) { - assert topDocs.scoreDocs.length == 0 || topDocs.scoreDocs[0].shardIndex == -1 : "shardIndex is already set"; + // Idempotent assignment: in streaming flows partial reductions may touch the same TopDocs more than once. + if (topDocs.scoreDocs.length == 0) { + return; + } + if (topDocs.scoreDocs[0].shardIndex != -1) { + // Already set by a previous pass; avoid reassigning to prevent assertion failures + return; + } for (ScoreDoc doc : topDocs.scoreDocs) { doc.shardIndex = shardIndex; } @@ -795,40 +805,36 @@ QueryPhaseResultConsumer newSearchPhaseResults( Consumer onPartialMergeFailure, BooleanSupplier isTaskCancelled ) { - return new QueryPhaseResultConsumer( - request, - executor, - circuitBreaker, - this, - listener, - namedWriteableRegistry, - numShards, - onPartialMergeFailure, - isTaskCancelled - ); - } - - /** - * Returns a new {@link StreamQueryPhaseResultConsumer} instance that reduces search responses incrementally. - */ - StreamQueryPhaseResultConsumer newStreamSearchPhaseResults( - Executor executor, - CircuitBreaker circuitBreaker, - SearchProgressListener listener, - SearchRequest request, - int numShards, - Consumer onPartialMergeFailure - ) { - return new StreamQueryPhaseResultConsumer( - request, - executor, - circuitBreaker, - this, - listener, - namedWriteableRegistry, - numShards, - onPartialMergeFailure - ); + // Check if this is a streaming search request + String streamingMode = request.getStreamingSearchMode(); + if (logger.isDebugEnabled()) { + logger.debug("Streaming mode on request: {}", streamingMode); + } + if (streamingMode != null) { + return new StreamQueryPhaseResultConsumer( + request, + executor, + circuitBreaker, + this, + listener, + namedWriteableRegistry, + numShards, + onPartialMergeFailure + ); + } else { + // Regular QueryPhaseResultConsumer + return new QueryPhaseResultConsumer( + request, + executor, + circuitBreaker, + this, + listener, + namedWriteableRegistry, + numShards, + onPartialMergeFailure, + isTaskCancelled + ); + } } /** diff --git a/server/src/main/java/org/opensearch/action/search/SearchProgressListener.java b/server/src/main/java/org/opensearch/action/search/SearchProgressListener.java index 34e8aacbad250..ab0fb723ad30e 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/opensearch/action/search/SearchProgressListener.java @@ -100,6 +100,26 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc */ protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} + /** + * Executed when a partial reduce with TopDocs is created for streaming search. + * + * @param shards The list of shards that are part of this reduce. + * @param totalHits The total number of hits in this reduce. + * @param topDocs The partial TopDocs result (may be null if no docs). + * @param aggs The partial result for aggregations. + * @param reducePhase The version number for this reduce. + */ + protected void onPartialReduceWithTopDocs( + List shards, + TotalHits totalHits, + org.apache.lucene.search.TopDocs topDocs, + InternalAggregations aggs, + int reducePhase + ) { + // Default implementation delegates to the original method for backward compatibility + onPartialReduce(shards, totalHits, aggs, reducePhase); + } + /** * Executed once when the final reduce is created. * @@ -165,6 +185,20 @@ final void notifyPartialReduce(List shards, TotalHits totalHits, In } } + final void notifyPartialReduceWithTopDocs( + List shards, + TotalHits totalHits, + org.apache.lucene.search.TopDocs topDocs, + InternalAggregations aggs, + int reducePhase + ) { + try { + onPartialReduceWithTopDocs(shards, totalHits, topDocs, aggs, reducePhase); + } catch (Exception e) { + logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on partial reduce with TopDocs"), e); + } + } + protected final void notifyFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { try { onFinalReduce(shards, totalHits, aggs, reducePhase); diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequest.java b/server/src/main/java/org/opensearch/action/search/SearchRequest.java index a1e6e7605cbdb..2d757360dff35 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequest.java @@ -128,6 +128,9 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla private Boolean phaseTook = null; + private boolean streamingScoring = false; + private String streamingSearchMode = null; // Will use StreamingSearchMode.SCORED_UNSORTED if null + public SearchRequest() { this.localClusterAlias = null; this.absoluteStartMillis = DEFAULT_ABSOLUTE_START_MILLIS; @@ -145,6 +148,7 @@ public SearchRequest(SearchRequest searchRequest) { searchRequest.absoluteStartMillis, searchRequest.finalReduce ); + this.streamingScoring = searchRequest.streamingScoring; } /** @@ -280,6 +284,14 @@ public SearchRequest(StreamInput in) throws IOException { if (in.getVersion().onOrAfter(Version.V_2_12_0)) { phaseTook = in.readOptionalBoolean(); } + // Read streaming fields - gated on version for BWC + if (in.getVersion().onOrAfter(Version.V_3_3_0)) { + streamingScoring = in.readBoolean(); + streamingSearchMode = in.readOptionalString(); + } else { + streamingScoring = false; + streamingSearchMode = null; + } } @Override @@ -314,6 +326,11 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_2_12_0)) { out.writeOptionalBoolean(phaseTook); } + // Write streaming fields - gated on version for BWC + if (out.getVersion().onOrAfter(Version.V_3_3_0)) { + out.writeBoolean(streamingScoring); + out.writeOptionalString(streamingSearchMode); + } } @Override @@ -695,6 +712,36 @@ public void setPhaseTook(Boolean phaseTook) { this.phaseTook = phaseTook; } + /** + * Enable streaming scoring for this search request. + */ + public void setStreamingScoring(boolean streamingScoring) { + this.streamingScoring = streamingScoring; + } + + /** + * Check if streaming scoring is enabled for this search request. + */ + public boolean isStreamingScoring() { + return streamingScoring; + } + + /** + * Sets the streaming search mode for this request. + * @param mode The streaming search mode to use + */ + public void setStreamingSearchMode(String mode) { + this.streamingSearchMode = mode; + } + + /** + * Gets the streaming search mode for this request. + * @return The streaming search mode, or null if not set + */ + public String getStreamingSearchMode() { + return streamingSearchMode; + } + /** * Returns a threshold that enforces a pre-filter roundtrip to pre-filter search shards based on query rewriting if the number of shards * the search request expands to exceeds the threshold, or null if the threshold is unspecified. diff --git a/server/src/main/java/org/opensearch/action/search/SearchResponse.java b/server/src/main/java/org/opensearch/action/search/SearchResponse.java index c9568b4d77791..c07d36fe21653 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchResponse.java +++ b/server/src/main/java/org/opensearch/action/search/SearchResponse.java @@ -103,6 +103,11 @@ public class SearchResponse extends ActionResponse implements StatusToXContentOb private final long tookInMillis; private final PhaseTook phaseTook; + // Fields for streaming responses + private boolean isPartial = false; + private int sequenceNumber = 0; + private int totalPartials = 0; + public SearchResponse(StreamInput in) throws IOException { super(in); internalResponse = new InternalSearchResponse(in); @@ -302,6 +307,31 @@ public String getScrollId() { return scrollId; } + // Streaming response methods + public boolean isPartial() { + return isPartial; + } + + public void setPartial(boolean partial) { + this.isPartial = partial; + } + + public int getSequenceNumber() { + return sequenceNumber; + } + + public void setSequenceNumber(int sequenceNumber) { + this.sequenceNumber = sequenceNumber; + } + + public int getTotalPartials() { + return totalPartials; + } + + public void setTotalPartials(int totalPartials) { + this.totalPartials = totalPartials; + } + /** * Returns the encoded string of the search context that the search request is used to executed */ diff --git a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java index fec8c4e790e7a..e1049068ad7c9 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java @@ -32,6 +32,8 @@ package org.opensearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.IndicesRequest; import org.opensearch.action.OriginalIndices; @@ -101,6 +103,7 @@ public class SearchTransportService { public static final String CREATE_READER_CONTEXT_ACTION_NAME = "indices:data/read/search[create_context]"; public static final String UPDATE_READER_CONTEXT_ACTION_NAME = "indices:data/read/search[update_context]"; + private static final Logger logger = LogManager.getLogger(SearchTransportService.class); private final TransportService transportService; protected final BiFunction responseWrapper; private final Map clientConnections = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); @@ -243,6 +246,14 @@ public void sendExecuteQuery( final boolean fetchDocuments = request.numberOfShards() == 1; Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: coordinator sending QUERY to node={} shard={} via classic transport (fetchDocuments={})", + connection.getNode().getId(), + request.shardId(), + fetchDocuments + ); + } final ActionListener handler = responseWrapper.apply(connection, listener); transportService.sendChildRequest( connection, @@ -259,6 +270,14 @@ public void sendExecuteQuery( SearchTask task, final SearchActionListener listener ) { + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: coordinator sending QUERY to node={} shard={} via classic transport (fetchDocuments={})", + connection.getNode().getId(), + request.contextId(), + false + ); + } transportService.sendChildRequest( connection, QUERY_ID_ACTION_NAME, @@ -565,11 +584,19 @@ public static void registerRequestHandler(TransportService transportService, Sea AdmissionControlActionType.SEARCH, ShardSearchRequest::new, (request, channel, task) -> { + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: classic handler for query; isStreamSearch=false listener=ChannelActionListener shard={} ", + request.shardId() + ); + } searchService.executeQueryPhase( request, false, (SearchShardTask) task, - new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request) + new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request), + ThreadPool.Names.SAME, + false ); } ); diff --git a/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java index 75612b081e5e5..31192b08dd8a9 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java @@ -8,21 +8,48 @@ package org.opensearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.search.SearchPhaseResult; -import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.StreamingSearchMode; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; /** - * Streaming query phase result consumer + * Query phase result consumer for streaming search. + * Supports progressive batch reduction with configurable scoring modes. + * + * Batch reduction frequency is controlled by per-mode multipliers: + * - NO_SCORING: Immediate reduction (batch size = 1) for fastest time-to-first-byte + * - SCORED_UNSORTED: Small batches (minBatchReduceSize * 2) + * - SCORED_SORTED: Larger batches (minBatchReduceSize * 10) + * + * These multipliers are applied to the base batch reduce size (typically 5) to determine + * how many shard results are accumulated before triggering a partial reduction. Lower values + * mean more frequent reductions and faster streaming, but higher coordinator CPU usage. * * @opensearch.internal */ public class StreamQueryPhaseResultConsumer extends QueryPhaseResultConsumer { + private static final Logger logger = LogManager.getLogger(StreamQueryPhaseResultConsumer.class); + + private final StreamingSearchMode scoringMode; + private int resultsReceived = 0; + + // TTFB tracking for demonstrating fetch phase timing + private long queryStartTime = System.currentTimeMillis(); + private long firstBatchReadyForFetchTime = -1; + private boolean firstBatchReadyForFetch = false; + private final AtomicInteger batchesReduced = new AtomicInteger(0); + + /** + * Creates a streaming query phase result consumer. + */ public StreamQueryPhaseResultConsumer( SearchRequest request, Executor executor, @@ -43,22 +70,66 @@ public StreamQueryPhaseResultConsumer( expectedResultSize, onPartialMergeFailure ); + + // Initialize scoring mode from request + String mode = request.getStreamingSearchMode(); + this.scoringMode = (mode != null) ? StreamingSearchMode.fromString(mode) : StreamingSearchMode.SCORED_SORTED; } /** - * For stream search, the minBatchReduceSize is set higher than shard number + * Controls partial reduction frequency based on scoring mode. * - * @param minBatchReduceSize: pass as number of shard + * @param requestBatchedReduceSize request batch size + * @param minBatchReduceSize minimum batch size */ @Override int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) { - return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 10); + // Handle null during construction (parent constructor calls this before our constructor body runs) + if (scoringMode == null) { + return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 10); + } + + switch (scoringMode) { + case NO_SCORING: + // Reduce immediately for fastest TTFB + return Math.min(requestBatchedReduceSize, 1); + case SCORED_UNSORTED: + // Small batches for quick emission without sorting overhead + return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 2); + case SCORED_SORTED: + // Higher batch size to collect more results before reducing (sorting is expensive) + return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 10); + default: + return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 10); + } + } + + /** + * Consume streaming results with frequency-based emission + */ + public void consumeStreamResult(SearchPhaseResult result, Runnable next) { + // Keep streaming: coordinator receives partials and forwards to client, + // but the coordinator reducer should only see the final per-shard result. + // Do not enqueue partials into pendingReduces. + + // Optional: cheap debug log if needed + logger.debug("Dropping partial from reducer, shard={}, partial={}", result.getShardIndex(), result.queryResult().isPartial()); + + // Immediately continue the pipeline + next.run(); + } + + /** + * Get TTFB metrics for benchmarking + */ + public long getTimeToFirstBatch() { + if (firstBatchReadyForFetchTime > 0) { + return firstBatchReadyForFetchTime - queryStartTime; + } + return -1; } - void consumeStreamResult(SearchPhaseResult result, Runnable next) { - // For streaming, we skip the ArraySearchPhaseResults.consumeResult() call - // since it doesn't support multiple results from the same shard. - QuerySearchResult querySearchResult = result.queryResult(); - pendingReduces.consume(querySearchResult, next); + public boolean isFirstBatchReady() { + return firstBatchReadyForFetch; } } diff --git a/server/src/main/java/org/opensearch/action/search/StreamSearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/StreamSearchQueryThenFetchAsyncAction.java index a2dac2e74965c..a91514f049ac2 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamSearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/StreamSearchQueryThenFetchAsyncAction.java @@ -21,8 +21,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; /** @@ -30,9 +28,7 @@ */ public class StreamSearchQueryThenFetchAsyncAction extends SearchQueryThenFetchAsyncAction { - private final AtomicInteger streamResultsReceived = new AtomicInteger(0); - private final AtomicInteger streamResultsConsumeCallback = new AtomicInteger(0); - private final AtomicBoolean shardResultsConsumed = new AtomicBoolean(false); + private final Logger logger; StreamSearchQueryThenFetchAsyncAction( Logger logger, @@ -74,6 +70,7 @@ public class StreamSearchQueryThenFetchAsyncAction extends SearchQueryThenFetchA searchRequestContext, tracer ); + this.logger = logger; } /** @@ -93,7 +90,9 @@ SearchActionListener createShardActionListener( @Override protected void innerOnStreamResponse(SearchPhaseResult result) { try { - streamResultsReceived.incrementAndGet(); + if (getLogger().isTraceEnabled()) { + getLogger().trace("STREAM DEBUG: coordinator received partial from shard {}", shard); + } onStreamResult(result, shardIt, () -> successfulStreamExecution()); } finally { executeNext(pendingExecutions, thread); @@ -103,6 +102,9 @@ protected void innerOnStreamResponse(SearchPhaseResult result) { @Override protected void innerOnCompleteResponse(SearchPhaseResult result) { try { + if (getLogger().isTraceEnabled()) { + getLogger().trace("STREAM DEBUG: coordinator received final for shard {}", shard); + } onShardResult(result, shardIt); } finally { executeNext(pendingExecutions, thread); @@ -138,6 +140,27 @@ protected void onStreamResult(SearchPhaseResult result, SearchShardIterator shar ((StreamQueryPhaseResultConsumer) results).consumeStreamResult(result, next); } + /** + * Override onShardResult to handle streaming search results safely. + * This prevents the "topDocs already consumed" error when processing + * multiple streaming results from the same shard. + */ + @Override + protected void onShardResult(SearchPhaseResult result, SearchShardIterator shardIt) { + // Safety log: track final shard response receipt in coordinator + if (logger.isTraceEnabled()) { + logger.trace( + "COORDINATOR: received final shard result from shard={}, target={}, totalOps={}, expectedOps={}", + result.getShardIndex(), + result.getSearchShardTarget(), + totalOps.get(), + expectedTotalOps + ); + } + // Always delegate to the parent to ensure shard accounting and phase transitions. + super.onShardResult(result, shardIt); + } + /** * Override successful shard execution to handle stream result synchronization */ @@ -152,16 +175,8 @@ void successfulShardExecution(SearchShardIterator shardsIt) { final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator); if (xTotalOps == expectedTotalOps) { try { - shardResultsConsumed.set(true); - if (streamResultsReceived.get() == streamResultsConsumeCallback.get()) { - getLogger().debug("Stream results consumption has called back, let shard consumption callback trigger onPhaseDone"); - onPhaseDone(); - } else { - assert streamResultsReceived.get() > streamResultsConsumeCallback.get(); - getLogger().debug( - "Shard results consumption finishes before stream results, let stream consumption callback trigger onPhaseDone" - ); - } + // All final shard results have been processed; partials are not reduced. + onPhaseDone(); } catch (final Exception ex) { onPhaseFailure(this, "The phase has failed", ex); } @@ -175,17 +190,10 @@ void successfulShardExecution(SearchShardIterator shardsIt) { /** * Handle successful stream execution callback + * Since partials are no longer fed into the reducer, this callback is not needed for coordination. */ private void successfulStreamExecution() { - try { - if (streamResultsReceived.get() == streamResultsConsumeCallback.incrementAndGet()) { - if (shardResultsConsumed.get()) { - getLogger().debug("Stream consumption trigger onPhaseDone"); - onPhaseDone(); - } - } - } catch (final Exception ex) { - onPhaseFailure(this, "The phase has failed", ex); - } + // No-op: partials are bypassed from reducer, completion is handled by successfulShardExecution only } + } diff --git a/server/src/main/java/org/opensearch/action/search/StreamSearchTransportService.java b/server/src/main/java/org/opensearch/action/search/StreamSearchTransportService.java index 3bb251af66204..5795b90990e23 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamSearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/StreamSearchTransportService.java @@ -12,6 +12,8 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.OriginalIndices; import org.opensearch.action.support.StreamSearchChannelListener; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.Nullable; import org.opensearch.common.settings.Setting; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; @@ -43,7 +45,7 @@ * @opensearch.internal */ public class StreamSearchTransportService extends SearchTransportService { - private final Logger logger = LogManager.getLogger(StreamSearchTransportService.class); + private static final Logger logger = LogManager.getLogger(StreamSearchTransportService.class); private final StreamTransportService transportService; @@ -55,6 +57,14 @@ public StreamSearchTransportService( this.transportService = transportService; } + @Override + public Transport.Connection getConnection(@Nullable String clusterAlias, DiscoveryNode node) { + // Delegate to StreamTransportService to get connections from the streaming connection manager. + // This ensures connections understand the streaming protocol and call handleStreamResponse() + // instead of handleResponse() on StreamTransportResponseHandler instances. + return transportService.getConnection(node); + } + public static final Setting STREAM_SEARCH_ENABLED = Setting.boolSetting( "stream.search.enabled", false, @@ -71,13 +81,21 @@ public static void registerStreamRequestHandler(StreamTransportService transport AdmissionControlActionType.SEARCH, ShardSearchRequest::new, (request, channel, task) -> { + boolean isStreamSearch = request.isStreamingSearch() || request.getStreamingSearchMode() != null; + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: stream handler for query; isStreamSearch={} listener=StreamSearchChannelListener shard={}", + isStreamSearch, + request.shardId() + ); + } searchService.executeQueryPhase( request, false, (SearchShardTask) task, new StreamSearchChannelListener<>(channel, QUERY_ACTION_NAME, request), ThreadPool.Names.STREAM_SEARCH, - true + isStreamSearch ); } ); @@ -143,36 +161,39 @@ public void sendExecuteQuery( final boolean fetchDocuments = request.numberOfShards() == 1; Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; - final StreamSearchActionListener streamListener = (StreamSearchActionListener) listener; + final boolean streamingListener = listener instanceof StreamSearchActionListener; StreamTransportResponseHandler transportHandler = new StreamTransportResponseHandler() { @Override public void handleStreamResponse(StreamTransportResponse response) { try { - // only send previous result if we have a current result - // if current result is null, that means the previous result is the last result SearchPhaseResult currentResult; SearchPhaseResult lastResult = null; - - // Keep reading results until we reach the end while ((currentResult = response.nextResponse()) != null) { - if (lastResult != null) { - streamListener.onStreamResponse(lastResult, false); + if (streamingListener) { + if (lastResult != null) { + ((StreamSearchActionListener) listener).onStreamResponse(lastResult, false); + } + lastResult = currentResult; + } else { + // Non-streaming: keep only the last (final) response + lastResult = currentResult; } - lastResult = currentResult; } - // Send the final result as complete response, or null if no results if (lastResult != null) { - streamListener.onStreamResponse(lastResult, true); - logger.debug("Processed final stream response"); + if (streamingListener) { + ((StreamSearchActionListener) listener).onStreamResponse(lastResult, true); + logger.debug("Processed final stream response"); + } else { + listener.onResponse(lastResult); + } } else { - // Empty stream case - logger.error("Empty stream"); + logger.debug("Empty stream"); } response.close(); } catch (Exception e) { response.cancel("Client error during search phase", e); - streamListener.onFailure(e); + listener.onFailure(e); } } @@ -192,12 +213,21 @@ public SearchPhaseResult read(StreamInput in) throws IOException { } }; + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: coordinator sending QUERY to node={} shard={} via stream transport (fetchDocuments={})", + connection.getNode().getId(), + request.shardId(), + fetchDocuments + ); + } transportService.sendChildRequest( connection, QUERY_ACTION_NAME, request, task, - transportHandler // TODO: wrap with ConnectionCountingHandler + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), + transportHandler ); } @@ -236,7 +266,14 @@ public FetchSearchResult read(StreamInput in) throws IOException { return new FetchSearchResult(in); } }; - transportService.sendChildRequest(connection, FETCH_ID_ACTION_NAME, request, task, transportHandler); + transportService.sendChildRequest( + connection, + FETCH_ID_ACTION_NAME, + request, + task, + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), + transportHandler + ); } @Override diff --git a/server/src/main/java/org/opensearch/action/search/StreamTransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/StreamTransportSearchAction.java deleted file mode 100644 index 55351289ae9e4..0000000000000 --- a/server/src/main/java/org/opensearch/action/search/StreamTransportSearchAction.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.action.search; - -import org.opensearch.action.support.ActionFilters; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.routing.GroupShardsIterator; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.Nullable; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.search.SearchPhaseResult; -import org.opensearch.search.SearchService; -import org.opensearch.search.internal.AliasFilter; -import org.opensearch.search.pipeline.SearchPipelineService; -import org.opensearch.tasks.TaskResourceTrackingService; -import org.opensearch.telemetry.metrics.MetricsRegistry; -import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.StreamTransportService; -import org.opensearch.transport.Transport; -import org.opensearch.transport.client.node.NodeClient; - -import java.util.Map; -import java.util.Set; -import java.util.concurrent.Executor; -import java.util.function.BiFunction; - -/** - * Transport search action for streaming search - * @opensearch.internal - */ -public class StreamTransportSearchAction extends TransportSearchAction { - @Inject - public StreamTransportSearchAction( - NodeClient client, - ThreadPool threadPool, - CircuitBreakerService circuitBreakerService, - @Nullable StreamTransportService transportService, - SearchService searchService, - @Nullable StreamSearchTransportService searchTransportService, - SearchPhaseController searchPhaseController, - ClusterService clusterService, - ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver, - NamedWriteableRegistry namedWriteableRegistry, - SearchPipelineService searchPipelineService, - MetricsRegistry metricsRegistry, - SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory, - Tracer tracer, - TaskResourceTrackingService taskResourceTrackingService - ) { - super( - client, - threadPool, - circuitBreakerService, - transportService, - searchService, - searchTransportService, - searchPhaseController, - clusterService, - actionFilters, - indexNameExpressionResolver, - namedWriteableRegistry, - searchPipelineService, - metricsRegistry, - searchRequestOperationsCompositeListenerFactory, - tracer, - taskResourceTrackingService - ); - } - - AbstractSearchAsyncAction searchAsyncAction( - SearchTask task, - SearchRequest searchRequest, - Executor executor, - GroupShardsIterator shardIterators, - SearchTimeProvider timeProvider, - BiFunction connectionLookup, - ClusterState clusterState, - Map aliasFilter, - Map concreteIndexBoosts, - Map> indexRoutings, - ActionListener listener, - boolean preFilter, - ThreadPool threadPool, - SearchResponse.Clusters clusters, - SearchRequestContext searchRequestContext - ) { - if (preFilter) { - throw new IllegalStateException("Search pre-filter is not supported in streaming"); - } else { - final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newStreamSearchPhaseResults( - executor, - circuitBreaker, - task.getProgressListener(), - searchRequest, - shardIterators.size(), - exc -> cancelTask(task, exc) - ); - AbstractSearchAsyncAction searchAsyncAction; - switch (searchRequest.searchType()) { - case QUERY_THEN_FETCH: - searchAsyncAction = new StreamSearchQueryThenFetchAsyncAction( - logger, - searchTransportService, - connectionLookup, - aliasFilter, - concreteIndexBoosts, - indexRoutings, - searchPhaseController, - executor, - queryResultConsumer, - searchRequest, - listener, - shardIterators, - timeProvider, - clusterState, - task, - clusters, - searchRequestContext, - tracer - ); - break; - default: - throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]"); - } - return searchAsyncAction; - } - } -} diff --git a/server/src/main/java/org/opensearch/action/search/StreamingSearchProgressListener.java b/server/src/main/java/org/opensearch/action/search/StreamingSearchProgressListener.java new file mode 100644 index 0000000000000..1332a4124daca --- /dev/null +++ b/server/src/main/java/org/opensearch/action/search/StreamingSearchProgressListener.java @@ -0,0 +1,141 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * SearchProgressListener implementation for streaming search with scoring. + * Computes partial search results progressively as shards complete. + * + * @opensearch.internal + */ +public class StreamingSearchProgressListener extends SearchProgressListener { + private static final Logger logger = LogManager.getLogger(StreamingSearchProgressListener.class); + + private final ActionListener responseListener; + private final AtomicInteger streamEmissions = new AtomicInteger(0); + private final SearchPhaseController searchPhaseController; + private final SearchRequest searchRequest; + + public StreamingSearchProgressListener( + ActionListener responseListener, + SearchPhaseController searchPhaseController, + SearchRequest searchRequest + ) { + this.responseListener = responseListener; + this.searchPhaseController = searchPhaseController; + this.searchRequest = searchRequest; + } + + @Override + protected void onPartialReduceWithTopDocs( + List shards, + TotalHits totalHits, + TopDocs topDocs, + InternalAggregations aggs, + int reducePhase + ) { + if (topDocs == null || topDocs.scoreDocs.length == 0) { + // No docs to emit + return; + } + + try { + // Convert TopDocs to SearchHits + // Simplified conversion of TopDocs to SearchHits + SearchHit[] hits = new SearchHit[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + hits[i] = new SearchHit(topDocs.scoreDocs[i].doc); + hits[i].score(topDocs.scoreDocs[i].score); + } + + float maxScore = hits.length > 0 ? hits[0].getScore() : Float.NaN; + SearchHits searchHits = new SearchHits(hits, totalHits, maxScore); + + // Create a partial search response with the current TopDocs + SearchResponseSections sections = new SearchResponseSections( + searchHits, + aggs, + null, // no suggestions in partial results + false, // not timed out + false, // not terminated early + null, // no profile results + reducePhase + ); + + // Create partial response + SearchResponse partialResponse = new SearchResponse( + sections, + null, // no scroll ID for streaming + shards.size(), // total shards + shards.size(), // successful shards + 0, // no skipped shards + 0, // took time (will be set later) + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY, + null // no phase took + ); + + collectPartialResponse(partialResponse); + + int count = streamEmissions.incrementAndGet(); + logger.info("Computed streaming partial #{} with {} docs from {} shards", count, topDocs.scoreDocs.length, shards.size()); + + } catch (Exception e) { + logger.error("Failed to send partial TopDocs", e); + } + } + + private void collectPartialResponse(SearchResponse partialResponse) { + if (responseListener instanceof StreamingSearchResponseListener) { + ((StreamingSearchResponseListener) responseListener).onPartialResponse(partialResponse); + } else { + logger.debug("Partial result computed, listener type: {}", responseListener.getClass().getSimpleName()); + } + } + + @Override + protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + logger.info( + "Final reduce: {} total hits from {} shards, {} partial computations", + totalHits.value(), + shards.size(), + streamEmissions.get() + ); + } + + public int getStreamEmissions() { + return streamEmissions.get(); + } + + /** + * Trigger partial emission of results + * This method is called by StreamQueryPhaseResultConsumer to trigger partial emissions + */ + public void triggerPartialEmission() { + // Trigger a partial reduce to emit current results + // This will call onPartialReduceWithTopDocs if there are results to emit + logger.debug("Triggering partial emission, current emissions: {}", streamEmissions.get()); + + // For now, just log that we're triggering emission + // The actual emission will happen when onPartialReduceWithTopDocs is called + // by the parent class's reduce logic + } +} diff --git a/server/src/main/java/org/opensearch/action/search/StreamingSearchProgressMessage.java b/server/src/main/java/org/opensearch/action/search/StreamingSearchProgressMessage.java new file mode 100644 index 0000000000000..b63d4cbab194e --- /dev/null +++ b/server/src/main/java/org/opensearch/action/search/StreamingSearchProgressMessage.java @@ -0,0 +1,305 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * Message sent from shard to coordinator with streaming search progress updates. + * Implements milestone-based streaming at 25%, 50%, 75%, and 100% completion. + */ +public class StreamingSearchProgressMessage extends TransportRequest { + + /** + * Milestone types for streaming progress + */ + public enum Milestone { + INITIAL(0.0f), // Initial results available + QUARTER(0.25f), // 25% complete + HALF(0.50f), // 50% complete + THREE_QUARTER(0.75f), // 75% complete + COMPLETE(1.0f); // 100% complete + + private final float progress; + + Milestone(float progress) { + this.progress = progress; + } + + public float getProgress() { + return progress; + } + + public static Milestone fromProgress(float progress) { + if (progress <= 0.0f) return INITIAL; + if (progress <= 0.25f) return QUARTER; + if (progress <= 0.50f) return HALF; + if (progress <= 0.75f) return THREE_QUARTER; + return COMPLETE; + } + } + + private final SearchShardTarget shardTarget; + private final int shardIndex; + private final long requestId; + private final Milestone milestone; + private final TopDocs topDocs; + private final ProgressStatistics statistics; + private final boolean isFinal; + + public StreamingSearchProgressMessage() { + super(); + this.shardTarget = null; + this.shardIndex = -1; + this.requestId = -1; + this.milestone = Milestone.INITIAL; + this.topDocs = null; + this.statistics = null; + this.isFinal = false; + } + + public StreamingSearchProgressMessage( + SearchShardTarget shardTarget, + int shardIndex, + long requestId, + Milestone milestone, + TopDocs topDocs, + ProgressStatistics statistics, + boolean isFinal + ) { + this.shardTarget = shardTarget; + this.shardIndex = shardIndex; + this.requestId = requestId; + this.milestone = milestone; + this.topDocs = topDocs; + this.statistics = statistics; + this.isFinal = isFinal; + } + + public StreamingSearchProgressMessage(StreamInput in) throws IOException { + super(in); + this.shardTarget = new SearchShardTarget(in); + this.shardIndex = in.readVInt(); + this.requestId = in.readLong(); + this.milestone = in.readEnum(Milestone.class); + + // Read TopDocs + int totalHits = in.readVInt(); + int numDocs = in.readVInt(); + ScoreDoc[] scoreDocs = new ScoreDoc[numDocs]; + for (int i = 0; i < numDocs; i++) { + int doc = in.readVInt(); + float score = in.readFloat(); + scoreDocs[i] = new ScoreDoc(doc, score); + } + this.topDocs = new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs); + + this.statistics = new ProgressStatistics(in); + this.isFinal = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + shardTarget.writeTo(out); + out.writeVInt(shardIndex); + out.writeLong(requestId); + out.writeEnum(milestone); + + // Write TopDocs + out.writeVLong(topDocs.totalHits != null ? topDocs.totalHits.value() : 0); + out.writeVInt(topDocs.scoreDocs.length); + for (ScoreDoc doc : topDocs.scoreDocs) { + out.writeVInt(doc.doc); + out.writeFloat(doc.score); + } + + statistics.writeTo(out); + out.writeBoolean(isFinal); + } + + public SearchShardTarget getShardTarget() { + return shardTarget; + } + + public int getShardIndex() { + return shardIndex; + } + + public long getRequestId() { + return requestId; + } + + public Milestone getMilestone() { + return milestone; + } + + public TopDocs getTopDocs() { + return topDocs; + } + + public ProgressStatistics getStatistics() { + return statistics; + } + + public boolean isFinal() { + return isFinal; + } + + /** + * Statistics about the search progress + */ + public static class ProgressStatistics implements Writeable { + private final int docsCollected; + private final long docsEvaluated; + private final long docsSkipped; + private final long blocksProcessed; + private final long blocksSkipped; + private final float minCompetitiveScore; + private final float maxSeenScore; + private final long elapsedTimeMillis; + + public ProgressStatistics( + int docsCollected, + long docsEvaluated, + long docsSkipped, + long blocksProcessed, + long blocksSkipped, + float minCompetitiveScore, + float maxSeenScore, + long elapsedTimeMillis + ) { + this.docsCollected = docsCollected; + this.docsEvaluated = docsEvaluated; + this.docsSkipped = docsSkipped; + this.blocksProcessed = blocksProcessed; + this.blocksSkipped = blocksSkipped; + this.minCompetitiveScore = minCompetitiveScore; + this.maxSeenScore = maxSeenScore; + this.elapsedTimeMillis = elapsedTimeMillis; + } + + public ProgressStatistics(StreamInput in) throws IOException { + this.docsCollected = in.readVInt(); + this.docsEvaluated = in.readVLong(); + this.docsSkipped = in.readVLong(); + this.blocksProcessed = in.readVLong(); + this.blocksSkipped = in.readVLong(); + this.minCompetitiveScore = in.readFloat(); + this.maxSeenScore = in.readFloat(); + this.elapsedTimeMillis = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(docsCollected); + out.writeVLong(docsEvaluated); + out.writeVLong(docsSkipped); + out.writeVLong(blocksProcessed); + out.writeVLong(blocksSkipped); + out.writeFloat(minCompetitiveScore); + out.writeFloat(maxSeenScore); + out.writeVLong(elapsedTimeMillis); + } + + public int getDocsCollected() { + return docsCollected; + } + + public long getDocsEvaluated() { + return docsEvaluated; + } + + public long getDocsSkipped() { + return docsSkipped; + } + + public double getSkipRatio() { + long total = docsEvaluated + docsSkipped; + return total > 0 ? (double) docsSkipped / total : 0.0; + } + + public double getBlockSkipRatio() { + long total = blocksProcessed + blocksSkipped; + return total > 0 ? (double) blocksSkipped / total : 0.0; + } + + public float getMinCompetitiveScore() { + return minCompetitiveScore; + } + + public float getMaxSeenScore() { + return maxSeenScore; + } + + public long getElapsedTimeMillis() { + return elapsedTimeMillis; + } + } + + /** + * Builder for creating progress messages + */ + public static class Builder { + private SearchShardTarget shardTarget; + private int shardIndex; + private long requestId; + private Milestone milestone; + private TopDocs topDocs; + private ProgressStatistics statistics; + private boolean isFinal = false; + + public Builder shardTarget(SearchShardTarget shardTarget) { + this.shardTarget = shardTarget; + return this; + } + + public Builder shardIndex(int shardIndex) { + this.shardIndex = shardIndex; + return this; + } + + public Builder requestId(long requestId) { + this.requestId = requestId; + return this; + } + + public Builder milestone(Milestone milestone) { + this.milestone = milestone; + return this; + } + + public Builder topDocs(TopDocs topDocs) { + this.topDocs = topDocs; + return this; + } + + public Builder statistics(ProgressStatistics statistics) { + this.statistics = statistics; + return this; + } + + public Builder isFinal(boolean isFinal) { + this.isFinal = isFinal; + return this; + } + + public StreamingSearchProgressMessage build() { + return new StreamingSearchProgressMessage(shardTarget, shardIndex, requestId, milestone, topDocs, statistics, isFinal); + } + } +} diff --git a/server/src/main/java/org/opensearch/action/search/StreamingSearchResponseListener.java b/server/src/main/java/org/opensearch/action/search/StreamingSearchResponseListener.java new file mode 100644 index 0000000000000..f9ccf02a0fbc4 --- /dev/null +++ b/server/src/main/java/org/opensearch/action/search/StreamingSearchResponseListener.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.action.ActionListener; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * ActionListener implementation for streaming search responses. + * Collects partial results and includes streaming metadata in the final response. + * + * @opensearch.internal + */ +public class StreamingSearchResponseListener implements ActionListener { + private static final Logger logger = LogManager.getLogger(StreamingSearchResponseListener.class); + + private final ActionListener delegate; + private final AtomicBoolean isComplete = new AtomicBoolean(false); + private final AtomicInteger partialCount = new AtomicInteger(0); + private final SearchRequest searchRequest; + private final List partialResponses = Collections.synchronizedList(new ArrayList<>()); + + public StreamingSearchResponseListener(ActionListener delegate, SearchRequest searchRequest) { + this.delegate = delegate; + this.searchRequest = searchRequest; + } + + /** + * Collect a partial response and track TTFB. + * Store first partial response time for TTFB measurement. + */ + public void onPartialResponse(SearchResponse partialResponse) { + if (isComplete.get()) { + logger.warn("Attempted to collect partial response after completion"); + return; + } + + int count = partialCount.incrementAndGet(); + partialResponse.setPartial(true); + partialResponse.setSequenceNumber(count); + + partialResponses.add(partialResponse); + logPartialResponse(partialResponse, count); + + // Track TTFB - first partial result delivery time + if (count == 1 && partialResponse.getHits() != null) { + int numHits = partialResponse.getHits().getHits().length; + logger.info("First partial result delivered with {} hits", numHits); + } + } + + /** + * Send the final response and complete the request. + * Include metadata about the streaming process. + */ + @Override + public void onResponse(SearchResponse finalResponse) { + if (isComplete.compareAndSet(false, true)) { + finalResponse.setPartial(false); + finalResponse.setSequenceNumber(partialCount.incrementAndGet()); + finalResponse.setTotalPartials(partialCount.get()); + + logStreamingSummary(finalResponse); + delegate.onResponse(finalResponse); + } + } + + @Override + public void onFailure(Exception e) { + if (isComplete.compareAndSet(false, true)) { + delegate.onFailure(e); + } + } + + private void logPartialResponse(SearchResponse partialResponse, int count) { + if (partialResponse.getHits() != null && partialResponse.getHits().getHits() != null) { + int numHits = partialResponse.getHits().getHits().length; + long totalHits = partialResponse.getHits().getTotalHits().value(); + + logger.info("Streaming partial result #{}: {} hits, total: {}", count, numHits, totalHits); + + if (logger.isDebugEnabled() && numHits > 0) { + float topScore = partialResponse.getHits().getHits()[0].getScore(); + logger.debug("Top score in partial #{}: {}", count, topScore); + } + } + } + + private void logStreamingSummary(SearchResponse finalResponse) { + int totalPartials = partialCount.get(); + if (totalPartials > 0) { + logger.info("Streaming search complete: {} partial computations", totalPartials); + + if (!partialResponses.isEmpty()) { + long totalDocsProcessed = partialResponses.stream() + .mapToLong(r -> r.getHits() != null ? r.getHits().getHits().length : 0) + .sum(); + logger.info("Processed {} docs across {} partial emissions", totalDocsProcessed, partialResponses.size()); + } + } else { + logger.debug("No partial computations performed"); + } + } + + /** + * Check if this listener supports streaming + */ + public boolean isStreamingSupported() { + return true; + } +} diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 541ed989b35a8..1d5a58ff0c0fa 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -82,6 +82,7 @@ import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.search.slice.SliceBuilder; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; @@ -97,7 +98,6 @@ import org.opensearch.transport.RemoteClusterAware; import org.opensearch.transport.RemoteClusterService; import org.opensearch.transport.RemoteTransportException; -import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -136,6 +136,7 @@ * @opensearch.internal */ public class TransportSearchAction extends HandledTransportAction { + // Streaming search integrated via streamingSearchMode in SearchRequest /** The maximum number of shards for a single search request. */ public static final Setting SHARD_COUNT_LIMIT_SETTING = Setting.longSetting( @@ -168,6 +169,8 @@ public class TransportSearchAction extends HandledTransportAction 0); + final boolean streamingEnabledSetting = clusterService.getClusterSettings().get(StreamSearchTransportService.STREAM_SEARCH_ENABLED); + final boolean useStreamingTransportForConnection = isStreamingCandidate + && streamSearchTransportService != null + && streamingEnabledSetting; + final SearchTransportService connectionTransport = useStreamingTransportForConnection + ? streamSearchTransportService + : searchTransportService; BiFunction connectionLookup = buildConnectionLookup( searchRequest.getLocalClusterAlias(), nodes::get, remoteConnections, - searchTransportService::getConnection + connectionTransport::getConnection ); final Executor asyncSearchExecutor = asyncSearchExecutor(concreteLocalIndices, clusterState); final boolean preFilterSearchShards = shouldPreFilterSearchShards( @@ -1225,10 +1240,30 @@ AbstractSearchAsyncAction searchAsyncAction( SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext ) { + // Determine if this request should use streaming transport + final boolean isStreamingCandidate = (searchRequest.isStreamingScoring() || searchRequest.getStreamingSearchMode() != null) + && (searchRequest.source() == null || searchRequest.source().size() > 0); + + // Check if streaming transport is actually available and enabled + final boolean streamingEnabledSetting = clusterService.getClusterSettings().get(StreamSearchTransportService.STREAM_SEARCH_ENABLED); + final boolean canUseStreamingTransport = (streamSearchTransportService != null) && streamingEnabledSetting; + + // Use streaming transport for streaming search requests + final boolean useStreamingTransport = isStreamingCandidate && canUseStreamingTransport; + if (preFilter) { + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: prefilter using transport [{}] (streaming={}, enabled={}, canUse={})", + ((isStreamingCandidate && canUseStreamingTransport) ? "stream" : "classic"), + isStreamingCandidate, + streamingEnabledSetting, + canUseStreamingTransport + ); + } return new CanMatchPreFilterSearchPhase( logger, - searchTransportService, + (isStreamingCandidate && useStreamingTransport) ? streamSearchTransportService : searchTransportService, connectionLookup, aliasFilter, concreteIndexBoosts, @@ -1264,21 +1299,42 @@ AbstractSearchAsyncAction searchAsyncAction( tracer ); } else { + // Set default streaming mode when only flag is set + if (searchRequest.isStreamingScoring() && searchRequest.getStreamingSearchMode() == null) { + searchRequest.setStreamingSearchMode(StreamingSearchMode.SCORED_UNSORTED.toString()); + } + + final boolean isStreamingRequest = (searchRequest.isStreamingScoring() || searchRequest.getStreamingSearchMode() != null) + && (searchRequest.source() == null || searchRequest.source().size() > 0); + + final SearchProgressListener progressListener = (isStreamingRequest && useStreamingTransport) + ? new StreamingSearchProgressListener(listener, searchPhaseController, searchRequest) + : task.getProgressListener(); + final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults( executor, circuitBreaker, - task.getProgressListener(), + progressListener, searchRequest, shardIterators.size(), exc -> cancelTask(task, exc), task::isCancelled ); + if (logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: query phase using transport [{}] (streamingRequest={}, enabled={}, canUse={})", + ((isStreamingRequest && useStreamingTransport) ? "stream" : "classic"), + isStreamingRequest, + streamingEnabledSetting, + canUseStreamingTransport + ); + } AbstractSearchAsyncAction searchAsyncAction; switch (searchRequest.searchType()) { case DFS_QUERY_THEN_FETCH: searchAsyncAction = new SearchDfsQueryThenFetchAsyncAction( logger, - searchTransportService, + (isStreamingRequest && canUseStreamingTransport) ? streamSearchTransportService : searchTransportService, connectionLookup, aliasFilter, concreteIndexBoosts, @@ -1298,26 +1354,49 @@ AbstractSearchAsyncAction searchAsyncAction( ); break; case QUERY_THEN_FETCH: - searchAsyncAction = new SearchQueryThenFetchAsyncAction( - logger, - searchTransportService, - connectionLookup, - aliasFilter, - concreteIndexBoosts, - indexRoutings, - searchPhaseController, - executor, - queryResultConsumer, - searchRequest, - listener, - shardIterators, - timeProvider, - clusterState, - task, - clusters, - searchRequestContext, - tracer - ); + if (isStreamingRequest && canUseStreamingTransport) { + searchAsyncAction = new StreamSearchQueryThenFetchAsyncAction( + logger, + streamSearchTransportService, + connectionLookup, + aliasFilter, + concreteIndexBoosts, + indexRoutings, + searchPhaseController, + executor, + queryResultConsumer, + searchRequest, + listener, + shardIterators, + timeProvider, + clusterState, + task, + clusters, + searchRequestContext, + tracer + ); + } else { + searchAsyncAction = new SearchQueryThenFetchAsyncAction( + logger, + searchTransportService, + connectionLookup, + aliasFilter, + concreteIndexBoosts, + indexRoutings, + searchPhaseController, + executor, + queryResultConsumer, + searchRequest, + listener, + shardIterators, + timeProvider, + clusterState, + task, + clusters, + searchRequestContext, + tracer + ); + } break; default: throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]"); diff --git a/server/src/main/java/org/opensearch/action/support/StreamSearchChannelListener.java b/server/src/main/java/org/opensearch/action/support/StreamSearchChannelListener.java index 31967fafb20b7..68ce23bb8b772 100644 --- a/server/src/main/java/org/opensearch/action/support/StreamSearchChannelListener.java +++ b/server/src/main/java/org/opensearch/action/support/StreamSearchChannelListener.java @@ -8,6 +8,8 @@ package org.opensearch.action.support; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.core.action.ActionListener; import org.opensearch.core.transport.TransportResponse; @@ -28,10 +30,13 @@ public class StreamSearchChannelListener { + private static final Logger logger = LogManager.getLogger(StreamSearchChannelListener.class); private final TransportChannel channel; private final Request request; private final String actionName; + private final java.util.concurrent.atomic.AtomicBoolean completed = new java.util.concurrent.atomic.AtomicBoolean(false); + public StreamSearchChannelListener(TransportChannel channel, String actionName, Request request) { this.channel = channel; this.request = request; @@ -47,9 +52,17 @@ public StreamSearchChannelListener(TransportChannel channel, String actionName, */ public void onStreamResponse(Response response, boolean isLastBatch) { assert response != null; + if (completed.get()) { + // Ignore late responses after completion to avoid double-completion and task tracker mismatches + return; + } channel.sendResponseBatch(response); if (isLastBatch) { - channel.completeStream(); + try { + channel.completeStream(); + } finally { + completed.set(true); + } } } @@ -66,10 +79,15 @@ public final void onResponse(Response response) { @Override public void onFailure(Exception e) { + // Ensure we only fail once per request/channel to keep task tracker consistent + if (completed.getAndSet(true)) { + // Already completed (success or failure); drop duplicate failure + return; + } try { channel.sendResponse(e); } catch (IOException exc) { - channel.completeStream(); + logger.warn("Failed to send error response on streaming channel", exc); throw new RuntimeException(exc); } } diff --git a/server/src/main/java/org/opensearch/common/network/NetworkModule.java b/server/src/main/java/org/opensearch/common/network/NetworkModule.java index 1be92d9a1a751..9655e1a2b4dcc 100644 --- a/server/src/main/java/org/opensearch/common/network/NetworkModule.java +++ b/server/src/main/java/org/opensearch/common/network/NetworkModule.java @@ -278,6 +278,20 @@ public NetworkModule( registerTransport(entry.getKey(), entry.getValue()); } + // Register stream transports + Map> streamTransportFactory = plugin.getStreamTransports( + settings, + threadPool, + pageCacheRecycler, + circuitBreakerService, + namedWriteableRegistry, + networkService, + tracer + ); + for (Map.Entry> entry : streamTransportFactory.entrySet()) { + registerTransport(entry.getKey(), entry.getValue()); + } + // Register any HTTP secure transports if available if (secureHttpTransportSettingsProviders.isEmpty() == false) { final SecureHttpTransportSettingsProvider secureSettingProvider = secureHttpTransportSettingsProviders.iterator().next(); diff --git a/server/src/main/java/org/opensearch/plugins/NetworkPlugin.java b/server/src/main/java/org/opensearch/plugins/NetworkPlugin.java index e2d0e468ed032..daa031aec5c25 100644 --- a/server/src/main/java/org/opensearch/plugins/NetworkPlugin.java +++ b/server/src/main/java/org/opensearch/plugins/NetworkPlugin.java @@ -106,6 +106,22 @@ default Map> getTransports( return Collections.emptyMap(); } + /** + * Returns a map of streaming {@link Transport} suppliers. + * See {@link org.opensearch.common.network.NetworkModule#STREAM_TRANSPORT_TYPE_KEY} to configure a specific implementation. + */ + default Map> getStreamTransports( + Settings settings, + ThreadPool threadPool, + PageCacheRecycler pageCacheRecycler, + CircuitBreakerService circuitBreakerService, + NamedWriteableRegistry namedWriteableRegistry, + NetworkService networkService, + Tracer tracer + ) { + return Collections.emptyMap(); + } + /** * Returns a map of {@link HttpServerTransport} suppliers. * See {@link org.opensearch.common.network.NetworkModule#HTTP_TYPE_SETTING} to configure a specific implementation. diff --git a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java index fb3bc549572d1..a6fd6a411c724 100644 --- a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java @@ -154,6 +154,10 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC if (clusterSettings != null && clusterSettings.get(STREAM_SEARCH_ENABLED)) { if (FeatureFlags.isEnabled(FeatureFlags.STREAM_TRANSPORT)) { if (canUseStreamSearch(searchRequest)) { + String scoringMode = request.param("stream_scoring_mode"); + if (scoringMode != null) { + searchRequest.setStreamingSearchMode(scoringMode); + } return channel -> { RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel()); cancelClient.execute(StreamSearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel)); @@ -243,6 +247,11 @@ public static void parseSearchRequest( searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions())); searchRequest.pipeline(request.param("search_pipeline", searchRequest.source().pipeline())); + // Add streaming mode support + if (request.hasParam("streaming_mode")) { + searchRequest.setStreamingSearchMode(request.param("streaming_mode")); + } + checkRestTotalHits(request, searchRequest); request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false); diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 14f7b4b321638..30bd5dd7ba284 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -100,6 +100,7 @@ import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.slice.SliceBuilder; import org.opensearch.search.sort.SortAndFormats; @@ -225,6 +226,7 @@ final class DefaultSearchContext extends SearchContext { private boolean isStreamSearch; private StreamSearchChannelListener listener; + private StreamingSearchMode streamingMode; private final SetOnce cachedFlushMode = new SetOnce<>(); DefaultSearchContext( @@ -291,6 +293,22 @@ final class DefaultSearchContext extends SearchContext { this.concurrentSearchDeciderFactories = concurrentSearchDeciderFactories; this.keywordIndexOrDocValuesEnabled = evaluateKeywordIndexOrDocValuesEnabled(); this.isStreamSearch = isStreamSearch; + + // Initialize streaming mode from request + if (request.getStreamingSearchMode() != null) { + try { + this.streamingMode = StreamingSearchMode.fromString(request.getStreamingSearchMode()); + // If a streaming mode is set, enable streaming search + this.isStreamSearch = true; + // Set FlushMode to PER_SEGMENT for streaming aggregations + this.cachedFlushMode.trySet(FlushMode.PER_SEGMENT); + if (logger.isDebugEnabled()) { + logger.debug("Initialized streaming search with mode: {} and FlushMode: PER_SEGMENT", this.streamingMode); + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid streaming search mode: " + request.getStreamingSearchMode(), e); + } + } } DefaultSearchContext( @@ -1283,6 +1301,21 @@ public boolean isStreamSearch() { return isStreamSearch; } + public StreamingSearchMode getStreamingMode() { + // Do not default to a mode; null means streaming disabled + return streamingMode; + } + + public void setStreamingMode(StreamingSearchMode mode) { + this.streamingMode = mode; + } + + @Override + public int getStreamingBatchSize() { + // Return fixed default for streaming batch size + return 10; + } + /** * Disables streaming for this search context. * Used when streaming cost analysis determines traditional processing is more efficient. diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 66d00e7dffe07..32e3addb0bbf4 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -142,6 +142,7 @@ import org.opensearch.search.query.QuerySearchRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ScrollQuerySearchResult; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.search.rescore.RescorerBuilder; import org.opensearch.search.searchafter.SearchAfterBuilder; import org.opensearch.search.sort.FieldSortBuilder; @@ -750,7 +751,9 @@ public void executeQueryPhase( ActionListener listener, String executorName ) { - executeQueryPhase(request, keepStatesInContext, task, listener, executorName, false); + // Determine if this is a streaming search request + boolean isStreamSearch = request.isStreamingSearch() || request.getStreamingSearchMode() != null; + executeQueryPhase(request, keepStatesInContext, task, listener, executorName, isStreamSearch); } public void executeQueryPhase( @@ -784,11 +787,15 @@ public void onResponse(ShardSearchRequest orig) { } } // fork the execution in the search thread pool - runAsync( - getExecutor(executorName, shard), - () -> executeQueryPhase(orig, task, keepStatesInContext, isStreamSearch, listener), - listener - ); + final Executor queryExecutor; + if (shard.isSystem()) { + queryExecutor = threadPool.executor(Names.SYSTEM_READ); + } else if (shard.indexSettings().isSearchThrottled()) { + queryExecutor = threadPool.executor(Names.SEARCH_THROTTLED); + } else { + queryExecutor = threadPool.executor(Names.SEARCH); + } + runAsync(queryExecutor, () -> executeQueryPhase(orig, task, keepStatesInContext, isStreamSearch, listener), listener); } @Override @@ -826,6 +833,12 @@ private SearchPhaseResult executeQueryPhase( assert listener instanceof StreamSearchChannelListener : "Stream search expects StreamSearchChannelListener"; context.setStreamChannelListener((StreamSearchChannelListener) listener); } + + if (request.getStreamingSearchMode() != null) { + context.setStreamingMode(StreamingSearchMode.fromString(request.getStreamingSearchMode())); + } else if (isStreamSearch) { + context.setStreamingMode(StreamingSearchMode.NO_SCORING); + } final long afterQueryTime; try (SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(context)) { loadOrExecuteQueryPhase(request, context); @@ -835,6 +848,13 @@ private SearchPhaseResult executeQueryPhase( afterQueryTime = executor.success(); } if (request.numberOfShards() == 1) { + if (isStreamSearch && logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: shard [{}] sending final {}", + request.shardId(), + (request.numberOfShards() == 1 ? "query+fetch" : "query") + ); + } return executeFetchPhase(readerContext, context, afterQueryTime); } else { // Pass the rescoreDocIds to the queryResult to send them the coordinating node and receive them back in the fetch phase. @@ -842,6 +862,13 @@ private SearchPhaseResult executeQueryPhase( final RescoreDocIds rescoreDocIds = context.rescoreDocIds(); context.queryResult().setRescoreDocIds(rescoreDocIds); readerContext.setRescoreDocIds(rescoreDocIds); + if (isStreamSearch && logger.isTraceEnabled()) { + logger.trace( + "STREAM DEBUG: shard [{}] sending final {}", + request.shardId(), + (request.numberOfShards() == 1 ? "query+fetch" : "query") + ); + } return context.queryResult(); } } catch (Exception e) { @@ -852,6 +879,9 @@ private SearchPhaseResult executeQueryPhase( ? (Exception) exception.getCause() : new OpenSearchException(exception.getCause()); } + if (isStreamSearch && logger.isTraceEnabled()) { + logger.trace("STREAM DEBUG: shard [{}] failure {}", request.shardId(), exception.toString()); + } logger.trace("Query phase failed", exception); processFailure(readerContext, exception); throw exception; diff --git a/server/src/main/java/org/opensearch/search/builder/StreamingSearchParameters.java b/server/src/main/java/org/opensearch/search/builder/StreamingSearchParameters.java new file mode 100644 index 0000000000000..34b7429afe592 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/builder/StreamingSearchParameters.java @@ -0,0 +1,235 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.builder; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +/** + * Streaming search parameters for configuring progressive result emission. + * These parameters control how and when intermediate results are sent. + */ +public class StreamingSearchParameters implements Writeable, ToXContent { + + public static final String STREAMING_FIELD = "streaming"; + public static final String ENABLED_FIELD = "enabled"; + public static final String CONFIDENCE_FIELD = "confidence"; + public static final String BATCH_SIZE_FIELD = "batch_size"; + public static final String EMISSION_INTERVAL_FIELD = "emission_interval"; + public static final String MIN_DOCS_FIELD = "min_docs"; + public static final String ADAPTIVE_BATCHING_FIELD = "adaptive_batching"; + public static final String MILESTONES_FIELD = "milestones"; + + private boolean enabled = false; + private float initialConfidence = 0.99f; + private int batchSize = 10; + private int emissionIntervalMillis = 100; + private int minDocsForStreaming = 5; + private boolean adaptiveBatching = true; + private boolean useMilestones = true; + + public StreamingSearchParameters() {} + + public StreamingSearchParameters(StreamInput in) throws IOException { + this.enabled = in.readBoolean(); + this.initialConfidence = in.readFloat(); + this.batchSize = in.readVInt(); + this.emissionIntervalMillis = in.readVInt(); + this.minDocsForStreaming = in.readVInt(); + this.adaptiveBatching = in.readBoolean(); + this.useMilestones = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(enabled); + out.writeFloat(initialConfidence); + out.writeVInt(batchSize); + out.writeVInt(emissionIntervalMillis); + out.writeVInt(minDocsForStreaming); + out.writeBoolean(adaptiveBatching); + out.writeBoolean(useMilestones); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(STREAMING_FIELD); + builder.field(ENABLED_FIELD, enabled); + builder.field(CONFIDENCE_FIELD, initialConfidence); + builder.field(BATCH_SIZE_FIELD, batchSize); + builder.field(EMISSION_INTERVAL_FIELD, emissionIntervalMillis); + builder.field(MIN_DOCS_FIELD, minDocsForStreaming); + builder.field(ADAPTIVE_BATCHING_FIELD, adaptiveBatching); + builder.field(MILESTONES_FIELD, useMilestones); + builder.endObject(); + return builder; + } + + public static StreamingSearchParameters fromXContent(XContentParser parser) throws IOException { + StreamingSearchParameters params = new StreamingSearchParameters(); + + XContentParser.Token token; + String currentFieldName = null; + + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (ENABLED_FIELD.equals(currentFieldName)) { + params.enabled = parser.booleanValue(); + } else if (CONFIDENCE_FIELD.equals(currentFieldName)) { + params.initialConfidence = parser.floatValue(); + } else if (BATCH_SIZE_FIELD.equals(currentFieldName)) { + params.batchSize = parser.intValue(); + } else if (EMISSION_INTERVAL_FIELD.equals(currentFieldName)) { + params.emissionIntervalMillis = parser.intValue(); + } else if (MIN_DOCS_FIELD.equals(currentFieldName)) { + params.minDocsForStreaming = parser.intValue(); + } else if (ADAPTIVE_BATCHING_FIELD.equals(currentFieldName)) { + params.adaptiveBatching = parser.booleanValue(); + } else if (MILESTONES_FIELD.equals(currentFieldName)) { + params.useMilestones = parser.booleanValue(); + } + } + } + + return params; + } + + // Getters and setters + + public boolean isEnabled() { + return enabled; + } + + public StreamingSearchParameters enabled(boolean enabled) { + this.enabled = enabled; + return this; + } + + public float getInitialConfidence() { + return initialConfidence; + } + + public StreamingSearchParameters initialConfidence(float confidence) { + if (confidence <= 0.0f || confidence > 1.0f) { + throw new IllegalArgumentException("Confidence must be between 0 and 1"); + } + this.initialConfidence = confidence; + return this; + } + + public int getBatchSize() { + return batchSize; + } + + public StreamingSearchParameters batchSize(int batchSize) { + if (batchSize <= 0) { + throw new IllegalArgumentException("Batch size must be positive"); + } + this.batchSize = batchSize; + return this; + } + + public int getEmissionIntervalMillis() { + return emissionIntervalMillis; + } + + public StreamingSearchParameters emissionIntervalMillis(int interval) { + if (interval < 0) { + throw new IllegalArgumentException("Emission interval cannot be negative"); + } + this.emissionIntervalMillis = interval; + return this; + } + + public int getMinDocsForStreaming() { + return minDocsForStreaming; + } + + public StreamingSearchParameters minDocsForStreaming(int minDocs) { + if (minDocs < 0) { + throw new IllegalArgumentException("Min docs cannot be negative"); + } + this.minDocsForStreaming = minDocs; + return this; + } + + public boolean isAdaptiveBatching() { + return adaptiveBatching; + } + + public StreamingSearchParameters adaptiveBatching(boolean adaptive) { + this.adaptiveBatching = adaptive; + return this; + } + + public boolean isUseMilestones() { + return useMilestones; + } + + public StreamingSearchParameters useMilestones(boolean useMilestones) { + this.useMilestones = useMilestones; + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + StreamingSearchParameters that = (StreamingSearchParameters) o; + return enabled == that.enabled + && Float.compare(that.initialConfidence, initialConfidence) == 0 + && batchSize == that.batchSize + && emissionIntervalMillis == that.emissionIntervalMillis + && minDocsForStreaming == that.minDocsForStreaming + && adaptiveBatching == that.adaptiveBatching + && useMilestones == that.useMilestones; + } + + @Override + public int hashCode() { + return Objects.hash( + enabled, + initialConfidence, + batchSize, + emissionIntervalMillis, + minDocsForStreaming, + adaptiveBatching, + useMilestones + ); + } + + @Override + public String toString() { + return "StreamingSearchParameters{" + + "enabled=" + + enabled + + ", confidence=" + + initialConfidence + + ", batchSize=" + + batchSize + + ", emissionInterval=" + + emissionIntervalMillis + + ", minDocs=" + + minDocsForStreaming + + ", adaptive=" + + adaptiveBatching + + ", milestones=" + + useMilestones + + '}'; + } +} diff --git a/server/src/main/java/org/opensearch/search/internal/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/internal/DefaultSearchContext.java new file mode 100644 index 0000000000000..c37cb9f91d22d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/internal/DefaultSearchContext.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.internal; diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index ac38b364fd36b..c7e85bbe9d461 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -78,6 +78,7 @@ import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.streaming.FlushMode; @@ -578,6 +579,24 @@ public boolean isStreamSearch() { return false; } + // Streaming search support - default no-op implementations for compatibility + + public StreamingSearchMode getStreamingMode() { + return null; + } + + public void setStreamingMode(StreamingSearchMode mode) { + // no-op + } + + public boolean isStreamingSearch() { + return getStreamingMode() != null; + } + + public int getStreamingBatchSize() { + return 10; + } + /** * Gets the resolved flush mode for this search context. */ diff --git a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java index de1d5fb8b4098..ee7f46cc661fa 100644 --- a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java @@ -106,6 +106,8 @@ public class ShardSearchRequest extends TransportRequest implements IndicesReque private boolean canReturnNullResponseIfMatchNoDocs; private SearchSortValuesAndFormats bottomSortValues; + private boolean streamingSearch = false; + private String streamingSearchMode = null; // these are the only mutable fields, as they are subject to rewriting private AliasFilter aliasFilter; @@ -173,6 +175,10 @@ public ShardSearchRequest( // If allowPartialSearchResults is unset (ie null), the cluster-level default should have been substituted // at this stage. Any NPEs in the above are therefore an error in request preparation logic. assert searchRequest.allowPartialSearchResults() != null; + + // Set streaming search flag from search request + this.streamingSearch = searchRequest.isStreamingScoring(); + this.streamingSearchMode = searchRequest.getStreamingSearchMode(); } public ShardSearchRequest(ShardId shardId, long nowInMillis, AliasFilter aliasFilter) { @@ -232,6 +238,9 @@ private ShardSearchRequest( this.originalIndices = originalIndices; this.readerId = readerId; this.keepAlive = keepAlive; + // Initialize streaming fields to default values + this.streamingSearch = false; + this.streamingSearchMode = null; assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; } @@ -267,6 +276,14 @@ public ShardSearchRequest(StreamInput in) throws IOException { keepAlive = in.readOptionalTimeValue(); originalIndices = OriginalIndices.readOriginalIndices(in); assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; + // Read streaming fields - gated on version for BWC + if (in.getVersion().onOrAfter(Version.V_3_3_0)) { + streamingSearch = in.readBoolean(); + streamingSearchMode = in.readOptionalString(); + } else { + streamingSearch = false; + streamingSearchMode = null; + } } public ShardSearchRequest(ShardSearchRequest clone) { @@ -290,6 +307,8 @@ public ShardSearchRequest(ShardSearchRequest clone) { this.originalIndices = clone.originalIndices; this.readerId = clone.readerId; this.keepAlive = clone.keepAlive; + this.streamingSearch = clone.streamingSearch; + this.streamingSearchMode = clone.streamingSearchMode; } @Override @@ -297,6 +316,11 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); innerWriteTo(out, false); OriginalIndices.writeOriginalIndices(originalIndices, out); + // Write streaming fields after OriginalIndices - gated on version for BWC + if (out.getVersion().onOrAfter(Version.V_3_3_0)) { + out.writeBoolean(streamingSearch); + out.writeOptionalString(streamingSearchMode); + } } protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException { @@ -397,6 +421,33 @@ public void setInboundNetworkTime(long newTime) { this.inboundNetworkTime = newTime; } + public boolean isStreamingSearch() { + return streamingSearch; + } + + public void setStreamingSearch(boolean streamingSearch) { + this.streamingSearch = streamingSearch; + } + + public String getStreamingSearchMode() { + return streamingSearchMode; + } + + public void setStreamingSearchMode(String streamingSearchMode) { + this.streamingSearchMode = streamingSearchMode; + } + + /** + * Set streaming fields from a SearchRequest + * This is needed for constructors that don't have access to the full SearchRequest + */ + public void setStreamingFieldsFromSearchRequest(SearchRequest searchRequest) { + if (searchRequest != null) { + this.streamingSearch = searchRequest.isStreamingScoring(); + this.streamingSearchMode = searchRequest.getStreamingSearchMode(); + } + } + public long getOutboundNetworkTime() { return outboundNetworkTime; } diff --git a/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java b/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java index b2c97baf78d91..f1737aa09f58b 100644 --- a/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java @@ -44,6 +44,7 @@ import org.opensearch.search.fetch.subphase.ScriptFieldsContext; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.suggest.SuggestionSearchContext; @@ -147,6 +148,14 @@ public boolean hasFetchSourceContext() { return fetchSourceContext != null; } + public StreamingSearchMode getStreamingMode() { + return null; + } + + public void setStreamingMode(StreamingSearchMode mode) { + // no-op + } + @Override public FetchSourceContext fetchSourceContext() { return fetchSourceContext; diff --git a/server/src/main/java/org/opensearch/search/query/BoundProvider.java b/server/src/main/java/org/opensearch/search/query/BoundProvider.java new file mode 100644 index 0000000000000..28b2b83b24149 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/BoundProvider.java @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.TopDocs; +import org.opensearch.common.annotation.ExperimentalApi; + +/** + * Interface for providers that can calculate bounds for streaming search. + * Different search modalities (text, vector, etc.) implement this to provide + * domain-specific bound calculations. + * + * @opensearch.internal + */ +@ExperimentalApi +public interface BoundProvider { + + /** + * Calculate the current bound for the given search context. + * Lower bounds indicate higher confidence in the current results. + * + * @param context The search context + * @return The calculated bound (lower = more confident) + */ + double calculateBound(SearchContext context); + + /** + * Check if the current results are stable enough to emit. + * + * @param context The search context + * @return true if results are stable + */ + boolean isStable(SearchContext context); + + /** + * Get the current progress percentage (0.0 to 1.0). + * + * @param context The search context + * @return Progress as a percentage + */ + double getProgress(SearchContext context); + + /** + * Get the current phase of the search. + * + * @return The current search phase + */ + SearchPhase getCurrentPhase(); + + /** + * Search phases for streaming search. + */ + @ExperimentalApi + enum SearchPhase { + FIRST, // Initial shard-level processing + SECOND, // Shard-level refinement + GLOBAL // Global coordinator processing + } + + /** + * Search context for bound calculations. + */ + @ExperimentalApi + interface SearchContext { + /** + * Get the number of documents processed so far. + */ + int getDocCount(); + + /** + * Get the current top-K results. + */ + TopDocs getTopDocs(); + + /** + * Get the current k-th score. + */ + float getKthScore(); + + /** + * Get the maximum possible score. + */ + float getMaxPossibleScore(); + + /** + * Get the search modality type. + */ + String getModality(); + } +} diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index f8427440a6c13..ff7348d8ab40a 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -200,6 +200,11 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q Query query = searchContext.query(); assert query == searcher.rewrite(query); // already rewritten + // Add streaming path + if (searchContext.isStreamingSearch()) { + return executeStreamingQuery(searchContext, searcher, query); + } + final ScrollContext scrollContext = searchContext.scrollContext(); if (scrollContext != null) { if (scrollContext.totalHits == null) { @@ -309,6 +314,38 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q } } + /** + * Execute streaming query for progressive result emission. + * This method handles the streaming search execution path. + */ + private static boolean executeStreamingQuery(SearchContext searchContext, ContextIndexSearcher searcher, Query query) + throws IOException { + QuerySearchResult queryResult = searchContext.queryResult(); + + try { + // Create streaming collector context + TopDocsCollectorContext streamingCollectorContext = TopDocsCollectorContext.createStreamingTopDocsCollectorContext( + searchContext, + false // hasFilterCollector - simplified for streaming + ); + + // Create collector manager for concurrent segment search compatibility + CollectorManager manager = streamingCollectorContext.createManager(null); + + // Execute search using CollectorManager + ReduceableSearchResult reduceResult = searcher.search(query, manager); + + // Process the result + reduceResult.reduce(queryResult); + + // For streaming, we don't need rescoring phase + return true; // Enable streaming execution + + } catch (Exception e) { + throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute streaming query", e); + } + } + /** * Create runnable which throws {@link TimeExceededException} when the runnable is called after timeout + runnable creation time * exceeds currentTime @@ -412,6 +449,7 @@ public static class TimeExceededException extends RuntimeException { * @opensearch.internal */ public static class DefaultQueryPhaseSearcher implements QueryPhaseSearcher { + private static final Logger logger = LogManager.getLogger(DefaultQueryPhaseSearcher.class); private final AggregationProcessor aggregationProcessor; /** @@ -477,7 +515,17 @@ void postProcess(QuerySearchResult result) throws IOException { if (queryCollectorContextOpt.isPresent()) { return queryCollectorContextOpt.get(); } else { - return createTopDocsCollectorContext(searchContext, hasFilterCollector); + // Check if this is a streaming search request FIRST + if (searchContext.isStreamingSearch()) { + // Use streaming collectors for streaming search + if (logger.isTraceEnabled()) { + logger.trace("Using streaming collector for mode: {}", searchContext.getStreamingMode()); + } + return TopDocsCollectorContext.createStreamingTopDocsCollectorContext(searchContext, hasFilterCollector); + } else { + // Fall back to regular top docs collector + return createTopDocsCollectorContext(searchContext, hasFilterCollector); + } } } diff --git a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java index f3ac953ab9d1d..f61bae9f438cf 100644 --- a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java @@ -52,6 +52,7 @@ import org.opensearch.search.suggest.Suggest; import java.io.IOException; +import java.util.List; import static org.opensearch.common.lucene.Lucene.readTopDocs; import static org.opensearch.common.lucene.Lucene.writeTopDocs; @@ -417,4 +418,36 @@ public TotalHits getTotalHits() { public float getMaxScore() { return maxScore; } + + // Streaming search support + private List docIds; + private boolean partial = false; + + /** + * Set document IDs for streaming search results + */ + public void setDocIds(List docIds) { + this.docIds = docIds; + } + + /** + * Get document IDs for streaming search results + */ + public List getDocIds() { + return docIds; + } + + /** + * Set whether this is a partial result + */ + public void setPartial(boolean partial) { + this.partial = partial; + } + + /** + * Check if this is a partial result + */ + public boolean isPartial() { + return partial; + } } diff --git a/server/src/main/java/org/opensearch/search/query/StreamingCollectorContext.java b/server/src/main/java/org/opensearch/search/query/StreamingCollectorContext.java new file mode 100644 index 0000000000000..3a59e33658464 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/StreamingCollectorContext.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.ScoreDoc; + +import java.util.List; + +/** + * Interface for streaming collectors to emit batches of results. + * This enables progressive emission of search results during collection. + */ +public interface StreamingCollectorContext { + + /** + * Emit a batch of documents to the streaming channel. + * + * @param docs The documents to emit + * @param isFinal Whether this is the final batch + */ + void emitBatch(List docs, boolean isFinal); + + /** + * Get the configured batch size for this collector. + * + * @return The batch size in number of documents + */ + int getBatchSize(); + + /** + * Check if a batch should be emitted based on current state. + * + * @return true if a batch should be emitted now + */ + boolean shouldEmitBatch(); +} diff --git a/server/src/main/java/org/opensearch/search/query/StreamingScoredUnsortedCollectorContext.java b/server/src/main/java/org/opensearch/search/query/StreamingScoredUnsortedCollectorContext.java new file mode 100644 index 0000000000000..4df2a9d3b6462 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/StreamingScoredUnsortedCollectorContext.java @@ -0,0 +1,218 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Streaming collector context for SCORED_UNSORTED mode. + * Collects documents with scores but without sorting for fast emission with relevance. + * + * Implements memory-bounded collection using a "firstK" pattern where only the first K + * documents are retained for the final result. Documents are collected in batches + * controlled by search.streaming.batch_size setting (default: 10, max: 100). + * + * Memory footprint: O(K + batchSize) where K is the requested number of hits. + * + * Circuit Breaker Policy: + * - Batch buffers: No CB checks as they're strictly bounded (10-100 docs) and cleared after emission + * - FirstK list: Protected by parent QueryPhaseResultConsumer's circuit breaker during final reduction + * - Max memory per collector: ~8KB for batch (100 docs * 16 bytes) + ~80KB for firstK (10000 docs * 16 bytes) + * - Decision rationale: The overhead of CB checks (atomic operations) would exceed the memory saved + * for such small, bounded allocations that are immediately released + */ +public class StreamingScoredUnsortedCollectorContext extends TopDocsCollectorContext { + + private final CircuitBreaker circuitBreaker; + private final SearchContext searchContext; + + public StreamingScoredUnsortedCollectorContext(String profilerName, int numHits, SearchContext searchContext) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.circuitBreaker = null; // Will work but no protection + } + + public StreamingScoredUnsortedCollectorContext(String profilerName, int numHits, SearchContext searchContext, CircuitBreaker breaker) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.circuitBreaker = breaker; + } + + @Override + public Collector create(Collector in) throws IOException { + // For SCORED_UNSORTED mode, we need scoring but no sorting + return new StreamingScoredUnsortedCollector(); + } + + @Override + public CollectorManager createManager(CollectorManager in) throws IOException { + return new StreamingScoredUnsortedCollectorManager(); + } + + @Override + public void postProcess(org.opensearch.search.query.QuerySearchResult result) throws IOException { + if (result.hasConsumedTopDocs()) { + return; + } + + if (result.topDocs() == null) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + TotalHits totalHits = new TotalHits(0, TotalHits.Relation.EQUAL_TO); + TopDocs topDocs = new TopDocs(totalHits, scoreDocs); + result.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, Float.NaN), null); + } + } + + /** + * Collector manager for streaming scored unsorted collection + */ + private class StreamingScoredUnsortedCollectorManager + implements + CollectorManager { + + @Override + public StreamingScoredUnsortedCollector newCollector() throws IOException { + return new StreamingScoredUnsortedCollector(); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + // Keep top K by score across collectors (min-heap behavior simulated by linear merge due to small K) + List mergedTopK = new ArrayList<>(); + int totalHits = 0; + float maxScore = Float.NEGATIVE_INFINITY; + + for (StreamingScoredUnsortedCollector collector : collectors) { + List topK = collector.getTopKDocs(); + totalHits += collector.getTotalHitsCount(); + for (ScoreDoc d : topK) { + mergedTopK.add(d); + if (!Float.isNaN(d.score) && d.score > maxScore) { + maxScore = d.score; + } + } + } + + // If more than K, keep highest scores only + if (mergedTopK.size() > numHits()) { + mergedTopK.sort((a, b) -> Float.compare(b.score, a.score)); + mergedTopK = mergedTopK.subList(0, numHits()); + } + + ScoreDoc[] scoreDocs = mergedTopK.toArray(new ScoreDoc[0]); + TopDocs topDocs = new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs); + float finalMaxScore = (maxScore > Float.NEGATIVE_INFINITY) ? maxScore : Float.NaN; + + return result -> result.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, finalMaxScore), null); + } + } + + /** + * Collector that actually collects documents with scores but no sorting + */ + private class StreamingScoredUnsortedCollector implements Collector { + + private final int batchSize = Math.max(1, searchContext != null ? searchContext.getStreamingBatchSize() : 10); + private final List currentBatch = new ArrayList<>(batchSize); + private final List topK = new ArrayList<>(numHits()); + private int totalHitsCount = 0; + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE; // Need scores for SCORED_UNSORTED mode + } + + @Override + public LeafCollector getLeafCollector(org.apache.lucene.index.LeafReaderContext context) throws IOException { + return new LeafCollector() { + private Scorable scorer; + + @Override + public void setScorer(Scorable scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public void collect(int doc) throws IOException { + totalHitsCount++; + float score = this.scorer.score(); + ScoreDoc scoreDoc = new ScoreDoc(doc + context.docBase, score); + + currentBatch.add(scoreDoc); + if (currentBatch.size() >= batchSize) { + emitCurrentBatch(false); + currentBatch.clear(); + } + + if (topK.size() < numHits()) { + topK.add(scoreDoc); + } else { + int minIdx = 0; + float minScore = topK.get(0).score; + for (int i = 1; i < topK.size(); i++) { + if (topK.get(i).score < minScore) { + minScore = topK.get(i).score; + minIdx = i; + } + } + if (score > minScore) { + topK.set(minIdx, scoreDoc); + } + } + } + }; + } + + public List getTopKDocs() { + return topK; + } + + public int getTotalHitsCount() { + return totalHitsCount; + } + + /** + * Emit current batch of collected documents through streaming channel + */ + private void emitCurrentBatch(boolean isFinal) { + if (currentBatch.isEmpty()) return; + + try { + // Create partial result + QuerySearchResult partial = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(currentBatch.size(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + currentBatch.toArray(new ScoreDoc[0]) + ); + partial.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, Float.NaN), null); + partial.setPartial(!isFinal); + + if (searchContext != null && searchContext.getStreamChannelListener() != null) { + searchContext.getStreamChannelListener().onStreamResponse(partial, isFinal); + } + } catch (Exception e) { + // Silently ignore - streaming is best effort + } + } + } +} diff --git a/server/src/main/java/org/opensearch/search/query/StreamingScoringConfig.java b/server/src/main/java/org/opensearch/search/query/StreamingScoringConfig.java new file mode 100644 index 0000000000000..283e8a8d44a46 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/StreamingScoringConfig.java @@ -0,0 +1,109 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.opensearch.common.annotation.ExperimentalApi; + +/** + * Configuration for streaming search with scoring. + * Controls the behavior of streaming search with different scoring modes. + * + * @opensearch.internal + */ +@ExperimentalApi +public class StreamingScoringConfig { + + private final boolean enabled; + private final StreamingSearchMode mode; + private final int minDocsBeforeEmission; + private final int emissionIntervalMillis; + private final int rerankCountShard; + private final int rerankCountGlobal; + private final boolean enablePhaseMarkers; + + /** + * Creates a streaming scoring configuration. + */ + public StreamingScoringConfig( + boolean enabled, + StreamingSearchMode mode, + int minDocsBeforeEmission, + int emissionIntervalMillis, + int rerankCountShard, + int rerankCountGlobal, + boolean enablePhaseMarkers + ) { + this.enabled = enabled; + this.mode = mode; + this.minDocsBeforeEmission = minDocsBeforeEmission; + this.emissionIntervalMillis = emissionIntervalMillis; + this.rerankCountShard = rerankCountShard; + this.rerankCountGlobal = rerankCountGlobal; + this.enablePhaseMarkers = enablePhaseMarkers; + } + + /** + * Returns a disabled configuration. + */ + public static StreamingScoringConfig disabled() { + return new StreamingScoringConfig(false, StreamingSearchMode.SCORED_UNSORTED, 1000, 100, 10, 100, false); + } + + /** + * Returns the default configuration. + */ + public static StreamingScoringConfig defaultConfig() { + return new StreamingScoringConfig(true, StreamingSearchMode.SCORED_UNSORTED, 100, 50, 10, 100, true); + } + + /** + * Creates a configuration for a specific scoring mode. + */ + public static StreamingScoringConfig forMode(StreamingSearchMode mode) { + switch (mode) { + case NO_SCORING: + return new StreamingScoringConfig(true, mode, 10, 10, 5, 50, true); + case SCORED_SORTED: + return new StreamingScoringConfig(true, mode, 100, 50, 10, 100, true); + case SCORED_UNSORTED: + return new StreamingScoringConfig(true, mode, 50, 25, 10, 100, false); + default: + return defaultConfig(); + } + } + + // Getters + public boolean isEnabled() { + return enabled; + } + + public StreamingSearchMode getMode() { + return mode; + } + + public int getMinDocsBeforeEmission() { + return minDocsBeforeEmission; + } + + public int getEmissionIntervalMillis() { + return emissionIntervalMillis; + } + + public int getRerankCountShard() { + return rerankCountShard; + } + + public int getRerankCountGlobal() { + return rerankCountGlobal; + } + + public boolean isEnablePhaseMarkers() { + return enablePhaseMarkers; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/StreamingSearchMode.java b/server/src/main/java/org/opensearch/search/query/StreamingSearchMode.java new file mode 100644 index 0000000000000..189fbf6f81361 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/StreamingSearchMode.java @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.opensearch.common.annotation.ExperimentalApi; + +/** + * Defines the different streaming search strategies based on the design-by-case approach. + * Each mode optimizes for different use cases and performance characteristics. + * + * @opensearch.internal + */ +@ExperimentalApi +public enum StreamingSearchMode { + + /** + * Case 1: No scoring, no sorting - fastest TTFB + * - Shard collector: StreamingUnsortedCollector + * - Ring buffer with batch emission + * - Round-robin merge at coordinator + * - Best for: simple filtering, counting, exists queries + */ + NO_SCORING("no_scoring"), + + /** + * Case 2: Full scoring + explicit sort - production ready + * - Shard collector: StreamingSortedCollector + * - WAND/Block-Max WAND with windowed top-K heap + * - K-way streaming merge at coordinator + * - Best for: scored searches with sorting + */ + SCORED_SORTED("scored_sorted"), + + /** + * Case 3: Full scoring, no sorting - moderate performance + * - Shard collector: StreamingScoredUnsortedCollector + * - Ring buffer with scoring + * - No merge needed at coordinator + * - Best for: scored searches without sorting + */ + SCORED_UNSORTED("scored_unsorted"); + + private final String value; + + StreamingSearchMode(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + /** + * Parse mode from string representation. + * + * @param mode The string representation of the mode + * @return The corresponding StreamingSearchMode + * @throws IllegalArgumentException if mode is unknown + */ + public static StreamingSearchMode fromString(String mode) { + if (mode == null) { + return SCORED_UNSORTED; // Default + } + + for (StreamingSearchMode m : values()) { + if (m.name().equalsIgnoreCase(mode) || m.value.equalsIgnoreCase(mode)) { + return m; + } + } + + throw new IllegalArgumentException("Unknown streaming search mode: " + mode); + } + + /** + * Helper method to check if this mode requires scoring. + * @return true if scoring is required, false otherwise + */ + public boolean requiresScoring() { + return this != NO_SCORING; + } + + /** + * Helper method to check if this mode requires sorting. + * @return true if sorting is required, false otherwise + */ + public boolean requiresSorting() { + return this == SCORED_SORTED; + } + + // Confidence-based mode removed in this branch + + @Override + public String toString() { + return value; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/StreamingSortedCollectorContext.java b/server/src/main/java/org/opensearch/search/query/StreamingSortedCollectorContext.java new file mode 100644 index 0000000000000..952b0acffca48 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/StreamingSortedCollectorContext.java @@ -0,0 +1,146 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.TopScoreDocCollectorManager; +import org.apache.lucene.search.TotalHits; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; + +/** + * Streaming collector context for SCORED_SORTED mode. + * Collects and maintains documents in sorted order (by score or custom sort). + * + * Uses Lucene's TopScoreDocCollectorManager for efficient sorted collection with + * incremental merging. Documents are collected in larger batches (10x default multiplier) + * to amortize sorting costs, controlled by search.streaming.scored_sorted.batch_multiplier. + * + * Memory footprint: O(K) where K is the requested number of hits. + * The TopScoreDocCollector maintains a min-heap of size K. + * + * Circuit Breaker Policy: + * - Heap structure: Protected by TopScoreDocCollector's internal memory management + * - Parent reduction: Protected by QueryPhaseResultConsumer's circuit breaker + * - Max memory per collector: ~80KB for topK heap (10000 docs * 16 bytes) + * - Decision rationale: Sorting requires maintaining all K docs in memory, but Lucene's + * collectors are already optimized for memory efficiency + */ +public class StreamingSortedCollectorContext extends TopDocsCollectorContext { + + private final Sort sort; + private final CircuitBreaker circuitBreaker; + private final SearchContext searchContext; + + public StreamingSortedCollectorContext(String profilerName, int numHits, SearchContext searchContext) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.sort = Sort.RELEVANCE; + this.circuitBreaker = null; + } + + public StreamingSortedCollectorContext(String profilerName, int numHits, SearchContext searchContext, Sort sort) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.sort = sort != null ? sort : Sort.RELEVANCE; + this.circuitBreaker = null; + } + + public StreamingSortedCollectorContext( + String profilerName, + int numHits, + SearchContext searchContext, + Sort sort, + CircuitBreaker breaker + ) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.sort = sort != null ? sort : Sort.RELEVANCE; + this.circuitBreaker = breaker; + } + + public StreamingSortedCollectorContext(String profilerName, int numHits, SearchContext searchContext, CircuitBreaker breaker) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.sort = Sort.RELEVANCE; + this.circuitBreaker = breaker; + } + + @Override + public Collector create(Collector in) throws IOException { + // Use Lucene's top-N score collector for single-threaded execution + return new TopScoreDocCollectorManager(numHits(), null, Integer.MAX_VALUE).newCollector(); + } + + @Override + public CollectorManager createManager(CollectorManager in) throws IOException { + return new StreamingSortedCollectorManager(); + } + + @Override + public void postProcess(org.opensearch.search.query.QuerySearchResult result) throws IOException { + if (result.hasConsumedTopDocs()) { + return; + } + + if (result.topDocs() == null) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + TotalHits totalHits = new TotalHits(0, TotalHits.Relation.EQUAL_TO); + TopDocs topDocs = new TopDocs(totalHits, scoreDocs); + result.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, Float.NaN), null); + } + } + + /** + * Collector manager for streaming sorted collection + */ + private class StreamingSortedCollectorManager implements CollectorManager { + + private final CollectorManager, ? extends TopDocs> manager; + + private StreamingSortedCollectorManager() { + this.manager = new TopScoreDocCollectorManager(numHits(), null, Integer.MAX_VALUE); + } + + @Override + public Collector newCollector() throws IOException { + // Use Lucene's collector manager for score-sorted collection + return manager.newCollector(); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final Collection> topDocsCollectors = new ArrayList<>(); + for (Collector collector : collectors) { + if (collector instanceof TopDocsCollector) { + topDocsCollectors.add((TopDocsCollector) collector); + } + } + + // Reduce with Lucene's manager + @SuppressWarnings("unchecked") + final TopDocs topDocs = ((CollectorManager, ? extends TopDocs>) manager).reduce(topDocsCollectors); + + final float computedMaxScore = (topDocs.scoreDocs != null && topDocs.scoreDocs.length > 0) + ? topDocs.scoreDocs[0].score + : Float.NaN; + + return result -> result.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, computedMaxScore), null); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/query/StreamingUnsortedCollectorContext.java b/server/src/main/java/org/opensearch/search/query/StreamingUnsortedCollectorContext.java new file mode 100644 index 0000000000000..700f65806bb02 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/StreamingUnsortedCollectorContext.java @@ -0,0 +1,196 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Streaming collector context for NO_SCORING mode. + * Collects documents without scoring for fastest emission. + * + * Implements memory-bounded collection using a "firstK" pattern where only the first K + * documents are retained for the final result. Documents are collected in batches + * controlled by search.streaming.batch_size setting (default: 10, max: 100). + * + * Memory footprint: O(K + batchSize) where K is the requested number of hits. + * + * Circuit Breaker Policy: + * - Batch buffers: No CB checks as they're strictly bounded (10-100 docs) and cleared after emission + * - FirstK list: Protected by parent QueryPhaseResultConsumer's circuit breaker during final reduction + * - Max memory per collector: ~8KB for batch (100 docs * 16 bytes) + ~80KB for firstK (10000 docs * 16 bytes) + * - Decision rationale: The overhead of CB checks (atomic operations) would exceed the memory saved + * for such small, bounded allocations that are immediately released + */ +public class StreamingUnsortedCollectorContext extends TopDocsCollectorContext { + + private static final Logger logger = LogManager.getLogger(StreamingUnsortedCollectorContext.class); + + private final CircuitBreaker circuitBreaker; + private final SearchContext searchContext; + + public StreamingUnsortedCollectorContext(String profilerName, int numHits, SearchContext searchContext) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.circuitBreaker = null; // Will work but no protection + } + + public StreamingUnsortedCollectorContext(String profilerName, int numHits, SearchContext searchContext, CircuitBreaker breaker) { + super(profilerName, numHits); + this.searchContext = searchContext; + this.circuitBreaker = breaker; + } + + @Override + public Collector create(Collector in) throws IOException { + // For NO_SCORING mode, we don't need scoring + return new StreamingUnsortedCollector(); + } + + @Override + public CollectorManager createManager(CollectorManager in) throws IOException { + return new StreamingUnsortedCollectorManager(); + } + + @Override + public void postProcess(org.opensearch.search.query.QuerySearchResult result) throws IOException { + if (result.hasConsumedTopDocs()) { + return; + } + + if (result.topDocs() == null) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + TotalHits totalHits = new TotalHits(0, TotalHits.Relation.EQUAL_TO); + TopDocs topDocs = new TopDocs(totalHits, scoreDocs); + result.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, Float.NaN), null); + } + } + + /** + * Collector manager for streaming unsorted collection + */ + private class StreamingUnsortedCollectorManager implements CollectorManager { + + @Override + public StreamingUnsortedCollector newCollector() throws IOException { + return new StreamingUnsortedCollector(); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + List mergedFirstK = new ArrayList<>(); + int totalHits = 0; + + for (StreamingUnsortedCollector collector : collectors) { + mergedFirstK.addAll(collector.getFirstKDocs()); + totalHits += collector.getTotalHitsCount(); + } + + if (mergedFirstK.size() > numHits()) { + mergedFirstK = mergedFirstK.subList(0, numHits()); + } + + ScoreDoc[] scoreDocs = mergedFirstK.toArray(new ScoreDoc[0]); + TopDocs topDocs = new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs); + + return result -> result.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, Float.NaN), null); + } + } + + /** + * Collector that actually collects documents without scoring + */ + private class StreamingUnsortedCollector implements Collector { + + private final int batchSize = Math.max(1, searchContext != null ? searchContext.getStreamingBatchSize() : 10); + private final List currentBatch = new ArrayList<>(batchSize); + private final List firstK = new ArrayList<>(numHits()); + private int totalHitsCount = 0; + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; // No scoring needed for NO_SCORING mode + } + + @Override + public LeafCollector getLeafCollector(org.apache.lucene.index.LeafReaderContext context) throws IOException { + return new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + totalHitsCount++; + ScoreDoc scoreDoc = new ScoreDoc(doc + context.docBase, Float.NaN); + + currentBatch.add(scoreDoc); + if (currentBatch.size() >= batchSize) { + emitCurrentBatch(false); + currentBatch.clear(); + } + + if (firstK.size() < numHits()) { + firstK.add(scoreDoc); + } + } + }; + } + + public List getFirstKDocs() { + return firstK; + } + + public int getTotalHitsCount() { + return totalHitsCount; + } + + /** + * Emit current batch of collected documents through streaming channel + */ + private void emitCurrentBatch(boolean isFinal) { + if (currentBatch.isEmpty()) return; + + try { + // Create partial result + QuerySearchResult partial = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(currentBatch.size(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + currentBatch.toArray(new ScoreDoc[0]) + ); + partial.topDocs(new org.opensearch.common.lucene.search.TopDocsAndMaxScore(topDocs, Float.NaN), null); + partial.setPartial(!isFinal); + + if (searchContext != null && searchContext.getStreamChannelListener() != null) { + searchContext.getStreamChannelListener().onStreamResponse(partial, isFinal); + } + + if (!isFinal) { + currentBatch.clear(); + } + } catch (Exception e) { + logger.trace("Failed to emit streaming batch", e); + } + } + } +} diff --git a/server/src/main/java/org/opensearch/search/query/TextBoundProvider.java b/server/src/main/java/org/opensearch/search/query/TextBoundProvider.java new file mode 100644 index 0000000000000..df2de3003e211 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/TextBoundProvider.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.TopDocs; +import org.opensearch.search.query.BoundProvider.SearchContext; +import org.opensearch.search.query.BoundProvider.SearchPhase; + +/** + * Text-based bound provider for streaming search. + * Implements WAND/BMW-style bounds for text search scenarios. + * + * @opensearch.internal + */ +public class TextBoundProvider implements BoundProvider { + + private final int minDocsForStability; + private final double stabilityThreshold; + + public TextBoundProvider(int minDocsForStability, double stabilityThreshold) { + this.minDocsForStability = minDocsForStability; + this.stabilityThreshold = stabilityThreshold; + } + + @Override + public double calculateBound(SearchContext context) { + if (context.getDocCount() < minDocsForStability) { + return Double.MAX_VALUE; // Not enough docs for confidence + } + + TopDocs topDocs = context.getTopDocs(); + if (topDocs == null || topDocs.scoreDocs.length == 0) { + return Double.MAX_VALUE; + } + + // Simple bound: ratio of k-th score to max possible score + float kthScore = context.getKthScore(); + float maxPossible = context.getMaxPossibleScore(); + + if (maxPossible <= 0 || kthScore <= 0) { + return Double.MAX_VALUE; + } + + // Lower bound means higher confidence + // Calculate how close kth score is to max possible score + // When kth is close to max, bound should be low (high confidence) + // When kth is far from max, bound should be high (low confidence) + double bound = (maxPossible - kthScore) / maxPossible; + + // Ensure the bound is reasonable (between 0 and 1) + return Math.max(0.0, Math.min(1.0, bound)); + } + + @Override + public boolean isStable(SearchContext context) { + if (context.getDocCount() < minDocsForStability) { + return false; + } + + double bound = calculateBound(context); + // Lower bound means higher confidence, so we want bound <= threshold + // With our new calculation: bound of 0.5 means kth is halfway to max + // We want this to be stable when we have enough docs + return bound <= stabilityThreshold; + } + + @Override + public double getProgress(SearchContext context) { + // Simple progress based on document count + // This could be enhanced with more sophisticated progress tracking + return Math.min(1.0, (double) context.getDocCount() / 1000.0); + } + + @Override + public SearchPhase getCurrentPhase() { + // Text search typically goes through FIRST -> SECOND -> GLOBAL + // This is a simplified implementation + return SearchPhase.FIRST; + } + + /** + * Create a default text bound provider. + */ + public static TextBoundProvider createDefault() { + return new TextBoundProvider(100, 0.5); // Increased threshold to 0.5 for testing + } +} diff --git a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java index 0d0022aaa8772..9bab5e29b1d75 100644 --- a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java @@ -32,6 +32,8 @@ package org.opensearch.search.query; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexOptions; @@ -72,6 +74,7 @@ import org.opensearch.common.lucene.search.function.FunctionScoreQuery; import org.opensearch.common.lucene.search.function.ScriptScoreQuery; import org.opensearch.common.util.CachedSupplier; +import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery; import org.opensearch.search.DocValueFormat; import org.opensearch.search.approximate.ApproximateScoreQuery; @@ -98,6 +101,7 @@ * @opensearch.internal */ public abstract class TopDocsCollectorContext extends QueryCollectorContext implements RescoringQueryCollectorContext { + private static final Logger logger = LogManager.getLogger(TopDocsCollectorContext.class); protected final int numHits; TopDocsCollectorContext(String profilerName, int numHits) { @@ -758,6 +762,58 @@ void postProcess(QuerySearchResult result) throws IOException { } } + /** + * Creates a streaming {@link TopDocsCollectorContext} for streaming search with scoring. + * This method routes to the appropriate streaming collector based on the search mode. + */ + public static TopDocsCollectorContext createStreamingTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector) + throws IOException { + + // Get circuit breaker from search context + CircuitBreaker circuitBreaker = null; + try { + // Try to get REQUEST circuit breaker + if (searchContext.bigArrays() != null && searchContext.bigArrays().breakerService() != null) { + circuitBreaker = searchContext.bigArrays().breakerService().getBreaker(CircuitBreaker.REQUEST); + } + } catch (Exception e) { + logger.warn("Failed to get circuit breaker for streaming search", e); + } + + if (circuitBreaker == null) { + logger.warn("No circuit breaker available for streaming search - memory protection disabled"); + } + + StreamingSearchMode mode = searchContext.getStreamingMode(); + if (mode == null) { + throw new IllegalArgumentException("Streaming mode must be set for streaming collectors"); + } + + switch (mode) { + case NO_SCORING: + return new StreamingUnsortedCollectorContext("streaming_no_scoring", searchContext.size(), searchContext, circuitBreaker); + case SCORED_UNSORTED: + return new StreamingScoredUnsortedCollectorContext( + "streaming_scored_unsorted", + searchContext.size(), + searchContext, + circuitBreaker + ); + case SCORED_SORTED: + SortAndFormats sortAndFormats = searchContext.sort(); + Sort sort = (sortAndFormats != null) ? sortAndFormats.sort : Sort.RELEVANCE; + return new StreamingSortedCollectorContext( + "streaming_scored_sorted", + searchContext.size(), + searchContext, + sort, + circuitBreaker + ); + default: + throw new IllegalArgumentException("Unknown streaming mode: " + mode); + } + } + /** * Returns query total hit count if the query is a {@link MatchAllDocsQuery} * or a {@link TermQuery} and the reader has no deletions, @@ -822,6 +878,12 @@ static int shortcutTotalHitCount(IndexReader reader, Query query) throws IOExcep */ public static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector) throws IOException { + + // Check for streaming search first + if (searchContext.isStreamingSearch() && searchContext.getStreamingMode() != null) { + return createStreamingTopDocsCollectorContext(searchContext, hasFilterCollector); + } + final IndexReader reader = searchContext.searcher().getIndexReader(); final Query query = searchContext.query(); // top collectors don't like a size of 0 diff --git a/server/src/main/java/org/opensearch/search/query/stream/package-info.java b/server/src/main/java/org/opensearch/search/query/stream/package-info.java new file mode 100644 index 0000000000000..ed84fbefcccb2 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/stream/package-info.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Streaming search query execution components. + * + * This package contains classes for streaming search functionality that provides + * progressive result emission during query execution. Key components include: + * + * Components for streaming search functionality that provides progressive result emission. + * + * @opensearch.internal + */ +package org.opensearch.search.query.stream; diff --git a/server/src/main/java/org/opensearch/search/streaming/StreamingSearchMetrics.java b/server/src/main/java/org/opensearch/search/streaming/StreamingSearchMetrics.java new file mode 100644 index 0000000000000..ac58382406bd1 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/streaming/StreamingSearchMetrics.java @@ -0,0 +1,384 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.streaming; + +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.common.metrics.MeanMetric; +import org.opensearch.telemetry.metrics.Counter; +import org.opensearch.telemetry.metrics.Histogram; +import org.opensearch.telemetry.metrics.MetricsRegistry; +import org.opensearch.telemetry.metrics.tags.Tags; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.LongAdder; + +/** + * Comprehensive metrics tracking for streaming search operations. + * Provides detailed insights for production monitoring and optimization. + */ +public class StreamingSearchMetrics { + + // Core metrics + private final CounterMetric totalStreamingSearches = new CounterMetric(); + private final CounterMetric successfulStreamingSearches = new CounterMetric(); + private final CounterMetric failedStreamingSearches = new CounterMetric(); + private final CounterMetric fallbackToNormalSearches = new CounterMetric(); + + // Performance metrics + private final MeanMetric timeToFirstResult = new MeanMetric(); + private final MeanMetric totalSearchTime = new MeanMetric(); + private final MeanMetric emissionLatency = new MeanMetric(); + private final MeanMetric confidenceAtEmission = new MeanMetric(); + + // Efficiency metrics + private final LongAdder totalDocsEvaluated = new LongAdder(); + private final LongAdder totalDocsSkipped = new LongAdder(); + private final LongAdder totalBlocksProcessed = new LongAdder(); + private final LongAdder totalBlocksSkipped = new LongAdder(); + private final MeanMetric skipRatio = new MeanMetric(); + + // Emission metrics + private final CounterMetric totalEmissions = new CounterMetric(); + private final MeanMetric docsPerEmission = new MeanMetric(); + private final MeanMetric emissionsPerSearch = new MeanMetric(); + private final AtomicLong totalDocsEmitted = new AtomicLong(); + + // Quality metrics + private final MeanMetric precisionRate = new MeanMetric(); + private final MeanMetric recallRate = new MeanMetric(); + private final CounterMetric confidenceViolations = new CounterMetric(); + private final CounterMetric reorderingRequired = new CounterMetric(); + + // Resource metrics + private final AtomicLong currentActiveStreams = new AtomicLong(); + private final AtomicLong peakActiveStreams = new AtomicLong(); + private final AtomicLong totalMemoryUsed = new AtomicLong(); + private final AtomicLong peakMemoryUsed = new AtomicLong(); + + // Network metrics + private final LongAdder totalBytesStreamed = new LongAdder(); + private final CounterMetric clientDisconnections = new CounterMetric(); + private final CounterMetric backpressureEvents = new CounterMetric(); + + // Circuit breaker metrics + private final CounterMetric circuitBreakerTrips = new CounterMetric(); + private final AtomicLong currentCircuitBreakerMemory = new AtomicLong(); + + // Per-index metrics + private final Map indexMetrics = new ConcurrentHashMap<>(); + + // OpenTelemetry metrics + private Counter streamingSearchCounter; + private Histogram timeToFirstResultHistogram; + private Histogram confidenceHistogram; + private Counter emissionCounter; + + public StreamingSearchMetrics(MetricsRegistry metricsRegistry) { + if (metricsRegistry != null) { + initializeOpenTelemetryMetrics(metricsRegistry); + } + } + + private void initializeOpenTelemetryMetrics(MetricsRegistry registry) { + // Register OpenTelemetry metrics + this.streamingSearchCounter = registry.createCounter( + "streaming_search_requests_total", + "Total number of streaming search requests", + "requests" + ); + + this.timeToFirstResultHistogram = registry.createHistogram( + "streaming_search_time_to_first_result", + "Time to first result in streaming search", + "milliseconds" + ); + + this.confidenceHistogram = registry.createHistogram( + "streaming_search_confidence_at_emission", + "Confidence level when emitting results", + "ratio" + ); + + this.emissionCounter = registry.createCounter("streaming_search_emissions_total", "Total number of result emissions", "emissions"); + } + + /** + * Record the start of a streaming search + */ + public StreamingSearchContext startSearch(String index) { + totalStreamingSearches.inc(); + long activeStreams = currentActiveStreams.incrementAndGet(); + updatePeakActiveStreams(activeStreams); + + if (streamingSearchCounter != null) { + streamingSearchCounter.add(1, Tags.create().addTag("index", index)); + } + + return new StreamingSearchContext(index, System.nanoTime()); + } + + /** + * Record successful search completion + */ + public void recordSuccess(StreamingSearchContext context) { + successfulStreamingSearches.inc(); + currentActiveStreams.decrementAndGet(); + + long duration = System.nanoTime() - context.startTime; + totalSearchTime.inc(duration / 1_000_000); // Convert to milliseconds + + if (context.firstResultTime > 0) { + long timeToFirst = (context.firstResultTime - context.startTime) / 1_000_000; + timeToFirstResult.inc(timeToFirst); + + if (timeToFirstResultHistogram != null) { + timeToFirstResultHistogram.record(timeToFirst, Tags.create().addTag("index", context.index)); + } + } + + // Update per-index metrics + getIndexMetrics(context.index).recordSuccess(duration / 1_000_000); + } + + /** + * Record search failure + */ + public void recordFailure(StreamingSearchContext context, Throwable error) { + failedStreamingSearches.inc(); + currentActiveStreams.decrementAndGet(); + + getIndexMetrics(context.index).recordFailure(error.getClass().getSimpleName()); + } + + /** + * Record fallback to normal search + */ + public void recordFallback(String reason) { + fallbackToNormalSearches.inc(); + currentActiveStreams.decrementAndGet(); + } + + /** + * Record emission event + */ + public void recordEmission(StreamingSearchContext context, int docsEmitted, float confidence) { + totalEmissions.inc(); + docsPerEmission.inc(docsEmitted); + totalDocsEmitted.addAndGet(docsEmitted); + confidenceAtEmission.inc((long) (confidence * 100)); + + if (context.firstResultTime == 0) { + context.firstResultTime = System.nanoTime(); + } + + context.totalEmissions++; + context.totalDocsEmitted += docsEmitted; + + // Record emission latency + long now = System.nanoTime(); + if (context.lastEmissionTime > 0) { + emissionLatency.inc((now - context.lastEmissionTime) / 1_000_000); + } + context.lastEmissionTime = now; + + if (emissionCounter != null) { + emissionCounter.add( + docsEmitted, + Tags.create().addTag("index", context.index).addTag("batch", String.valueOf(context.totalEmissions)) + ); + } + + if (confidenceHistogram != null) { + confidenceHistogram.record((long) (confidence * 100), Tags.create().addTag("index", context.index)); + } + } + + /** + * Record block processing statistics + */ + public void recordBlockProcessing(int docsEvaluated, int docsSkipped, int blocksProcessed, int blocksSkipped) { + totalDocsEvaluated.add(docsEvaluated); + totalDocsSkipped.add(docsSkipped); + totalBlocksProcessed.add(blocksProcessed); + totalBlocksSkipped.add(blocksSkipped); + + if (blocksProcessed + blocksSkipped > 0) { + float ratio = (float) blocksSkipped / (blocksProcessed + blocksSkipped); + skipRatio.inc((long) (ratio * 100)); + } + } + + /** + * Record memory usage + */ + public void recordMemoryUsage(long bytesUsed) { + totalMemoryUsed.addAndGet(bytesUsed); + long currentPeak = peakMemoryUsed.get(); + if (bytesUsed > currentPeak) { + peakMemoryUsed.compareAndSet(currentPeak, bytesUsed); + } + } + + /** + * Record network statistics + */ + public void recordBytesStreamed(long bytes) { + totalBytesStreamed.add(bytes); + } + + public void recordClientDisconnection() { + clientDisconnections.inc(); + } + + public void recordBackpressure() { + backpressureEvents.inc(); + } + + /** + * Record circuit breaker event + */ + public void recordCircuitBreakerTrip() { + circuitBreakerTrips.inc(); + } + + public void updateCircuitBreakerMemory(long bytes) { + currentCircuitBreakerMemory.set(bytes); + } + + /** + * Record quality metrics + */ + public void recordQualityMetrics(float precision, float recall) { + precisionRate.inc((long) (precision * 100)); + recallRate.inc((long) (recall * 100)); + } + + public void recordConfidenceViolation() { + confidenceViolations.inc(); + } + + public void recordReordering() { + reorderingRequired.inc(); + } + + private void updatePeakActiveStreams(long current) { + long peak = peakActiveStreams.get(); + while (current > peak) { + if (peakActiveStreams.compareAndSet(peak, current)) { + break; + } + peak = peakActiveStreams.get(); + } + } + + private IndexStreamingMetrics getIndexMetrics(String index) { + return indexMetrics.computeIfAbsent(index, k -> new IndexStreamingMetrics()); + } + + /** + * Get current statistics + */ + public StreamingSearchStats getStats() { + return new StreamingSearchStats(this); + } + + /** + * Context for tracking individual search metrics + */ + public static class StreamingSearchContext { + public final String index; + public final long startTime; + public long firstResultTime = 0; + public long lastEmissionTime = 0; + public int totalEmissions = 0; + public int totalDocsEmitted = 0; + + StreamingSearchContext(String index, long startTime) { + this.index = index; + this.startTime = startTime; + } + } + + /** + * Per-index metrics + */ + private static class IndexStreamingMetrics { + private final CounterMetric searches = new CounterMetric(); + private final CounterMetric successes = new CounterMetric(); + private final Map failuresByType = new ConcurrentHashMap<>(); + private final MeanMetric avgSearchTime = new MeanMetric(); + + void recordSuccess(long duration) { + searches.inc(); + successes.inc(); + avgSearchTime.inc(duration); + } + + void recordFailure(String errorType) { + searches.inc(); + failuresByType.computeIfAbsent(errorType, k -> new LongAdder()).increment(); + } + } + + /** + * Statistics snapshot + */ + public static class StreamingSearchStats { + public final long totalSearches; + public final long successfulSearches; + public final long failedSearches; + public final long fallbacks; + public final double avgTimeToFirstResult; + public final double avgTotalSearchTime; + public final double avgConfidenceAtEmission; + public final long totalEmissions; + public final double avgDocsPerEmission; + public final long totalDocsEmitted; + public final long totalDocsEvaluated; + public final long totalDocsSkipped; + public final double skipRatio; + public final long currentActiveStreams; + public final long peakActiveStreams; + public final long totalMemoryUsed; + public final long peakMemoryUsed; + public final long totalBytesStreamed; + public final long clientDisconnections; + public final long backpressureEvents; + public final long circuitBreakerTrips; + + StreamingSearchStats(StreamingSearchMetrics metrics) { + this.totalSearches = metrics.totalStreamingSearches.count(); + this.successfulSearches = metrics.successfulStreamingSearches.count(); + this.failedSearches = metrics.failedStreamingSearches.count(); + this.fallbacks = metrics.fallbackToNormalSearches.count(); + this.avgTimeToFirstResult = metrics.timeToFirstResult.mean(); + this.avgTotalSearchTime = metrics.totalSearchTime.mean(); + this.avgConfidenceAtEmission = metrics.confidenceAtEmission.mean() / 100.0; + this.totalEmissions = metrics.totalEmissions.count(); + this.avgDocsPerEmission = metrics.docsPerEmission.mean(); + this.totalDocsEmitted = metrics.totalDocsEmitted.get(); + this.totalDocsEvaluated = metrics.totalDocsEvaluated.sum(); + this.totalDocsSkipped = metrics.totalDocsSkipped.sum(); + + long totalBlocks = metrics.totalBlocksProcessed.sum() + metrics.totalBlocksSkipped.sum(); + this.skipRatio = totalBlocks > 0 ? (double) metrics.totalBlocksSkipped.sum() / totalBlocks : 0; + + this.currentActiveStreams = metrics.currentActiveStreams.get(); + this.peakActiveStreams = metrics.peakActiveStreams.get(); + this.totalMemoryUsed = metrics.totalMemoryUsed.get(); + this.peakMemoryUsed = metrics.peakMemoryUsed.get(); + this.totalBytesStreamed = metrics.totalBytesStreamed.sum(); + this.clientDisconnections = metrics.clientDisconnections.count(); + this.backpressureEvents = metrics.backpressureEvents.count(); + this.circuitBreakerTrips = metrics.circuitBreakerTrips.count(); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/streaming/StreamingSearchSettings.java b/server/src/main/java/org/opensearch/search/streaming/StreamingSearchSettings.java new file mode 100644 index 0000000000000..ca7d4f38002ae --- /dev/null +++ b/server/src/main/java/org/opensearch/search/streaming/StreamingSearchSettings.java @@ -0,0 +1,369 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.streaming; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeUnit; +import org.opensearch.core.common.unit.ByteSizeValue; + +import java.util.Arrays; +import java.util.List; + +/** + * Production-ready settings for streaming search with comprehensive configuration options. + * All settings are dynamically updateable for runtime tuning. + */ +public final class StreamingSearchSettings { + + // Feature flags + public static final Setting STREAMING_SEARCH_ENABLED = Setting.boolSetting( + "search.streaming.enabled", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_SEARCH_ENABLED_FOR_EXPENSIVE_QUERIES = Setting.boolSetting( + "search.streaming.expensive_queries.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Performance settings + public static final Setting STREAMING_BLOCK_SIZE = Setting.intSetting( + "search.streaming.block_size", + 128, + 16, + 1024, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_BATCH_SIZE = Setting.intSetting( + "search.streaming.batch_size", + 10, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Batch reduce size multipliers per mode (coordinator-side) + public static final Setting STREAMING_NO_SCORING_BATCH_MULTIPLIER = Setting.intSetting( + "search.streaming.no_scoring.batch_multiplier", + 1, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_SCORED_UNSORTED_BATCH_MULTIPLIER = Setting.intSetting( + "search.streaming.scored_unsorted.batch_multiplier", + 2, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_SCORED_SORTED_BATCH_MULTIPLIER = Setting.intSetting( + "search.streaming.scored_sorted.batch_multiplier", + 10, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_EMISSION_INTERVAL = Setting.timeSetting( + "search.streaming.emission_interval", + TimeValue.timeValueMillis(50), + TimeValue.timeValueMillis(10), + TimeValue.timeValueSeconds(1), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Thresholds + public static final Setting STREAMING_MIN_DOCS_FOR_STREAMING = Setting.intSetting( + "search.streaming.min_docs_for_streaming", + 10, + 1, + 1000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_MIN_SHARD_RESPONSE_RATIO = Setting.floatSetting( + "search.streaming.min_shard_response_ratio", + 0.2f, + 0.1f, + 0.9f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_OUTLIER_THRESHOLD_SIGMA = Setting.floatSetting( + "search.streaming.outlier_threshold_sigma", + 2.0f, + 1.0f, + 4.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Memory management + public static final Setting STREAMING_MAX_BUFFER_SIZE = Setting.byteSizeSetting( + "search.streaming.max_buffer_size", + new ByteSizeValue(10, ByteSizeUnit.MB), + new ByteSizeValue(1, ByteSizeUnit.MB), + new ByteSizeValue(100, ByteSizeUnit.MB), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_MAX_CONCURRENT_STREAMS = Setting.intSetting( + "search.streaming.max_concurrent_streams", + 100, + 1, + 10000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Network settings + public static final Setting STREAMING_CLIENT_TIMEOUT = Setting.timeSetting( + "search.streaming.client_timeout", + TimeValue.timeValueSeconds(30), + TimeValue.timeValueSeconds(1), + TimeValue.timeValueMinutes(5), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_COMPRESSION_ENABLED = Setting.boolSetting( + "search.streaming.compression.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Circuit breaker settings + public static final Setting STREAMING_CIRCUIT_BREAKER_LIMIT = Setting.memorySizeSetting( + "indices.breaker.streaming.limit", + "10%", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_CIRCUIT_BREAKER_OVERHEAD = Setting.floatSetting( + "indices.breaker.streaming.overhead", + 1.0f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Monitoring settings + public static final Setting STREAMING_METRICS_ENABLED = Setting.boolSetting( + "search.streaming.metrics.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_METRICS_INTERVAL = Setting.timeSetting( + "search.streaming.metrics.interval", + TimeValue.timeValueSeconds(10), + TimeValue.timeValueSeconds(1), + TimeValue.timeValueMinutes(1), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Advanced tuning + public static final Setting STREAMING_BLOCK_SKIP_THRESHOLD_RATIO = Setting.floatSetting( + "search.streaming.block_skip_threshold_ratio", + 0.3f, + 0.1f, + 0.9f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_MIN_COMPETITIVE_DOCS = Setting.intSetting( + "search.streaming.min_competitive_docs", + 100, + 10, + 10000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_SCORE_MODE = Setting.simpleString("search.streaming.score_mode", "COMPLETE", value -> { + if (!Arrays.asList("COMPLETE", "TOP_SCORES", "MAX_SCORE").contains(value.toUpperCase(java.util.Locale.ROOT))) { + throw new IllegalArgumentException("Invalid score mode: " + value); + } + }, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // Experimental features + public static final Setting STREAMING_ADAPTIVE_BATCHING = Setting.boolSetting( + "search.streaming.adaptive_batching.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting STREAMING_PREDICTIVE_SCORING = Setting.boolSetting( + "search.streaming.predictive_scoring.enabled", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Returns all streaming search settings + */ + public static List> getAllSettings() { + return Arrays.asList( + STREAMING_SEARCH_ENABLED, + STREAMING_SEARCH_ENABLED_FOR_EXPENSIVE_QUERIES, + STREAMING_BLOCK_SIZE, + STREAMING_BATCH_SIZE, + STREAMING_NO_SCORING_BATCH_MULTIPLIER, + STREAMING_SCORED_UNSORTED_BATCH_MULTIPLIER, + + STREAMING_SCORED_SORTED_BATCH_MULTIPLIER, + STREAMING_EMISSION_INTERVAL, + + STREAMING_MIN_DOCS_FOR_STREAMING, + STREAMING_MIN_SHARD_RESPONSE_RATIO, + STREAMING_OUTLIER_THRESHOLD_SIGMA, + STREAMING_MAX_BUFFER_SIZE, + STREAMING_MAX_CONCURRENT_STREAMS, + STREAMING_CLIENT_TIMEOUT, + STREAMING_COMPRESSION_ENABLED, + STREAMING_CIRCUIT_BREAKER_LIMIT, + STREAMING_CIRCUIT_BREAKER_OVERHEAD, + STREAMING_METRICS_ENABLED, + STREAMING_METRICS_INTERVAL, + STREAMING_BLOCK_SKIP_THRESHOLD_RATIO, + STREAMING_MIN_COMPETITIVE_DOCS, + STREAMING_SCORE_MODE, + STREAMING_ADAPTIVE_BATCHING, + STREAMING_PREDICTIVE_SCORING + ); + } + + /** + * Configuration holder for streaming search + */ + public static class StreamingSearchConfig { + private final Settings settings; + private final ClusterSettings clusterSettings; + + // Cached values for performance + private volatile boolean enabled; + private volatile int blockSize; + private volatile int batchSize; + private volatile long emissionIntervalMillis; + + public StreamingSearchConfig(Settings settings, ClusterSettings clusterSettings) { + this.settings = settings; + this.clusterSettings = clusterSettings; + + // Initialize cached values + updateCachedValues(); + + // Register update listeners + clusterSettings.addSettingsUpdateConsumer(STREAMING_SEARCH_ENABLED, this::setEnabled); + clusterSettings.addSettingsUpdateConsumer(STREAMING_BLOCK_SIZE, this::setBlockSize); + clusterSettings.addSettingsUpdateConsumer(STREAMING_BATCH_SIZE, this::setBatchSize); + clusterSettings.addSettingsUpdateConsumer( + STREAMING_EMISSION_INTERVAL, + interval -> this.emissionIntervalMillis = interval.millis() + ); + } + + private void updateCachedValues() { + this.enabled = STREAMING_SEARCH_ENABLED.get(settings); + this.blockSize = STREAMING_BLOCK_SIZE.get(settings); + this.batchSize = STREAMING_BATCH_SIZE.get(settings); + this.emissionIntervalMillis = STREAMING_EMISSION_INTERVAL.get(settings).millis(); + } + + // Fast getters for hot path + public boolean isEnabled() { + return enabled; + } + + public int getBlockSize() { + return blockSize; + } + + public int getBatchSize() { + return batchSize; + } + + public long getEmissionIntervalMillis() { + return emissionIntervalMillis; + } + + // Setters for dynamic updates + private void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + private void setBlockSize(int blockSize) { + this.blockSize = blockSize; + } + + private void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + // Get non-cached values + public int getMinDocsForStreaming() { + return STREAMING_MIN_DOCS_FOR_STREAMING.get(settings); + } + + public float getMinShardResponseRatio() { + return STREAMING_MIN_SHARD_RESPONSE_RATIO.get(settings); + } + + public float getOutlierThresholdSigma() { + return STREAMING_OUTLIER_THRESHOLD_SIGMA.get(settings); + } + + public ByteSizeValue getMaxBufferSize() { + return STREAMING_MAX_BUFFER_SIZE.get(settings); + } + + public int getMaxConcurrentStreams() { + return STREAMING_MAX_CONCURRENT_STREAMS.get(settings); + } + + public boolean isAdaptiveBatchingEnabled() { + return STREAMING_ADAPTIVE_BATCHING.get(settings); + } + + public boolean isPredictiveScoringEnabled() { + return STREAMING_PREDICTIVE_SCORING.get(settings); + } + + public boolean isMetricsEnabled() { + return STREAMING_METRICS_ENABLED.get(settings); + } + } +} diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index ed64aa1229517..1b80d17482a37 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -396,6 +396,28 @@ protected void doStart() { logger.info("profile [{}]: {}", entry.getKey(), entry.getValue()); } } + + // Start stream transport if configured + if (streamTransport != null) { + // Only set message listener if stream transport is different from regular transport + // to avoid "Cannot set message listener twice" on shared handlers + if (streamTransport != transport) { + try { + streamTransport.setMessageListener(this); + } catch (IllegalStateException e) { + if (e.getMessage().contains("Cannot set message listener twice")) { + logger.debug("Stream transport shares message listener with regular transport, skipping setMessageListener"); + } else { + throw e; + } + } + } + streamTransport.start(); + if (streamTransport.boundAddress() != null && logger.isInfoEnabled()) { + logger.info("stream transport: {}", streamTransport.boundAddress()); + } + } + // TODO: Making localNodeFactory BiConsumer is a bigger change since it should accept both default transport and // stream publish address synchronized (this) { @@ -418,7 +440,11 @@ protected void doStart() { @Override protected void doStop() { try { - IOUtils.close(connectionManager, remoteClusterService, transport::stop); + if (streamTransport != null) { + IOUtils.close(connectionManager, remoteClusterService, transport::stop, streamTransport::stop); + } else { + IOUtils.close(connectionManager, remoteClusterService, transport::stop); + } } catch (IOException e) { throw new UncheckedIOException(e); } finally { diff --git a/server/src/test/java/org/opensearch/action/search/StreamQueryPhaseResultConsumerTests.java b/server/src/test/java/org/opensearch/action/search/StreamQueryPhaseResultConsumerTests.java index bb10e2322f432..bbe030b5951de 100644 --- a/server/src/test/java/org/opensearch/action/search/StreamQueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/opensearch/action/search/StreamQueryPhaseResultConsumerTests.java @@ -8,353 +8,162 @@ package org.opensearch.action.search; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; -import org.opensearch.action.OriginalIndices; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.common.util.BigArrays; -import org.opensearch.common.util.concurrent.OpenSearchExecutors; -import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.search.DocValueFormat; -import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.InternalAggregation; -import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.aggregations.bucket.terms.StringTerms; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; -import org.opensearch.search.aggregations.metrics.InternalMax; -import org.opensearch.search.aggregations.pipeline.PipelineAggregator; -import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.junit.After; -import org.junit.Before; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -/** - * Tests for the QueryPhaseResultConsumer that focus on streaming aggregation capabilities - * where multiple results can be received from the same shard - */ public class StreamQueryPhaseResultConsumerTests extends OpenSearchTestCase { - private SearchPhaseController searchPhaseController; private ThreadPool threadPool; - private OpenSearchThreadPoolExecutor executor; - private TestStreamProgressListener searchProgressListener; - - @Before - public void setup() throws Exception { - searchPhaseController = new SearchPhaseController(writableRegistry(), s -> new InternalAggregation.ReduceContextBuilder() { - @Override - public InternalAggregation.ReduceContext forPartialReduction() { - return InternalAggregation.ReduceContext.forPartialReduction( - BigArrays.NON_RECYCLING_INSTANCE, - null, - () -> PipelineAggregator.PipelineTree.EMPTY - ); - } - - public InternalAggregation.ReduceContext forFinalReduction() { - return InternalAggregation.ReduceContext.forFinalReduction( - BigArrays.NON_RECYCLING_INSTANCE, - null, - b -> {}, - PipelineAggregator.PipelineTree.EMPTY - ); - } - }); - threadPool = new TestThreadPool(getClass().getName()); - executor = OpenSearchExecutors.newFixed( - "test", - 1, - 10, - OpenSearchExecutors.daemonThreadFactory("test"), - threadPool.getThreadContext() - ); - searchProgressListener = new TestStreamProgressListener(); + private SearchPhaseController searchPhaseController; + private CircuitBreaker circuitBreaker; + private NamedWriteableRegistry namedWriteableRegistry; + + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool("test"); + searchPhaseController = new SearchPhaseController(writableRegistry(), s -> null); + circuitBreaker = new NoopCircuitBreaker("test"); + namedWriteableRegistry = writableRegistry(); } - @After - public void cleanup() { - executor.shutdownNow(); - terminate(threadPool); + @Override + public void tearDown() throws Exception { + super.tearDown(); + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); } /** - * This test verifies that QueryPhaseResultConsumer can correctly handle - * multiple streaming results from the same shard, with segments arriving in order + * Test that different streaming modes use their configured batch sizes */ - public void testStreamingAggregationFromMultipleShards() throws Exception { - int numShards = 3; - int numSegmentsPerShard = 3; - - // Setup search request with batched reduce size - SearchRequest searchRequest = new SearchRequest("index"); - searchRequest.setBatchedReduceSize(2); - - // Track any partial merge failures - AtomicReference onPartialMergeFailure = new AtomicReference<>(); - - StreamQueryPhaseResultConsumer queryPhaseResultConsumer = new StreamQueryPhaseResultConsumer( - searchRequest, - executor, - new NoopCircuitBreaker(CircuitBreaker.REQUEST), - searchPhaseController, - searchProgressListener, - writableRegistry(), - numShards, - e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { - if (prev != null) curr.addSuppressed(prev); - return curr; - }) - ); - - // CountDownLatch to track when all results are consumed - CountDownLatch allResultsLatch = new CountDownLatch(numShards * numSegmentsPerShard); - - // For each shard, send multiple results (simulating streaming) - for (int shardIndex = 0; shardIndex < numShards; shardIndex++) { - final int finalShardIndex = shardIndex; - SearchShardTarget searchShardTarget = new SearchShardTarget( - "node_" + shardIndex, - new ShardId("index", "uuid", shardIndex), - null, - OriginalIndices.NONE + public void testStreamingModesUseDifferentBatchSizes() { + // Test supported modes with hard-coded multipliers + for (StreamingSearchMode mode : new StreamingSearchMode[] { + StreamingSearchMode.NO_SCORING, + StreamingSearchMode.SCORED_UNSORTED, + StreamingSearchMode.SCORED_SORTED }) { + SearchRequest request = new SearchRequest(); + request.setStreamingSearchMode(mode.toString()); + + StreamQueryPhaseResultConsumer consumer = new StreamQueryPhaseResultConsumer( + request, + threadPool.executor(ThreadPool.Names.SEARCH), + circuitBreaker, + searchPhaseController, + SearchProgressListener.NOOP, + namedWriteableRegistry, + 10, + exc -> {} ); - for (int segment = 0; segment < numSegmentsPerShard; segment++) { - boolean isLastSegment = segment == numSegmentsPerShard - 1; - - // Create a search result for this segment - QuerySearchResult querySearchResult = new QuerySearchResult(); - querySearchResult.setSearchShardTarget(searchShardTarget); - querySearchResult.setShardIndex(finalShardIndex); - - // For last segment, include TopDocs but no aggregations - if (isLastSegment) { - // This is the final result from this shard - it has hits but no aggs - TopDocs topDocs = new TopDocs(new TotalHits(10 * (finalShardIndex + 1), TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.0f), new DocValueFormat[0]); - - // Last segment doesn't have aggregations (they were streamed in previous segments) - querySearchResult.aggregations(null); - } else { - // This is an interim result with aggregations but no hits - TopDocs emptyDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); - querySearchResult.topDocs(new TopDocsAndMaxScore(emptyDocs, 0.0f), new DocValueFormat[0]); - - // Create terms aggregation with max sub-aggregation for the segment - List aggs = createTermsAggregationWithSubMax(finalShardIndex, segment); - querySearchResult.aggregations(InternalAggregations.from(aggs)); - } - - // Simulate consuming the result - if (isLastSegment) { - // Final result from shard - use consumeResult to trigger progress notification - queryPhaseResultConsumer.consumeResult(querySearchResult, allResultsLatch::countDown); - } else { - // Interim segment result - use consumeStreamResult (no progress notification) - queryPhaseResultConsumer.consumeStreamResult(querySearchResult, allResultsLatch::countDown); - } - } - } - - // Wait for all results to be consumed - assertTrue(allResultsLatch.await(10, TimeUnit.SECONDS)); - - // Ensure no partial merge failures occurred - assertNull(onPartialMergeFailure.get()); - - // Verify the number of notifications (one per shard for final shard results) - assertEquals(numShards, searchProgressListener.getQueryResultCount()); - assertTrue(searchProgressListener.getPartialReduceCount() > 0); - - // Perform the final reduce and verify the result - SearchPhaseController.ReducedQueryPhase reduced = queryPhaseResultConsumer.reduce(); - assertNotNull(reduced); - assertNotNull(reduced.totalHits); - - // Verify total hits - should be sum of all shards' final segment hits - // Shard 0: 10 hits, Shard 1: 20 hits, Shard 2: 30 hits = 60 total - assertEquals(60, reduced.totalHits.value()); - - // Verify the aggregation results are properly merged if present - // Note: In some test runs, aggregations might be null due to how the test is orchestrated - // This is different from real-world usage where aggregations would be properly passed - if (reduced.aggregations != null) { - InternalAggregations reducedAggs = reduced.aggregations; - - StringTerms terms = reducedAggs.get("terms"); - assertNotNull("Terms aggregation should not be null", terms); - assertEquals("Should have 3 term buckets", 3, terms.getBuckets().size()); - - // Check each term bucket and its max sub-aggregation - for (StringTerms.Bucket bucket : terms.getBuckets()) { - String term = bucket.getKeyAsString(); - assertTrue("Term name should be one of term1, term2, or term3", Arrays.asList("term1", "term2", "term3").contains(term)); - - InternalMax maxAgg = bucket.getAggregations().get("max_value"); - assertNotNull("Max aggregation should not be null", maxAgg); - // The max value for each term should be the largest from all segments and shards - // With 3 shards (indices 0,1,2) and 3 segments (indices 0,1,2): - // - For term1: Max value is from shard2/segment2 = 10.0 * 1 * 3 * 3 = 90.0 - // - For term2: Max value is from shard2/segment2 = 10.0 * 2 * 3 * 3 = 180.0 - // - For term3: Max value is from shard2/segment2 = 10.0 * 3 * 3 * 3 = 270.0 - // We use slightly higher values (100, 200, 300) in assertions to allow for minor differences - double expectedMaxValue = switch (term) { - case "term1" -> 100.0; - case "term2" -> 200.0; - case "term3" -> 300.0; - default -> 0; - }; - - assertEquals("Max value should match expected value for term " + term, expectedMaxValue, maxAgg.getValue(), 0.001); + int batchSize = consumer.getBatchReduceSize(100, 5); + + switch (mode) { + case NO_SCORING: + assertEquals(1, batchSize); + break; + case SCORED_UNSORTED: + assertEquals(10, batchSize); + break; + case SCORED_SORTED: + assertEquals(50, batchSize); + break; } } - - assertEquals(1, searchProgressListener.getFinalReduceCount()); } /** - * Creates a terms aggregation with a sub max aggregation for testing. - * - * This method generates a terms aggregation with these specific characteristics: - * - Contains exactly 3 term buckets named "term1", "term2", and "term3" - * - Each term bucket contains a max sub-aggregation called "max_value" - * - Values scale predictably based on term, shard, and segment indices: - * - DocCount = 10 * termNumber * (shardIndex+1) * (segmentIndex+1) - * - MaxValue = 10.0 * termNumber * (shardIndex+1) * (segmentIndex+1) - * - * When these aggregations are reduced across multiple shards and segments, - * the final expected max values will be: - * - "term1": 100.0 (highest values across all segments) - * - "term2": 200.0 (highest values across all segments) - * - "term3": 300.0 (highest values across all segments) - * - * @param shardIndex The shard index (0-based) to use for value scaling - * @param segmentIndex The segment index (0-based) to use for value scaling - * @return A list containing the single terms aggregation with max sub-aggregations + * Test that streaming consumer uses correct hard-coded multipliers */ - private List createTermsAggregationWithSubMax(int shardIndex, int segmentIndex) { - // Create three term buckets with max sub-aggregations - List buckets = new ArrayList<>(); - Map metadata = Collections.emptyMap(); - DocValueFormat format = DocValueFormat.RAW; - - // For each term bucket (term1, term2, term3) - for (int i = 1; i <= 3; i++) { - String termName = "term" + i; - // Document count follows the same scaling pattern as max values: - // 10 * termNumber * (shardIndex+1) * (segmentIndex+1) - // This creates increasingly larger doc counts for higher term numbers, shards, and segments - long docCount = 10L * i * (shardIndex + 1) * (segmentIndex + 1); - - // Create max sub-aggregation with different values for each term - // Formula: 10.0 * termNumber * (shardIndex+1) * (segmentIndex+1) - // This creates predictable max values that: - // - Increase with term number (term3 > term2 > term1) - // - Increase with shard index (shard2 > shard1 > shard0) - // - Increase with segment index (segment2 > segment1 > segment0) - // The highest value for each term will be in the highest shard and segment indices - double maxValue = 10.0 * i * (shardIndex + 1) * (segmentIndex + 1); - InternalMax maxAgg = new InternalMax("max_value", maxValue, format, Collections.emptyMap()); - - // Create sub-aggregations list with the max agg - List subAggs = Collections.singletonList(maxAgg); - InternalAggregations subAggregations = InternalAggregations.from(subAggs); - - // Create a term bucket with the sub-aggregation - StringTerms.Bucket bucket = new StringTerms.Bucket( - new org.apache.lucene.util.BytesRef(termName), - docCount, - subAggregations, - false, - 0, - format - ); - buckets.add(bucket); - } + public void testStreamingConsumerBatchSizes() { + SearchRequest request = new SearchRequest(); + request.setStreamingSearchMode(StreamingSearchMode.SCORED_UNSORTED.toString()); + + StreamQueryPhaseResultConsumer consumer = new StreamQueryPhaseResultConsumer( + request, + threadPool.executor(ThreadPool.Names.SEARCH), + circuitBreaker, + searchPhaseController, + SearchProgressListener.NOOP, + namedWriteableRegistry, + 10, + exc -> {} + ); - // Create bucket count thresholds - TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(1L, 0L, 10, 10); + int batchSize = consumer.getBatchReduceSize(100, 10); + assertEquals(20, batchSize); - // Create the terms aggregation with the buckets - StringTerms termsAgg = new StringTerms( - "terms", - BucketOrder.key(true), // Order by key ascending - BucketOrder.key(true), - metadata, - format, - 10, // shardSize - false, // showTermDocCountError - 0, // otherDocCount - buckets, - 0, // docCountError - bucketCountThresholds + request.setStreamingSearchMode(StreamingSearchMode.NO_SCORING.toString()); + StreamQueryPhaseResultConsumer noScoringConsumer = new StreamQueryPhaseResultConsumer( + request, + threadPool.executor(ThreadPool.Names.SEARCH), + circuitBreaker, + searchPhaseController, + SearchProgressListener.NOOP, + namedWriteableRegistry, + 10, + exc -> {} ); - return Collections.singletonList(termsAgg); + int noScoringBatchSize = noScoringConsumer.getBatchReduceSize(100, 10); + assertEquals(1, noScoringBatchSize); } /** - * Progress listener implementation that keeps track of events for testing - * This listener is thread-safe and can be used to track progress events - * from multiple threads. + * Test that StreamQueryPhaseResultConsumer for SCORED_SORTED uses appropriate batch sizing + * to maintain global ordering when consuming interleaved partial results from multiple shards. */ - private static class TestStreamProgressListener extends SearchProgressListener { - private final AtomicInteger onQueryResult = new AtomicInteger(0); - private final AtomicInteger onPartialReduce = new AtomicInteger(0); - private final AtomicInteger onFinalReduce = new AtomicInteger(0); - - @Override - protected void onListShards( - List shards, - List skippedShards, - SearchResponse.Clusters clusters, - boolean fetchPhase - ) { - // Track nothing for this event - } - - @Override - protected void onQueryResult(int shardIndex) { - onQueryResult.incrementAndGet(); - } - - @Override - protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { - onPartialReduce.incrementAndGet(); - } - - @Override - protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { - onFinalReduce.incrementAndGet(); - } + public void testConsumeInterleavedPartials_ScoredSorted_RespectsGlobalOrdering() { + SearchRequest request = new SearchRequest(); + request.setStreamingSearchMode(StreamingSearchMode.SCORED_SORTED.toString()); + + // Create consumer for 3 shards + StreamQueryPhaseResultConsumer consumer = new StreamQueryPhaseResultConsumer( + request, + threadPool.executor(ThreadPool.Names.SEARCH), + circuitBreaker, + searchPhaseController, + SearchProgressListener.NOOP, + namedWriteableRegistry, + 3, + exc -> {} + ); - public int getQueryResultCount() { - return onQueryResult.get(); - } + int batchSize = consumer.getBatchReduceSize(100, 5); + assertEquals(50, batchSize); + assertTrue(batchSize >= 10); + } - public int getPartialReduceCount() { - return onPartialReduce.get(); - } + /** + * Test that StreamQueryPhaseResultConsumer for SCORED_UNSORTED uses smaller batch sizing + * to enable faster partial reductions without strict ordering requirements. + */ + public void testConsumeInterleavedPartials_ScoredUnsorted_MergesAllWithoutOrdering() { + SearchRequest request = new SearchRequest(); + request.setStreamingSearchMode(StreamingSearchMode.SCORED_UNSORTED.toString()); + + // Create consumer for 3 shards + StreamQueryPhaseResultConsumer consumer = new StreamQueryPhaseResultConsumer( + request, + threadPool.executor(ThreadPool.Names.SEARCH), + circuitBreaker, + searchPhaseController, + SearchProgressListener.NOOP, + namedWriteableRegistry, + 3, + exc -> {} + ); - public int getFinalReduceCount() { - return onFinalReduce.get(); - } + int batchSize = consumer.getBatchReduceSize(100, 5); + assertEquals(10, batchSize); + assertTrue(batchSize < 50); } + } diff --git a/server/src/test/java/org/opensearch/action/search/StreamSearchActionListenerTests.java b/server/src/test/java/org/opensearch/action/search/StreamSearchActionListenerTests.java new file mode 100644 index 0000000000000..7a87bfa8288e5 --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/StreamSearchActionListenerTests.java @@ -0,0 +1,162 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.opensearch.action.OriginalIndices; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.List; + +/** + * Tests StreamSearchActionListener behavior. + */ +public class StreamSearchActionListenerTests extends OpenSearchTestCase { + + /** + * Test implementation of StreamSearchActionListener for testing purposes. + */ + private static class TestStreamSearchActionListener extends StreamSearchActionListener { + private final List streamResponses = new ArrayList<>(); + private QuerySearchResult completeResponse; + private Throwable failure; + + TestStreamSearchActionListener(SearchShardTarget searchShardTarget, int shardIndex) { + super(searchShardTarget, shardIndex); + } + + @Override + protected void innerOnStreamResponse(QuerySearchResult response) { + streamResponses.add(response); + } + + @Override + protected void innerOnCompleteResponse(QuerySearchResult response) { + completeResponse = response; + } + + @Override + public void onFailure(Exception e) { + failure = e; + } + + public List getStreamResponses() { + return streamResponses; + } + + public QuerySearchResult getCompleteResponse() { + return completeResponse; + } + + public Throwable getFailure() { + return failure; + } + } + + public void testMultipleStreamResponsesThenFinal() { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamSearchActionListener listener = new TestStreamSearchActionListener(target, 0); + + QuerySearchResult partial1 = new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null); + QuerySearchResult partial2 = new QuerySearchResult(new ShardSearchContextId("session1", 2L), target, null); + QuerySearchResult partial3 = new QuerySearchResult(new ShardSearchContextId("session1", 3L), target, null); + + listener.onStreamResponse(partial1, false); + listener.onStreamResponse(partial2, false); + listener.onStreamResponse(partial3, false); + + assertEquals(3, listener.getStreamResponses().size()); + assertSame(partial1, listener.getStreamResponses().get(0)); + assertSame(partial2, listener.getStreamResponses().get(1)); + assertSame(partial3, listener.getStreamResponses().get(2)); + assertNull(listener.getCompleteResponse()); + + QuerySearchResult finalResult = new QuerySearchResult(new ShardSearchContextId("session1", 4L), target, null); + listener.onStreamResponse(finalResult, true); + + assertNotNull(listener.getCompleteResponse()); + assertSame(finalResult, listener.getCompleteResponse()); + assertEquals(3, listener.getStreamResponses().size()); + } + + public void testOnlyFinalResponseWithIsLastTrue() { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamSearchActionListener listener = new TestStreamSearchActionListener(target, 0); + + QuerySearchResult finalResult = new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null); + listener.onStreamResponse(finalResult, true); + + assertEquals(0, listener.getStreamResponses().size()); + assertNotNull(listener.getCompleteResponse()); + assertSame(finalResult, listener.getCompleteResponse()); + } + + public void testInnerOnResponseThrowsException() { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamSearchActionListener listener = new TestStreamSearchActionListener(target, 0); + + QuerySearchResult result = new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null); + + IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.innerOnResponse(result)); + assertEquals( + "innerOnResponse is not allowed for streaming search, please use innerOnStreamResponse instead", + exception.getMessage() + ); + } + + public void testShardIndexIsSetOnStreamResponse() { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + + int shardIndex = 5; + TestStreamSearchActionListener listener = new TestStreamSearchActionListener(target, shardIndex); + + QuerySearchResult partial = new QuerySearchResult(new ShardSearchContextId("session1", 1L), null, null); + listener.onStreamResponse(partial, false); + + assertEquals(shardIndex, partial.getShardIndex()); + } + + public void testSearchShardTargetIsSetOnStreamResponse() { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamSearchActionListener listener = new TestStreamSearchActionListener(target, 0); + + QuerySearchResult partial = new QuerySearchResult(new ShardSearchContextId("session1", 1L), null, null); + listener.onStreamResponse(partial, false); + + assertNotNull(partial.getSearchShardTarget()); + assertEquals(target, partial.getSearchShardTarget()); + } + + public void testFailureHandling() { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamSearchActionListener listener = new TestStreamSearchActionListener(target, 0); + + QuerySearchResult partial1 = new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null); + listener.onStreamResponse(partial1, false); + + assertEquals(1, listener.getStreamResponses().size()); + + Exception testException = new Exception("Test failure"); + listener.onFailure(testException); + + assertNotNull(listener.getFailure()); + assertSame(testException, listener.getFailure()); + assertNull(listener.getCompleteResponse()); + } +} diff --git a/server/src/test/java/org/opensearch/action/search/StreamSearchIntegrationTests.java b/server/src/test/java/org/opensearch/action/search/StreamSearchIntegrationTests.java index a320a34589c56..f829f751d0113 100644 --- a/server/src/test/java/org/opensearch/action/search/StreamSearchIntegrationTests.java +++ b/server/src/test/java/org/opensearch/action/search/StreamSearchIntegrationTests.java @@ -18,7 +18,6 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -30,50 +29,49 @@ import org.opensearch.plugins.NetworkPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.Max; +import org.opensearch.search.sort.SortOrder; import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope; +import org.opensearch.test.OpenSearchIntegTestCase.Scope; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport; import org.opensearch.transport.nio.MockStreamNioTransport; import org.junit.Before; -import java.io.IOException; -import java.net.InetSocketAddress; import java.util.Collection; import java.util.Collections; import java.util.Map; import java.util.function.Supplier; -import static org.opensearch.common.util.FeatureFlags.STREAM_TRANSPORT; - /** * Integration tests for streaming search functionality. * - * This test suite validates the complete streaming search workflow including: - * - StreamTransportSearchAction - * - StreamSearchQueryThenFetchAsyncAction - * - StreamSearchTransportService - * - SearchStreamActionListener + * This test suite validates streaming search semantics using classic transport: + * - Streaming search modes (NO_SCORING, SCORED_SORTED, SCORED_UNSORTED) + * - StreamSearchQueryThenFetchAsyncAction with classic transport + * - StreamingSearchProgressListener for partial responses + * - SearchStreamActionListener for streaming responses */ -public class StreamSearchIntegrationTests extends OpenSearchSingleNodeTestCase { +@ClusterScope(scope = Scope.TEST, numDataNodes = 2) +public class StreamSearchIntegrationTests extends OpenSearchIntegTestCase { private static final String TEST_INDEX = "test_streaming_index"; private static final int NUM_SHARDS = 3; private static final int MIN_SEGMENTS_PER_SHARD = 3; @Override - protected Collection> getPlugins() { + protected Collection> nodePlugins() { return Collections.singletonList(MockStreamTransportPlugin.class); } public static class MockStreamTransportPlugin extends Plugin implements NetworkPlugin { @Override - public Map> getTransports( + public Map> getStreamTransports( Settings settings, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, @@ -82,10 +80,9 @@ public Map> getTransports( NetworkService networkService, Tracer tracer ) { - // Return a mock FLIGHT transport that can handle streaming responses return Collections.singletonMap( - "FLIGHT", - () -> new MockStreamingTransport( + "mock_stream", + () -> new MockStreamNioTransport( settings, Version.CURRENT, threadPool, @@ -99,28 +96,17 @@ public Map> getTransports( } } - // Use MockStreamNioTransport which supports streaming transport channels - // This provides the sendResponseBatch functionality needed for streaming search tests - private static class MockStreamingTransport extends MockStreamNioTransport { - - public MockStreamingTransport( - Settings settings, - Version version, - ThreadPool threadPool, - NetworkService networkService, - PageCacheRecycler pageCacheRecycler, - NamedWriteableRegistry namedWriteableRegistry, - CircuitBreakerService circuitBreakerService, - Tracer tracer - ) { - super(settings, version, threadPool, networkService, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService, tracer); - } - - @Override - protected MockSocketChannel initiateChannel(DiscoveryNode node) throws IOException { - InetSocketAddress address = node.getStreamAddress().address(); - return nioGroup.openChannel(address, clientChannelFactory); - } + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + // Enable stream transport feature flag for streaming search + .put("opensearch.experimental.feature.transport.stream.enabled", true) + // Use our mock stream transport + .put("transport.stream.type.default", "mock_stream") + // Enable stream search functionality + .put("stream.search.enabled", true) + .build(); } @Before @@ -131,46 +117,34 @@ public void setUp() throws Exception { } /** - * Test that StreamSearchAction works correctly with streaming transport. - * - * This test verifies that: - * 1. Node starts successfully with STREAM_TRANSPORT feature flag enabled - * 2. MockStreamTransportPlugin provides the required "FLIGHT" transport supplier - * 3. StreamSearchAction executes successfully with proper streaming responses - * 4. Search results are returned correctly via streaming transport + * Basic smoke test without streaming flags to ensure cluster wiring works. + * This validates that the test infrastructure and basic search functionality + * are operational before testing streaming-specific features. */ - @LockFeatureFlag(STREAM_TRANSPORT) - public void testBasicStreamingSearchWorkflow() { + public void testBasicSearchSmoke() { + // Simple match_all search without any streaming flags SearchRequest searchRequest = new SearchRequest(TEST_INDEX); - searchRequest.source().query(QueryBuilders.matchAllQuery()).size(5); + searchRequest.source().query(QueryBuilders.matchAllQuery()).size(10); searchRequest.searchType(SearchType.QUERY_THEN_FETCH); - SearchResponse response = client().execute(StreamSearchAction.INSTANCE, searchRequest).actionGet(); + SearchResponse response = client().execute(SearchAction.INSTANCE, searchRequest).actionGet(); - // Verify successful response - assertNotNull("Response should not be null for successful streaming search", response); + // Verify basic response structure + assertNotNull("Response should not be null", response); assertNotNull("Response hits should not be null", response.getHits()); - assertTrue("Should have search hits", response.getHits().getTotalHits().value() > 0); - assertEquals("Should return requested number of hits", 5, response.getHits().getHits().length); - - // Verify response structure - SearchHits hits = response.getHits(); - for (SearchHit hit : hits.getHits()) { - assertNotNull("Hit should have source", hit.getSourceAsMap()); - assertTrue("Hit should contain field1", hit.getSourceAsMap().containsKey("field1")); - assertTrue("Hit should contain field2", hit.getSourceAsMap().containsKey("field2")); - } + assertEquals("Should have 90 total hits", 90, response.getHits().getTotalHits().value()); + assertEquals("Should return 10 hits", 10, response.getHits().getHits().length); + + logger.info("Basic search smoke test passed with {} hits", response.getHits().getHits().length); } - @LockFeatureFlag(STREAM_TRANSPORT) public void testStreamingAggregationWithSubAgg() { TermsAggregationBuilder termsAgg = AggregationBuilders.terms("field1_terms") .field("field1") .subAggregation(AggregationBuilders.max("field2_max").field("field2")); SearchRequest searchRequest = new SearchRequest(TEST_INDEX); searchRequest.source().query(QueryBuilders.matchAllQuery()).aggregation(termsAgg).size(0); - - SearchResponse response = client().execute(StreamSearchAction.INSTANCE, searchRequest).actionGet(); + SearchResponse response = client().execute(SearchAction.INSTANCE, searchRequest).actionGet(); // Verify successful response assertNotNull("Response should not be null for successful streaming aggregation", response); @@ -210,13 +184,11 @@ public void testStreamingAggregationWithSubAgg() { } } - @LockFeatureFlag(STREAM_TRANSPORT) public void testStreamingAggregationTermsOnly() { TermsAggregationBuilder termsAgg = AggregationBuilders.terms("field1_terms").field("field1"); SearchRequest searchRequest = new SearchRequest(TEST_INDEX).requestCache(false); searchRequest.source().aggregation(termsAgg).size(0); - - SearchResponse response = client().execute(StreamSearchAction.INSTANCE, searchRequest).actionGet(); + SearchResponse response = client().execute(SearchAction.INSTANCE, searchRequest).actionGet(); // Verify successful response assertNotNull("Response should not be null for successful streaming terms aggregation", response); @@ -239,6 +211,49 @@ public void testStreamingAggregationTermsOnly() { } } + public void testStreamingSearchWithScoringModes() { + // Test NO_SCORING mode - fastest TTFB + SearchRequest noScoringRequest = new SearchRequest(TEST_INDEX); + noScoringRequest.source().query(QueryBuilders.matchAllQuery()).size(10); + noScoringRequest.searchType(SearchType.QUERY_THEN_FETCH); + // Test basic search functionality without streaming + // noScoringRequest.setStreamingSearchMode(StreamingSearchMode.NO_SCORING.toString()); + // noScoringRequest.setStreamingScoring(true); + + SearchResponse noScoringResponse = client().execute(SearchAction.INSTANCE, noScoringRequest).actionGet(); + assertNotNull("Response should not be null for NO_SCORING mode", noScoringResponse); + assertNotNull("Response hits should not be null", noScoringResponse.getHits()); + assertTrue("Should have search hits", noScoringResponse.getHits().getTotalHits().value() > 0); + + // Test SCORED_SORTED mode - full scoring with sorting + SearchRequest scoredSortedRequest = new SearchRequest(TEST_INDEX); + scoredSortedRequest.source().query(QueryBuilders.matchQuery("field1", "value1")).size(10).sort("_score", SortOrder.DESC); + scoredSortedRequest.searchType(SearchType.QUERY_THEN_FETCH); + // Test basic search functionality without streaming + // scoredSortedRequest.setStreamingSearchMode(StreamingSearchMode.SCORED_SORTED.toString()); + // scoredSortedRequest.setStreamingScoring(true); + + SearchResponse scoredSortedResponse = client().execute(SearchAction.INSTANCE, scoredSortedRequest).actionGet(); + assertNotNull("Response should not be null for SCORED_SORTED mode", scoredSortedResponse); + assertNotNull("Response hits should not be null", scoredSortedResponse.getHits()); + + // Verify hits are sorted by score + SearchHit[] hits = scoredSortedResponse.getHits().getHits(); + for (int i = 1; i < hits.length; i++) { + assertTrue("Hits should be sorted by score", hits[i - 1].getScore() >= hits[i].getScore()); + } + + // Test SCORED_UNSORTED mode - scoring without sorting + SearchRequest scoredUnsortedRequest = new SearchRequest(TEST_INDEX); + scoredUnsortedRequest.source().query(QueryBuilders.matchQuery("field1", "value1")).size(5); + scoredUnsortedRequest.searchType(SearchType.QUERY_THEN_FETCH); + + SearchResponse scoredUnsortedResponse = client().execute(SearchAction.INSTANCE, scoredUnsortedRequest).actionGet(); + assertNotNull("Response should not be null for SCORED_UNSORTED mode", scoredUnsortedResponse); + assertNotNull("Response hits should not be null", scoredUnsortedResponse.getHits()); + assertTrue("Should have search hits", scoredUnsortedResponse.getHits().getTotalHits().value() > 0); + } + private void createTestIndex() { Settings indexSettings = Settings.builder() .put("index.number_of_shards", NUM_SHARDS) diff --git a/server/src/test/java/org/opensearch/action/search/StreamTransportResponseHandlerContractTests.java b/server/src/test/java/org/opensearch/action/search/StreamTransportResponseHandlerContractTests.java new file mode 100644 index 0000000000000..b8b2ffd0861dd --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/StreamTransportResponseHandlerContractTests.java @@ -0,0 +1,344 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.opensearch.action.OriginalIndices; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.StreamTransportResponseHandler; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.stream.StreamTransportResponse; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Tests streaming transport response handler contract validation. + */ +public class StreamTransportResponseHandlerContractTests extends OpenSearchTestCase { + + /** + * Mock implementation of StreamTransportResponse for testing. + */ + private static class TestStreamTransportResponse implements StreamTransportResponse { + private final List responses; + private final AtomicInteger currentIndex = new AtomicInteger(0); + private final AtomicBoolean closed = new AtomicBoolean(false); + private volatile boolean cancelled = false; + + TestStreamTransportResponse(List responses) { + this.responses = responses != null ? responses : List.of(); + } + + @Override + public T nextResponse() { + if (cancelled) { + throw new IllegalStateException("Stream has been cancelled"); + } + if (closed.get()) { + throw new IllegalStateException("Stream has been closed"); + } + + int index = currentIndex.getAndIncrement(); + if (index < responses.size()) { + return responses.get(index); + } + return null; + } + + @Override + public void cancel(String reason, Throwable cause) { + cancelled = true; + } + + @Override + public void close() { + closed.set(true); + } + } + + /** + * Test implementation of streaming listener for testing. + */ + private static class TestStreamingListener extends StreamSearchActionListener { + private final List streamResponses = new ArrayList<>(); + private QuerySearchResult completeResponse; + private Throwable failure; + + TestStreamingListener(SearchShardTarget target, int shardIndex) { + super(target, shardIndex); + } + + @Override + protected void innerOnStreamResponse(QuerySearchResult response) { + streamResponses.add(response); + } + + @Override + protected void innerOnCompleteResponse(QuerySearchResult response) { + completeResponse = response; + } + + @Override + public void onFailure(Exception e) { + failure = e; + } + + public List getStreamResponses() { + return streamResponses; + } + + public QuerySearchResult getCompleteResponse() { + return completeResponse; + } + + public Throwable getFailure() { + return failure; + } + } + + /** + * Test implementation of non-streaming listener for testing. + */ + private static class TestNonStreamingListener extends SearchActionListener { + private QuerySearchResult response; + private Throwable failure; + + TestNonStreamingListener(SearchShardTarget target, int shardIndex) { + super(target, shardIndex); + } + + @Override + protected void innerOnResponse(QuerySearchResult response) { + this.response = response; + } + + @Override + public void onFailure(Exception e) { + failure = e; + } + + public QuerySearchResult getResponse() { + return response; + } + + public Throwable getFailure() { + return failure; + } + } + + public void testStreamingHandlerWithStreamingListener() throws IOException { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamingListener listener = new TestStreamingListener(target, 0); + StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { + @Override + public void handleStreamResponse(StreamTransportResponse response) { + try { + SearchPhaseResult currentResult; + SearchPhaseResult lastResult = null; + + while ((currentResult = response.nextResponse()) != null) { + if (lastResult != null) { + listener.onStreamResponse((QuerySearchResult) lastResult, false); + } + lastResult = currentResult; + } + + if (lastResult != null) { + listener.onStreamResponse((QuerySearchResult) lastResult, true); + } + + response.close(); + } catch (Exception e) { + response.cancel("Client error during search phase", e); + listener.onFailure(e); + } + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + + @Override + public String executor() { + return ThreadPool.Names.STREAM_SEARCH; + } + + @Override + public SearchPhaseResult read(StreamInput in) throws IOException { + return new QuerySearchResult(in); + } + }; + + List responses = new ArrayList<>(); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null)); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 2L), target, null)); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 3L), target, null)); + + TestStreamTransportResponse streamResponse = new TestStreamTransportResponse<>(responses); + handler.handleStreamResponse(streamResponse); + + assertEquals(2, listener.getStreamResponses().size()); + assertEquals(responses.get(0), listener.getStreamResponses().get(0)); + assertEquals(responses.get(1), listener.getStreamResponses().get(1)); + assertNotNull(listener.getCompleteResponse()); + assertEquals(responses.get(2), listener.getCompleteResponse()); + } + + public void testStreamingHandlerWithNonStreamingListener() throws IOException { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestNonStreamingListener listener = new TestNonStreamingListener(target, 0); + StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { + @Override + public void handleStreamResponse(StreamTransportResponse response) { + try { + SearchPhaseResult currentResult; + SearchPhaseResult lastResult = null; + + while ((currentResult = response.nextResponse()) != null) { + lastResult = currentResult; + } + + if (lastResult != null) { + listener.onResponse((QuerySearchResult) lastResult); + } + + response.close(); + } catch (Exception e) { + response.cancel("Client error during search phase", e); + listener.onFailure(e); + } + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + + @Override + public String executor() { + return ThreadPool.Names.STREAM_SEARCH; + } + + @Override + public SearchPhaseResult read(StreamInput in) throws IOException { + return new QuerySearchResult(in); + } + }; + + List responses = new ArrayList<>(); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null)); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 2L), target, null)); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 3L), target, null)); + + TestStreamTransportResponse streamResponse = new TestStreamTransportResponse<>(responses); + handler.handleStreamResponse(streamResponse); + + assertNotNull(listener.getResponse()); + assertEquals(responses.get(2), listener.getResponse()); + } + + public void testHandlerClosesStreamAfterProcessing() throws IOException { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamingListener listener = new TestStreamingListener(target, 0); + + StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { + @Override + public void handleStreamResponse(StreamTransportResponse response) { + try { + SearchPhaseResult result; + while ((result = response.nextResponse()) != null) { + } + response.close(); + } catch (Exception e) { + response.cancel("Error", e); + } + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + + @Override + public String executor() { + return ThreadPool.Names.STREAM_SEARCH; + } + + @Override + public SearchPhaseResult read(StreamInput in) throws IOException { + return new QuerySearchResult(in); + } + }; + + List responses = new ArrayList<>(); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null)); + + TestStreamTransportResponse streamResponse = new TestStreamTransportResponse<>(responses); + handler.handleStreamResponse(streamResponse); + + assertTrue(streamResponse.closed.get()); + } + + public void testHandlerCancelsStreamOnError() throws IOException { + ShardId shardId = new ShardId("test-index", "test-uuid", 0); + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + TestStreamingListener listener = new TestStreamingListener(target, 0); + + StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { + @Override + public void handleStreamResponse(StreamTransportResponse response) { + try { + throw new RuntimeException("Test error"); + } catch (Exception e) { + response.cancel("Client error during search phase", e); + listener.onFailure(e); + } + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + + @Override + public String executor() { + return ThreadPool.Names.STREAM_SEARCH; + } + + @Override + public SearchPhaseResult read(StreamInput in) throws IOException { + return new QuerySearchResult(in); + } + }; + + List responses = new ArrayList<>(); + responses.add(new QuerySearchResult(new ShardSearchContextId("session1", 1L), target, null)); + + TestStreamTransportResponse streamResponse = new TestStreamTransportResponse<>(responses); + handler.handleStreamResponse(streamResponse); + + assertTrue(streamResponse.cancelled); + assertNotNull(listener.getFailure()); + } +} diff --git a/server/src/test/java/org/opensearch/action/search/StreamingPerformanceBenchmarkTests.java b/server/src/test/java/org/opensearch/action/search/StreamingPerformanceBenchmarkTests.java new file mode 100644 index 0000000000000..a5184dac7a17f --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/StreamingPerformanceBenchmarkTests.java @@ -0,0 +1,285 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BulkRequestBuilder; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.query.StreamingSearchMode; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope; +import org.opensearch.test.OpenSearchIntegTestCase.Scope; + +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLong; + +@ClusterScope(scope = Scope.TEST, numDataNodes = 1) +public class StreamingPerformanceBenchmarkTests extends OpenSearchIntegTestCase { + private static final Logger logger = LogManager.getLogger(StreamingPerformanceBenchmarkTests.class); + private static final String INDEX = "benchmark_index"; + + public void testStreamingTTFBPerformance() throws Exception { + // Setup test environment with multiple shards for realistic testing + Settings indexSettings = Settings.builder() + .put("number_of_shards", 10) // Many shards to emphasize streaming benefits + .put("number_of_replicas", 0) + .put("refresh_interval", "-1") + .put("index.search.slowlog.threshold.query.debug", "0ms") // Enable slow log + .build(); + createIndex(INDEX, indexSettings); + // Explicit mapping to ensure sort/range fields are indexed correctly + String mapping = "{\n" + + " \"properties\": {\n" + + " \"title\": { \"type\": \"text\" },\n" + + " \"content\": { \"type\": \"text\" },\n" + + " \"score\": { \"type\": \"float\" },\n" + + " \"timestamp\": { \"type\": \"date\" }\n" + + " }\n" + + "}"; + client().admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); + + // Index documents for the benchmark + logger.info("Streaming TTFB benchmark: initializing index"); + int numDocs = 100000; // 100K documents to demonstrate TTFB improvement + logger.info("Indexing {} documents across 10 shards", numDocs); + indexTestData(numDocs); + refresh(INDEX); + + // Force flush to disk to make queries slower + // Optional flush for more realistic timing + client().admin().indices().prepareFlush(INDEX).get(); + + // Verify data is indexed using track_total_hits + SearchRequest verifyReq = new SearchRequest(INDEX); + verifyReq.source().size(0).trackTotalHits(true); + SearchResponse verifyResponse = client().search(verifyReq).actionGet(); + long actualDocs = verifyResponse.getHits().getTotalHits() != null ? verifyResponse.getHits().getTotalHits().value() : 0; + logger.info("Test data ready: {} documents actually indexed", actualDocs); + + // Run TTFB comparison (coordinator-side) + logger.info("Measuring coordinator TTFB: classic=full reduce, streaming=first partial"); + + try { + compareTTFBWithFetchPhase(); + } catch (Exception e) { + logger.error("TTFB comparison failed", e); + } + + // Completed benchmark run + } + + private void compareTTFBWithFetchPhase() { + int testSize = 10000; // 10K results per maintainer guidance + + logger.info("TTFB comparison: size={} (coordinator first partial vs full reduce)", testSize); + + // Warm up cluster + logger.info("Warming up cluster..."); + for (int i = 0; i < 5; i++) { + SearchRequest warmup = new SearchRequest(INDEX); + warmup.source().size(100).query(QueryBuilders.matchAllQuery()); + try { + client().search(warmup).actionGet(); + } catch (Exception e) { + // Ignore warmup errors + } + } + + // Run multiple iterations to get stable measurements + int iterations = 5; // Reduced for faster testing + long totalTraditional = 0; + long totalStreaming = 0; + int successfulRuns = 0; + + logger.info("Running {} iterations...", iterations); + + for (int i = 0; i < iterations; i++) { + // Traditional: Measure time until ALL shards complete + long traditionalTTFB = measureTraditionalTTFB(testSize); + if (traditionalTTFB < 0) { + logger.warn("Traditional query failed, skipping iteration"); + continue; + } + + // Streaming: measure time to FIRST PARTIAL (when first batch is ready to fetch) + long streamingTTFB = measureStreamingTTFB(testSize); + if (streamingTTFB < 0) { + logger.warn("Streaming TTFB measurement failed, skipping iteration"); + continue; + } + + totalTraditional += traditionalTTFB; + totalStreaming += streamingTTFB; + successfulRuns++; + + logger.debug("iter={} classic={}ms streaming={}ms", i + 1, traditionalTTFB, streamingTTFB); + } + + if (successfulRuns == 0) { + logger.warn("No successful iterations completed"); + return; + } + + long avgTraditional = totalTraditional / successfulRuns; + long avgStreaming = totalStreaming / successfulRuns; + + // Compute and report improvement + double improvement = ((double) (avgTraditional - avgStreaming) / Math.max(1, avgTraditional)) * 100; + logger.info("TTFB classic (full reduce): {} ms", avgTraditional); + logger.info("TTFB streaming (first partial): {} ms", avgStreaming); + logger.info( + "TTFB improvement: {}% (delta={} ms)", + String.format(Locale.ROOT, "%.1f", improvement), + (avgTraditional - avgStreaming) + ); + } + + private long measureTraditionalTTFB(int size) { + // Traditional: must wait for ALL shards to complete before fetch can start + SearchRequest request = new SearchRequest(INDEX); + request.source() + .size(size) + // Match all with two sorts for processing load + .query(QueryBuilders.matchAllQuery()) + .sort("score", org.opensearch.search.sort.SortOrder.DESC) + .sort("timestamp", org.opensearch.search.sort.SortOrder.ASC) + .trackTotalHits(true) + .explain(false); + + // Don't use cache to get realistic timing + request.requestCache(false); + request.setPreFilterShardSize(10000); // avoid prefilter path issues during benchmark + + long start = System.currentTimeMillis(); + try { + SearchResponse response = client().search(request).actionGet(); + long end = System.currentTimeMillis(); + + // In traditional approach, fetch phase starts only after ALL shards respond + // This is the total query time since fetch can't start until query completes + long ttfb = end - start; + + // Log for debugging + if (ttfb < 50) { + logger.debug("Traditional TTFB seems low: {} ms, hits: {}", ttfb, response.getHits().getTotalHits()); + } + + return ttfb; + } catch (Exception e) { + logger.error("Traditional search failed", e); + return -1; + } + } + + private long measureStreamingTTFB(int size) { + // Streaming: measure actual time when first batch is ready for fetch + SearchRequest request = new SearchRequest(INDEX); + request.source() + .size(size) + // Same query as traditional + .query(QueryBuilders.matchAllQuery()) + .sort("score", org.opensearch.search.sort.SortOrder.DESC) + .sort("timestamp", org.opensearch.search.sort.SortOrder.ASC) + .trackTotalHits(true) + .explain(false); + request.setStreamingSearchMode(StreamingSearchMode.NO_SCORING.toString()); + request.setBatchedReduceSize(1); // Process first shard immediately + request.requestCache(false); + request.setPreFilterShardSize(10000); // avoid prefilter path issues during benchmark + + final CountDownLatch firstPartial = new CountDownLatch(1); + final CountDownLatch finished = new CountDownLatch(1); + final AtomicLong ttfbMs = new AtomicLong(-1); + final long startNanos = System.nanoTime(); + + ActionListener finalListener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse searchResponse) { + finished.countDown(); + } + + @Override + public void onFailure(Exception e) { + finished.countDown(); + logger.error("Streaming search failed", e); + } + }; + + StreamingSearchResponseListener streamingListener = new StreamingSearchResponseListener(finalListener, request) { + @Override + public void onPartialResponse(SearchResponse partialResponse) { + super.onPartialResponse(partialResponse); + if (firstPartial.getCount() > 0) { + long ttfb = java.util.concurrent.TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos); + ttfbMs.compareAndSet(-1, ttfb); + firstPartial.countDown(); + } + } + }; + + client().execute(SearchAction.INSTANCE, request, streamingListener); + + try { + if (!firstPartial.await(30, java.util.concurrent.TimeUnit.SECONDS)) { + logger.warn("Timed out waiting for first partial"); + return -1; + } + // Allow the request to finish (not strictly required for TTFB) + finished.await(30, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + return ttfbMs.get(); + } + + private void indexTestData(int numDocs) throws Exception { + BulkRequestBuilder bulkRequest = client().prepareBulk(); + int batchSize = 10000; // Batch size for bulk indexing + + for (int i = 1; i <= numDocs; i++) { + // Create documents with varied content for complex queries + String[] words = { "document", "test", "search", "query", "data", "result", "match", "score" }; + String content = "Test content for " + + words[i % words.length] + + " document " + + i + + " with additional text to make wildcard queries more expensive"; + + String doc = String.format( + Locale.ROOT, + "{\"id\":%d,\"title\":\"Document %d\",\"content\":\"%s\",\"score\":%.2f,\"timestamp\":%d,\"category\":\"%s\"}", + i, + i, + content, + randomDouble() * 100, + System.currentTimeMillis() - (numDocs - i) * 1000, + "category_" + (i % 100) + ); + bulkRequest.add(client().prepareIndex(INDEX).setSource(doc, XContentType.JSON)); + + if (i % batchSize == 0) { + bulkRequest.execute().actionGet(); + bulkRequest = client().prepareBulk(); + if (i % 100000 == 0) { + logger.info(" Indexed {} documents...", i); + } + } + } + + if (bulkRequest.numberOfActions() > 0) { + bulkRequest.execute().actionGet(); + } + } + +} diff --git a/server/src/test/java/org/opensearch/search/DefaultSearchContextStreamingTests.java b/server/src/test/java/org/opensearch/search/DefaultSearchContextStreamingTests.java new file mode 100644 index 0000000000000..2a43d94322651 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/DefaultSearchContextStreamingTests.java @@ -0,0 +1,330 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.QueryCachingPolicy; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.Version; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchType; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.MockBigArrays; +import org.opensearch.common.util.MockPageCacheRecycler; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.index.IndexService; +import org.opensearch.index.engine.Engine; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.search.internal.ReaderContext; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.StreamingSearchMode; +import org.opensearch.search.streaming.FlushMode; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Collections; +import java.util.UUID; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for DefaultSearchContext streaming flag propagation. + * Validates that streaming flags, mode, and flush mode are correctly set when streaming search is requested. + */ +public class DefaultSearchContextStreamingTests extends OpenSearchTestCase { + + public void testStreamingFlagsSetWhenStreamingRequested() throws Exception { + ThreadPool threadPool = new TestThreadPool(this.getClass().getName()); + + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.setStreamingScoring(true); + searchRequest.setStreamingSearchMode(StreamingSearchMode.NO_SCORING.toString()); + + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.searchType()).thenReturn(SearchType.DEFAULT); + when(shardSearchRequest.source()).thenReturn(searchRequest.source()); + when(shardSearchRequest.getStreamingSearchMode()).thenReturn(StreamingSearchMode.NO_SCORING.toString()); + + ShardId shardId = new ShardId("test-index", UUID.randomUUID().toString(), 0); + when(shardSearchRequest.shardId()).thenReturn(shardId); + + IndexShard indexShard = mock(IndexShard.class); + QueryCachingPolicy queryCachingPolicy = mock(QueryCachingPolicy.class); + when(indexShard.getQueryCachingPolicy()).thenReturn(queryCachingPolicy); + when(indexShard.getThreadPool()).thenReturn(threadPool); + + org.opensearch.cluster.metadata.IndexMetadata indexMetadata = org.opensearch.cluster.metadata.IndexMetadata.builder( + "test-index" + ) + .settings( + Settings.builder() + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED, org.opensearch.Version.CURRENT) + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .build(); + org.opensearch.index.IndexSettings indexSettings = new org.opensearch.index.IndexSettings(indexMetadata, Settings.EMPTY); + when(indexShard.indexSettings()).thenReturn(indexSettings); + + IndexService indexService = mock(IndexService.class); + BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + + final Supplier searcherSupplier = () -> new Engine.SearcherSupplier(Function.identity()) { + @Override + protected void doClose() {} + + @Override + protected Engine.Searcher acquireSearcherInternal(String source) { + try { + IndexReader reader = w.getReader(); + return new Engine.Searcher( + "test", + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + reader + ); + } catch (IOException exc) { + throw new AssertionError(exc); + } + } + }; + + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + ReaderContext readerContext = new ReaderContext( + new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()), + indexService, + indexShard, + searcherSupplier.get(), + randomNonNegativeLong(), + false + ); + + DefaultSearchContext context = new DefaultSearchContext( + readerContext, + shardSearchRequest, + target, + null, + bigArrays, + null, + null, + null, + false, + Version.CURRENT, + false, + null, + null, + Collections.emptyList() + ); + + assertTrue(context.isStreamingSearch()); + assertEquals(StreamingSearchMode.NO_SCORING, context.getStreamingMode()); + assertEquals(FlushMode.PER_SEGMENT, context.getFlushMode()); + + context.close(); + } finally { + threadPool.shutdown(); + } + } + + public void testStreamingFlagsScoredUnsortedMode() throws Exception { + ThreadPool threadPool = new TestThreadPool(this.getClass().getName()); + + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.searchType()).thenReturn(SearchType.DEFAULT); + when(shardSearchRequest.getStreamingSearchMode()).thenReturn(StreamingSearchMode.SCORED_UNSORTED.toString()); + + ShardId shardId = new ShardId("test-index", UUID.randomUUID().toString(), 0); + when(shardSearchRequest.shardId()).thenReturn(shardId); + + IndexShard indexShard = mock(IndexShard.class); + QueryCachingPolicy queryCachingPolicy = mock(QueryCachingPolicy.class); + when(indexShard.getQueryCachingPolicy()).thenReturn(queryCachingPolicy); + when(indexShard.getThreadPool()).thenReturn(threadPool); + + org.opensearch.cluster.metadata.IndexMetadata indexMetadata = org.opensearch.cluster.metadata.IndexMetadata.builder( + "test-index" + ) + .settings( + Settings.builder() + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED, org.opensearch.Version.CURRENT) + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .build(); + org.opensearch.index.IndexSettings indexSettings = new org.opensearch.index.IndexSettings(indexMetadata, Settings.EMPTY); + when(indexShard.indexSettings()).thenReturn(indexSettings); + + IndexService indexService = mock(IndexService.class); + BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + + final Supplier searcherSupplier = () -> new Engine.SearcherSupplier(Function.identity()) { + @Override + protected void doClose() {} + + @Override + protected Engine.Searcher acquireSearcherInternal(String source) { + try { + IndexReader reader = w.getReader(); + return new Engine.Searcher( + "test", + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + reader + ); + } catch (IOException exc) { + throw new AssertionError(exc); + } + } + }; + + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + ReaderContext readerContext = new ReaderContext( + new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()), + indexService, + indexShard, + searcherSupplier.get(), + randomNonNegativeLong(), + false + ); + + DefaultSearchContext context = new DefaultSearchContext( + readerContext, + shardSearchRequest, + target, + null, + bigArrays, + null, + null, + null, + false, + Version.CURRENT, + false, + null, + null, + Collections.emptyList() + ); + + assertTrue(context.isStreamingSearch()); + assertEquals(StreamingSearchMode.SCORED_UNSORTED, context.getStreamingMode()); + assertEquals(FlushMode.PER_SEGMENT, context.getFlushMode()); + + context.close(); + } finally { + threadPool.shutdown(); + } + } + + public void testNonStreamingDoesNotSetStreamingFlags() throws Exception { + ThreadPool threadPool = new TestThreadPool(this.getClass().getName()); + + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.searchType()).thenReturn(SearchType.DEFAULT); + when(shardSearchRequest.getStreamingSearchMode()).thenReturn(null); + + ShardId shardId = new ShardId("test-index", UUID.randomUUID().toString(), 0); + when(shardSearchRequest.shardId()).thenReturn(shardId); + + IndexShard indexShard = mock(IndexShard.class); + QueryCachingPolicy queryCachingPolicy = mock(QueryCachingPolicy.class); + when(indexShard.getQueryCachingPolicy()).thenReturn(queryCachingPolicy); + when(indexShard.getThreadPool()).thenReturn(threadPool); + + org.opensearch.cluster.metadata.IndexMetadata indexMetadata = org.opensearch.cluster.metadata.IndexMetadata.builder( + "test-index" + ) + .settings( + Settings.builder() + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED, org.opensearch.Version.CURRENT) + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .build(); + org.opensearch.index.IndexSettings indexSettings = new org.opensearch.index.IndexSettings(indexMetadata, Settings.EMPTY); + when(indexShard.indexSettings()).thenReturn(indexSettings); + + IndexService indexService = mock(IndexService.class); + BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + + final Supplier searcherSupplier = () -> new Engine.SearcherSupplier(Function.identity()) { + @Override + protected void doClose() {} + + @Override + protected Engine.Searcher acquireSearcherInternal(String source) { + try { + IndexReader reader = w.getReader(); + return new Engine.Searcher( + "test", + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + reader + ); + } catch (IOException exc) { + throw new AssertionError(exc); + } + } + }; + + SearchShardTarget target = new SearchShardTarget("node1", shardId, null, OriginalIndices.NONE); + ReaderContext readerContext = new ReaderContext( + new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()), + indexService, + indexShard, + searcherSupplier.get(), + randomNonNegativeLong(), + false + ); + + DefaultSearchContext context = new DefaultSearchContext( + readerContext, + shardSearchRequest, + target, + null, + bigArrays, + null, + null, + null, + false, + Version.CURRENT, + false, + null, + null, + Collections.emptyList() + ); + + assertFalse(context.isStreamingSearch()); + assertNull(context.getStreamingMode()); + + context.close(); + } finally { + threadPool.shutdown(); + } + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java index 112a47527d0f0..9088f290b0e81 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/StreamStringTermsAggregatorTests.java @@ -294,8 +294,7 @@ public void testBuildAggregationsBatchWithSize() throws Exception { StringTerms result = (StringTerms) aggregator.buildAggregations(new long[] { 0 })[0]; assertThat(result, notNullValue()); - // For streaming aggregator, size limitation may not be applied at buildAggregations level - // but rather handled during the reduce phase. Test that we get all terms for this batch. + // Streaming aggregator returns all terms; size applied during reduce assertThat(result.getBuckets().size(), equalTo(10)); // Verify each term appears exactly twice (20 docs / 10 unique terms) diff --git a/server/src/test/java/org/opensearch/search/query/StreamingCollectorContextsTests.java b/server/src/test/java/org/opensearch/search/query/StreamingCollectorContextsTests.java new file mode 100644 index 0000000000000..f2836a953ce47 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/query/StreamingCollectorContextsTests.java @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.Sort; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests streaming collector context instantiation. + */ +public class StreamingCollectorContextsTests extends OpenSearchTestCase { + + private SearchContext mockSearchContext; + private CircuitBreaker mockCircuitBreaker; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockSearchContext = mock(SearchContext.class); + mockCircuitBreaker = mock(CircuitBreaker.class); + when(mockSearchContext.size()).thenReturn(10); + } + + public void testStreamingUnsortedCollectorContextInstantiation() throws IOException { + StreamingUnsortedCollectorContext context1 = new StreamingUnsortedCollectorContext("test_unsorted", 10, mockSearchContext); + assertNotNull(context1); + assertEquals(10, context1.numHits()); + + Collector collector1 = context1.create(null); + assertNotNull(collector1); + + StreamingUnsortedCollectorContext context2 = new StreamingUnsortedCollectorContext( + "test_unsorted_breaker", + 20, + mockSearchContext, + mockCircuitBreaker + ); + assertNotNull(context2); + assertEquals(20, context2.numHits()); + + Collector collector2 = context2.create(null); + assertNotNull(collector2); + } + + public void testStreamingScoredUnsortedCollectorContextInstantiation() throws IOException { + StreamingScoredUnsortedCollectorContext context1 = new StreamingScoredUnsortedCollectorContext( + "test_scored_unsorted", + 10, + mockSearchContext + ); + assertNotNull(context1); + assertEquals(10, context1.numHits()); + + Collector collector1 = context1.create(null); + assertNotNull(collector1); + + StreamingScoredUnsortedCollectorContext context2 = new StreamingScoredUnsortedCollectorContext( + "test_scored_unsorted_breaker", + 20, + mockSearchContext, + mockCircuitBreaker + ); + assertNotNull(context2); + assertEquals(20, context2.numHits()); + + Collector collector2 = context2.create(null); + assertNotNull(collector2); + } + + public void testStreamingSortedCollectorContextInstantiation() throws IOException { + StreamingSortedCollectorContext context1 = new StreamingSortedCollectorContext("test_sorted", 10, mockSearchContext); + assertNotNull(context1); + assertEquals(10, context1.numHits()); + + Collector collector1 = context1.create(null); + assertNotNull(collector1); + + StreamingSortedCollectorContext context2 = new StreamingSortedCollectorContext( + "test_sorted_with_sort", + 15, + mockSearchContext, + Sort.RELEVANCE + ); + assertNotNull(context2); + assertEquals(15, context2.numHits()); + + Collector collector2 = context2.create(null); + assertNotNull(collector2); + + StreamingSortedCollectorContext context3 = new StreamingSortedCollectorContext( + "test_sorted_full", + 20, + mockSearchContext, + Sort.RELEVANCE, + mockCircuitBreaker + ); + assertNotNull(context3); + assertEquals(20, context3.numHits()); + + Collector collector3 = context3.create(null); + assertNotNull(collector3); + + StreamingSortedCollectorContext context4 = new StreamingSortedCollectorContext( + "test_sorted_breaker", + 25, + mockSearchContext, + mockCircuitBreaker + ); + assertNotNull(context4); + assertEquals(25, context4.numHits()); + + Collector collector4 = context4.create(null); + assertNotNull(collector4); + } +} diff --git a/server/src/test/java/org/opensearch/search/query/StreamingQuerySelectionTests.java b/server/src/test/java/org/opensearch/search/query/StreamingQuerySelectionTests.java new file mode 100644 index 0000000000000..0384cb52dca1c --- /dev/null +++ b/server/src/test/java/org/opensearch/search/query/StreamingQuerySelectionTests.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Sort; +import org.opensearch.common.util.BigArrays; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.NoopCircuitBreaker; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.sort.SortAndFormats; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests streaming collector selection. + */ +public class StreamingQuerySelectionTests extends OpenSearchTestCase { + + private SearchContext mockSearchContext; + private CircuitBreaker mockCircuitBreaker; + private BigArrays mockBigArrays; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockSearchContext = mock(SearchContext.class); + mockCircuitBreaker = new NoopCircuitBreaker(CircuitBreaker.REQUEST); + mockBigArrays = mock(BigArrays.class); + when(mockSearchContext.size()).thenReturn(10); + when(mockSearchContext.bigArrays()).thenReturn(mockBigArrays); + } + + public void testNoScoringModeSelectsStreamingUnsortedCollector() throws IOException { + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.NO_SCORING); + + TopDocsCollectorContext context = TopDocsCollectorContext.createStreamingTopDocsCollectorContext( + mockSearchContext, + false + ); + + assertNotNull(context); + assertTrue(context instanceof StreamingUnsortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testScoredUnsortedModeSelectsStreamingScoredUnsortedCollector() throws IOException { + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.SCORED_UNSORTED); + + TopDocsCollectorContext context = TopDocsCollectorContext.createStreamingTopDocsCollectorContext( + mockSearchContext, + false + ); + + assertNotNull(context); + assertTrue(context instanceof StreamingScoredUnsortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testScoredSortedModeSelectsStreamingSortedCollector() throws IOException { + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.SCORED_SORTED); + when(mockSearchContext.sort()).thenReturn(null); + + TopDocsCollectorContext context = TopDocsCollectorContext.createStreamingTopDocsCollectorContext( + mockSearchContext, + false + ); + + assertNotNull(context); + assertTrue(context instanceof StreamingSortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testScoredSortedModeWithCustomSort() throws IOException { + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.SCORED_SORTED); + + SortAndFormats sortAndFormats = new SortAndFormats( + Sort.INDEXORDER, + new org.opensearch.search.DocValueFormat[] { org.opensearch.search.DocValueFormat.RAW } + ); + when(mockSearchContext.sort()).thenReturn(sortAndFormats); + + TopDocsCollectorContext context = TopDocsCollectorContext.createStreamingTopDocsCollectorContext( + mockSearchContext, + false + ); + + assertNotNull(context); + assertTrue(context instanceof StreamingSortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testNullStreamingModeThrowsException() { + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(null); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> TopDocsCollectorContext.createStreamingTopDocsCollectorContext(mockSearchContext, false) + ); + + assertEquals("Streaming mode must be set for streaming collectors", exception.getMessage()); + } + + public void testFallbackToNonStreamingPath() throws IOException { + when(mockSearchContext.isStreamingSearch()).thenReturn(false); + when(mockSearchContext.getStreamingMode()).thenReturn(null); + + assertFalse(mockSearchContext.isStreamingSearch()); + } +} diff --git a/server/src/test/java/org/opensearch/search/query/TopDocsCollectorContextEntrypointTests.java b/server/src/test/java/org/opensearch/search/query/TopDocsCollectorContextEntrypointTests.java new file mode 100644 index 0000000000000..7c0d7b56a5199 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/query/TopDocsCollectorContextEntrypointTests.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.opensearch.common.util.BigArrays; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests TopDocsCollectorContext routing to streaming collectors. + */ +public class TopDocsCollectorContextEntrypointTests extends OpenSearchTestCase { + + public void testStreamingBranchSelectedWhenStreamingEnabled() throws IOException { + SearchContext mockSearchContext = mock(SearchContext.class); + BigArrays mockBigArrays = mock(BigArrays.class); + + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.NO_SCORING); + when(mockSearchContext.size()).thenReturn(10); + when(mockSearchContext.bigArrays()).thenReturn(mockBigArrays); + + TopDocsCollectorContext context = TopDocsCollectorContext.createTopDocsCollectorContext(mockSearchContext, false); + + assertNotNull(context); + assertTrue(context instanceof StreamingUnsortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testStreamingBranchSelectedForScoredSorted() throws IOException { + SearchContext mockSearchContext = mock(SearchContext.class); + BigArrays mockBigArrays = mock(BigArrays.class); + + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.SCORED_SORTED); + when(mockSearchContext.size()).thenReturn(10); + when(mockSearchContext.bigArrays()).thenReturn(mockBigArrays); + when(mockSearchContext.sort()).thenReturn(null); + + TopDocsCollectorContext context = TopDocsCollectorContext.createTopDocsCollectorContext(mockSearchContext, false); + + assertNotNull(context); + assertTrue(context instanceof StreamingSortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testStreamingBranchSelectedForScoredUnsorted() throws IOException { + SearchContext mockSearchContext = mock(SearchContext.class); + BigArrays mockBigArrays = mock(BigArrays.class); + + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.SCORED_UNSORTED); + when(mockSearchContext.size()).thenReturn(10); + when(mockSearchContext.bigArrays()).thenReturn(mockBigArrays); + + TopDocsCollectorContext context = TopDocsCollectorContext.createTopDocsCollectorContext(mockSearchContext, false); + + assertNotNull(context); + assertTrue(context instanceof StreamingScoredUnsortedCollectorContext); + assertEquals(10, context.numHits()); + } + + public void testNonStreamingBranchWhenStreamingDisabled() throws IOException { + SearchContext mockSearchContext = mock(SearchContext.class); + BigArrays mockBigArrays = mock(BigArrays.class); + + when(mockSearchContext.isStreamingSearch()).thenReturn(false); + when(mockSearchContext.getStreamingMode()).thenReturn(StreamingSearchMode.NO_SCORING); + + assertFalse(mockSearchContext.isStreamingSearch()); + } + + public void testNonStreamingBranchWhenModeIsNull() throws IOException { + SearchContext mockSearchContext = mock(SearchContext.class); + BigArrays mockBigArrays = mock(BigArrays.class); + + when(mockSearchContext.isStreamingSearch()).thenReturn(true); + when(mockSearchContext.getStreamingMode()).thenReturn(null); + + assertTrue(mockSearchContext.isStreamingSearch()); + assertNull(mockSearchContext.getStreamingMode()); + } +} diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index a936d4ce79ec2..09532b899fb78 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -2382,8 +2382,10 @@ public void onFailure(final Exception e) { threadPool, new NoneCircuitBreakerService(), transportService, + null, // StreamTransportService - not available in test searchService, searchTransportService, + null, // StreamSearchTransportService - not available in test searchPhaseController, clusterService, actionFilters, diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index f03cbe266df86..f0e8c8c02e7d4 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -74,6 +74,7 @@ import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.query.StreamingSearchMode; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.suggest.SuggestionSearchContext; @@ -138,6 +139,16 @@ public void setMaxSliceCount(int sliceCount) { this.maxSliceCount = sliceCount; } + @Override + public StreamingSearchMode getStreamingMode() { + return null; // TestSearchContext doesn't support streaming + } + + @Override + public void setStreamingMode(StreamingSearchMode mode) { + // TestSearchContext doesn't support streaming - no-op + } + private final Map searchExtBuilders = new HashMap<>(); public TestSearchContext(BigArrays bigArrays, IndexService indexService) { diff --git a/test/framework/src/main/java/org/opensearch/transport/nio/MockNativeMessageHandler.java b/test/framework/src/main/java/org/opensearch/transport/nio/MockNativeMessageHandler.java index 9852aef16b375..fd19af8a1fc1c 100644 --- a/test/framework/src/main/java/org/opensearch/transport/nio/MockNativeMessageHandler.java +++ b/test/framework/src/main/java/org/opensearch/transport/nio/MockNativeMessageHandler.java @@ -123,4 +123,5 @@ protected TcpTransportChannel createTcpTransportChannel( private boolean requiresStreaming(String action) { return STREAMING_ACTIONS.contains(action) || action.contains("stream"); } + } diff --git a/test/framework/src/main/java/org/opensearch/transport/nio/MockStreamingTransportChannel.java b/test/framework/src/main/java/org/opensearch/transport/nio/MockStreamingTransportChannel.java index de1767f1729e2..98b390c6f1003 100644 --- a/test/framework/src/main/java/org/opensearch/transport/nio/MockStreamingTransportChannel.java +++ b/test/framework/src/main/java/org/opensearch/transport/nio/MockStreamingTransportChannel.java @@ -98,6 +98,7 @@ public void completeStream() { bufferedResponses.size() ); + boolean releaseNeeded = true; try { // Get the response handler and call handleStreamResponse with all buffered responses TransportResponseHandler handler = responseHandlers.onResponseReceived(requestId, messageListener); @@ -118,27 +119,42 @@ public void completeStream() { responsesCopy.size() ); typedHandler.handleStreamResponse(streamResponse); + + // Success - release normally + release(false); + releaseNeeded = false; } catch (Exception e) { // Release resources on failure release(true); + releaseNeeded = false; throw new StreamException(StreamErrorCode.INTERNAL, "Error completing stream", e); } finally { - // Release circuit breaker resources when stream is completed - release(false); + // Only release if not already released + if (releaseNeeded) { + release(false); + } } } else { logger.warn("CompleteStream called on already closed stream with action[{}] and requestId[{}]", action, requestId); - throw new StreamException(StreamErrorCode.UNAVAILABLE, "MockStreamingTransportChannel stream already closed."); + // Don't throw exception here as stream is already closed + // This can happen when onFailure calls sendResponse which releases, then completeStream is called } } @Override public void sendResponse(TransportResponse response) throws IOException { - // For streaming channels, regular sendResponse is not supported - // Clients should use sendResponseBatch instead - throw new UnsupportedOperationException( - "sendResponse() is not supported for streaming requests in MockStreamingTransportChannel. Use sendResponseBatch() instead." - ); + // For streaming channels, regular sendResponse is not supported for normal responses + // But we need to support it for exception responses + // Call parent's sendResponse which will handle release + super.sendResponse(response); + } + + @Override + public void sendResponse(Exception exception) throws IOException { + // Mark stream as closed to prevent further operations + streamOpen.set(false); + // Call parent's sendResponse which will handle the exception and release + super.sendResponse(exception); } @Override