Skip to content

Commit a38a8a4

Browse files
authored
Fix Ask AI conversation ID handling (#2574)
1 parent 1c347c9 commit a38a8a4

File tree

13 files changed

+151
-126
lines changed

13 files changed

+151
-126
lines changed

src/api/Elastic.Documentation.Api.Core/AskAi/AskAiUsecase.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace Elastic.Documentation.Api.Core.AskAi;
1010

1111
public class AskAiUsecase(
12-
IAskAiGateway<Stream> askAiGateway,
12+
IAskAiGateway askAiGateway,
1313
IStreamTransformer streamTransformer,
1414
ILogger<AskAiUsecase> logger)
1515
{
@@ -24,6 +24,7 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx)
2424
_ = activity?.SetTag("gen_ai.agent.id", streamTransformer.AgentId); // docs-agent or docs_assistant
2525
if (askAiRequest.ConversationId is not null)
2626
_ = activity?.SetTag("gen_ai.conversation.id", askAiRequest.ConversationId.ToString());
27+
2728
var inputMessages = new[]
2829
{
2930
new InputMessage("user", [new MessagePart("text", askAiRequest.Message)])
@@ -33,9 +34,22 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx)
3334
var sanitizedMessage = askAiRequest.Message?.Replace("\r", "").Replace("\n", "");
3435
logger.LogInformation("AskAI input message: <{ask_ai.input.message}>", sanitizedMessage);
3536
logger.LogInformation("Streaming AskAI response");
36-
var rawStream = await askAiGateway.AskAi(askAiRequest, ctx);
37-
// The stream transformer will handle disposing the activity when streaming completes
38-
var transformedStream = await streamTransformer.TransformAsync(rawStream, askAiRequest.ConversationId?.ToString(), activity, ctx);
37+
38+
// Gateway handles conversation ID generation if needed
39+
var response = await askAiGateway.AskAi(askAiRequest, ctx);
40+
41+
// Use generated ID if available, otherwise use the original request ID
42+
var conversationId = response.GeneratedConversationId ?? askAiRequest.ConversationId;
43+
if (conversationId is not null)
44+
_ = activity?.SetTag("gen_ai.conversation.id", conversationId.ToString());
45+
46+
// The stream transformer takes ownership of the activity and disposes it when streaming completes.
47+
// This is necessary because streaming happens asynchronously after this method returns.
48+
var transformedStream = await streamTransformer.TransformAsync(
49+
response.Stream,
50+
response.GeneratedConversationId,
51+
activity,
52+
ctx);
3953
return transformedStream;
4054
}
4155
}

src/api/Elastic.Documentation.Api.Core/AskAi/IAskAiGateway.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@
44

55
namespace Elastic.Documentation.Api.Core.AskAi;
66

7-
public interface IAskAiGateway<T>
7+
/// <summary>
8+
/// Response from an AI gateway containing the stream and conversation metadata
9+
/// </summary>
10+
/// <param name="Stream">The SSE response stream</param>
11+
/// <param name="GeneratedConversationId">
12+
/// Non-null ONLY if the gateway generated a new conversation ID for this request.
13+
/// When set, the transformer should emit a ConversationStart event with this ID.
14+
/// Null means either: (1) user provided an ID (continuing conversation), or (2) gateway doesn't generate IDs (e.g., Agent Builder).
15+
/// </param>
16+
public record AskAiGatewayResponse(Stream Stream, Guid? GeneratedConversationId);
17+
18+
public interface IAskAiGateway
819
{
9-
Task<T> AskAi(AskAiRequest askAiRequest, Cancel ctx = default);
20+
Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default);
1021
}

src/api/Elastic.Documentation.Api.Core/AskAi/IStreamTransformer.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ public interface IStreamTransformer
2323
/// Transforms a raw SSE stream into a stream of AskAiEvent objects
2424
/// </summary>
2525
/// <param name="rawStream">Raw SSE stream from gateway (Agent Builder, LLM Gateway, etc.)</param>
26-
/// <param name="conversationId">Thread/conversation ID (if known)</param>
26+
/// <param name="generatedConversationId">
27+
/// Non-null if the gateway generated a new conversation ID (LLM Gateway only).
28+
/// When set, transformer should emit ConversationStart event with this ID.
29+
/// </param>
2730
/// <param name="parentActivity">Parent activity to track the streaming operation (will be disposed when stream completes)</param>
2831
/// <param name="cancellationToken">Cancellation token</param>
2932
/// <returns>Stream containing SSE-formatted AskAiEvent objects</returns>
30-
Task<Stream> TransformAsync(Stream rawStream, string? conversationId, System.Diagnostics.Activity? parentActivity, CancellationToken cancellationToken = default);
33+
Task<Stream> TransformAsync(Stream rawStream, Guid? generatedConversationId, System.Diagnostics.Activity? parentActivity, CancellationToken cancellationToken = default);
3134
}

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AgentBuilderAskAiGateway.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
1414

15-
public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kibanaOptions, ILogger<AgentBuilderAskAiGateway> logger) : IAskAiGateway<Stream>
15+
public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kibanaOptions, ILogger<AgentBuilderAskAiGateway> logger) : IAskAiGateway
1616
{
1717
/// <summary>
1818
/// Model name used by Agent Builder (from AgentId)
@@ -23,8 +23,10 @@ public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kiban
2323
/// Provider name for tracing
2424
/// </summary>
2525
public const string ProviderName = "agent-builder";
26-
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
26+
public async Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
2727
{
28+
// Agent Builder returns the conversation ID in the stream via conversation_id_set event
29+
// We don't generate IDs - Agent Builder handles that in the stream
2830
var agentBuilderPayload = new AgentBuilderPayload(
2931
askAiRequest.Message,
3032
"docs-agent",
@@ -55,7 +57,9 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
5557
logger.LogInformation("Response Content-Length: {ContentLength}", response.Content.Headers.ContentLength?.ToString(CultureInfo.InvariantCulture));
5658

5759
// Agent Builder already returns SSE format, just return the stream directly
58-
return await response.Content.ReadAsStreamAsync(ctx);
60+
// The conversation ID will be extracted from the stream by the transformer
61+
var stream = await response.Content.ReadAsStreamAsync(ctx);
62+
return new AskAiGatewayResponse(stream, GeneratedConversationId: null);
5963
}
6064
}
6165

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AskAiGatewayFactory.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
1414
public class AskAiGatewayFactory(
1515
IServiceProvider serviceProvider,
1616
AskAiProviderResolver providerResolver,
17-
ILogger<AskAiGatewayFactory> logger) : IAskAiGateway<Stream>
17+
ILogger<AskAiGatewayFactory> logger) : IAskAiGateway
1818
{
19-
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
19+
public async Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
2020
{
2121
var provider = providerResolver.ResolveProvider();
2222

23-
IAskAiGateway<Stream> gateway = provider switch
23+
IAskAiGateway gateway = provider switch
2424
{
2525
"LlmGateway" => serviceProvider.GetRequiredService<LlmGatewayAskAiGateway>(),
2626
"AgentBuilder" => serviceProvider.GetRequiredService<AgentBuilderAskAiGateway>(),

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
1212

13-
public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway<Stream>
13+
public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway
1414
{
1515
/// <summary>
1616
/// Model name used by LLM Gateway (from PlatformContext.UseCase)
@@ -21,9 +21,13 @@ public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider t
2121
/// Provider name for tracing
2222
/// </summary>
2323
public const string ProviderName = "llm-gateway";
24-
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
24+
public async Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
2525
{
26-
var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest);
26+
// LLM Gateway requires a ThreadId - generate one if not provided
27+
var generatedId = askAiRequest.ConversationId is null ? Guid.NewGuid() : (Guid?)null;
28+
var threadId = askAiRequest.ConversationId ?? generatedId!.Value;
29+
30+
var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest, threadId);
2731
var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest);
2832
using var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl);
2933
request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json");
@@ -46,7 +50,8 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
4650

4751
// Return the response stream directly - this enables true streaming
4852
// The stream will be consumed as data arrives from the LLM Gateway
49-
return await response.Content.ReadAsStreamAsync(ctx);
53+
var stream = await response.Content.ReadAsStreamAsync(ctx);
54+
return new AskAiGatewayResponse(stream, generatedId);
5055
}
5156
}
5257

@@ -57,7 +62,7 @@ public record LlmGatewayRequest(
5762
string ThreadId
5863
)
5964
{
60-
public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
65+
public static LlmGatewayRequest CreateFromRequest(AskAiRequest request, Guid conversationId) =>
6166
new(
6267
UserContext: new UserContext("elastic-docs-v3@invalid"),
6368
PlatformContext: new PlatformContext("docs_site", "docs_assistant", []),
@@ -66,7 +71,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
6671
// new ChatInput("user", AskAiRequest.SystemPrompt),
6772
new ChatInput("user", request.Message)
6873
],
69-
ThreadId: request.ConversationId?.ToString() ?? Guid.NewGuid().ToString()
74+
ThreadId: conversationId.ToString()
7075
);
7176
}
7277

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayStreamTransformer.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,33 @@ public class LlmGatewayStreamTransformer(ILogger<LlmGatewayStreamTransformer> lo
1919
protected override string GetAgentProvider() => LlmGatewayAskAiGateway.ProviderName;
2020

2121
/// <summary>
22-
/// Override to emit ConversationStart event when conversationId is null (new conversation)
22+
/// Override to emit ConversationStart event for new conversations.
23+
/// LLM Gateway doesn't return a conversation ID, so we emit one to match Agent Builder behavior.
24+
/// The generatedConversationId is the ID generated by the gateway and used as ThreadId.
2325
/// </summary>
24-
protected override async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken)
26+
protected override async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken)
2527
{
26-
// If conversationId is null, generate a new one and emit ConversationStart event
27-
// This matches the ThreadId format used in LlmGatewayAskAiGateway
28-
var actualConversationId = conversationId;
29-
if (conversationId == null)
28+
// Emit ConversationStart event only if a new conversation ID was generated
29+
if (generatedConversationId is not null)
3030
{
31-
actualConversationId = Guid.NewGuid().ToString();
31+
var conversationId = generatedConversationId.Value.ToString();
3232
var timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
3333
var conversationStartEvent = new AskAiEvent.ConversationStart(
3434
Id: Guid.NewGuid().ToString(),
3535
Timestamp: timestamp,
36-
ConversationId: actualConversationId
36+
ConversationId: conversationId
3737
);
3838

3939
// Set activity tags for the new conversation
40-
_ = parentActivity?.SetTag("gen_ai.conversation.id", actualConversationId);
41-
Logger.LogDebug("LLM Gateway conversation started: {ConversationId}", actualConversationId);
40+
_ = parentActivity?.SetTag("gen_ai.conversation.id", conversationId);
41+
Logger.LogDebug("LLM Gateway conversation started: {ConversationId}", conversationId);
4242

4343
// Write the ConversationStart event to the stream
4444
await WriteEventAsync(conversationStartEvent, writer, cancellationToken);
4545
}
4646

47-
// Continue with normal stream processing using the actual conversation ID
48-
await base.ProcessStreamAsync(reader, writer, actualConversationId, parentActivity, cancellationToken);
47+
// Continue with normal stream processing
48+
await base.ProcessStreamAsync(reader, writer, generatedConversationId, parentActivity, cancellationToken);
4949
}
5050
protected override AskAiEvent? TransformJsonEvent(string? eventType, JsonElement json)
5151
{

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerBase.cs

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public abstract class StreamTransformerBase(ILogger logger) : IStreamTransformer
4242
/// </summary>
4343
public string AgentProvider => GetAgentProvider();
4444

45-
public Task<Stream> TransformAsync(Stream rawStream, string? conversationId, Activity? parentActivity, Cancel cancellationToken = default)
45+
public Task<Stream> TransformAsync(Stream rawStream, Guid? generatedConversationId, Activity? parentActivity, Cancel cancellationToken = default)
4646
{
4747
// Configure pipe for low-latency streaming
4848
var pipeOptions = new PipeOptions(
@@ -61,7 +61,7 @@ public Task<Stream> TransformAsync(Stream rawStream, string? conversationId, Act
6161
// Note: We intentionally don't await this task as we need to return the stream immediately
6262
// The pipe handles synchronization and backpressure between producer and consumer
6363
// Pass parent activity - it will be disposed when streaming completes
64-
_ = ProcessPipeAsync(reader, pipe.Writer, conversationId, parentActivity, cancellationToken);
64+
_ = ProcessPipeAsync(reader, pipe.Writer, generatedConversationId, parentActivity, cancellationToken);
6565

6666
// Return the read side of the pipe as a stream
6767
return Task.FromResult(pipe.Reader.AsStream());
@@ -71,48 +71,42 @@ public Task<Stream> TransformAsync(Stream rawStream, string? conversationId, Act
7171
/// Process the pipe reader and write transformed events to the pipe writer.
7272
/// This runs concurrently with the consumer reading from the output stream.
7373
/// </summary>
74-
private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken)
74+
private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken)
7575
{
76+
using var activityScope = parentActivity;
7677
try
7778
{
79+
await ProcessStreamAsync(reader, writer, generatedConversationId, parentActivity, cancellationToken);
80+
}
81+
catch (OperationCanceledException ex)
82+
{
83+
Logger.LogDebug(ex, "Stream processing was cancelled for transformer {TransformerType}", GetType().Name);
84+
}
85+
catch (Exception ex)
86+
{
87+
Logger.LogError(ex, "Error transforming stream for transformer {TransformerType}. Stream processing will be terminated.", GetType().Name);
88+
_ = parentActivity?.SetTag("error.type", ex.GetType().Name);
7889
try
7990
{
80-
await ProcessStreamAsync(reader, writer, conversationId, parentActivity, cancellationToken);
91+
// Complete writer first, then reader - but don't try to complete reader
92+
// if the exception came from reading (would cause "read operation pending" error)
93+
await writer.CompleteAsync(ex);
8194
}
82-
catch (OperationCanceledException ex)
95+
catch (Exception completeEx)
8396
{
84-
Logger.LogDebug(ex, "Stream processing was cancelled for transformer {TransformerType}", GetType().Name);
85-
}
86-
catch (Exception ex)
87-
{
88-
Logger.LogError(ex, "Error transforming stream for transformer {TransformerType}. Stream processing will be terminated.", GetType().Name);
89-
_ = parentActivity?.SetTag("error.type", ex.GetType().Name);
90-
try
91-
{
92-
// Complete writer first, then reader - but don't try to complete reader
93-
// if the exception came from reading (would cause "read operation pending" error)
94-
await writer.CompleteAsync(ex);
95-
}
96-
catch (Exception completeEx)
97-
{
98-
Logger.LogError(completeEx, "Error completing pipe after transformation error for transformer {TransformerType}", GetType().Name);
99-
}
100-
return;
97+
Logger.LogError(completeEx, "Error completing pipe after transformation error for transformer {TransformerType}", GetType().Name);
10198
}
99+
return;
100+
}
102101

103-
// Normal completion - ensure cleanup happens
104-
try
105-
{
106-
await writer.CompleteAsync();
107-
}
108-
catch (Exception ex)
109-
{
110-
Logger.LogError(ex, "Error completing pipe after successful transformation");
111-
}
102+
// Normal completion - ensure cleanup happens
103+
try
104+
{
105+
await writer.CompleteAsync();
112106
}
113-
finally
107+
catch (Exception ex)
114108
{
115-
parentActivity?.Dispose();
109+
Logger.LogError(ex, "Error completing pipe after successful transformation");
116110
}
117111
}
118112

@@ -122,7 +116,7 @@ private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, string
122116
/// Default implementation parses SSE events and JSON, then calls TransformJsonEvent.
123117
/// </summary>
124118
/// <returns>Stream processing result with metrics and captured output</returns>
125-
protected virtual async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken)
119+
protected virtual async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken)
126120
{
127121
using var activity = StreamTransformerActivitySource.StartActivity("process ask_ai stream", ActivityKind.Internal);
128122

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ private IStreamTransformer GetTransformer()
3939
public string AgentId => GetTransformer().AgentId;
4040
public string AgentProvider => GetTransformer().AgentProvider;
4141

42-
public async Task<Stream> TransformAsync(Stream rawStream, string? conversationId, System.Diagnostics.Activity? parentActivity, Cancel cancellationToken = default)
42+
public async Task<Stream> TransformAsync(Stream rawStream, Guid? generatedConversationId, System.Diagnostics.Activity? parentActivity, Cancel cancellationToken = default)
4343
{
4444
var transformer = GetTransformer();
45-
return await transformer.TransformAsync(rawStream, conversationId, parentActivity, cancellationToken);
45+
return await transformer.TransformAsync(rawStream, generatedConversationId, parentActivity, cancellationToken);
4646
}
4747
}

src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv)
173173
logger?.LogInformation("Both stream transformers registered as concrete types");
174174

175175
// Register factories as interface implementations
176-
_ = services.AddScoped<IAskAiGateway<Stream>, AskAiGatewayFactory>();
176+
_ = services.AddScoped<IAskAiGateway, AskAiGatewayFactory>();
177177
_ = services.AddScoped<IStreamTransformer, StreamTransformerFactory>();
178178
logger?.LogInformation("Gateway and transformer factories registered successfully - provider switchable via X-AI-Provider header");
179179

0 commit comments

Comments
 (0)