Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions .github/actions/setup-test-environment/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,26 @@ runs:
run: |
# Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "Installing latest llama-stack-client-python from main branch"
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
unset LLAMA_STACK_CLIENT_DIR
else
echo "Invalid client-version: ${{ inputs.client-version }}"
exit 1
# Check if PR is targeting a release branch
TARGET_BRANCH="${{ github.base_ref }}"

if [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
echo "PR targets release branch: $TARGET_BRANCH"
echo "Checking if matching branch exists in llama-stack-client-python..."

# Check if the branch exists in the client repo
if git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$TARGET_BRANCH" > /dev/null 2>&1; then
echo "Installing llama-stack-client-python from matching branch: $TARGET_BRANCH"
uv pip install --force-reinstall git+https://github.com/llamastack/llama-stack-client-python.git@$TARGET_BRANCH
else
echo "::error::Branch $TARGET_BRANCH not found in llama-stack-client-python repository"
echo "::error::Please create the matching release branch in llama-stack-client-python before testing"
exit 1
fi
fi
# For main branch, client is already installed by setup-runner
fi
# For published version, client is already installed by setup-runner

echo "Building Llama Stack"

Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/integration-auth-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication

on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'distributions/**'
- 'llama_stack/**'
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/integration-sql-store-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore

on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'llama_stack/providers/utils/sqlstore/**'
- 'tests/integration/sqlstore/**'
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode

on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
types: [opened, synchronize, reopened]
paths:
- 'llama_stack/**'
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/integration-vector-io-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers

on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'llama_stack/**'
- '!llama_stack/ui/**'
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ run-name: Run pre-commit checks
on:
pull_request:
push:
branches: [main]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'

concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ run-name: Run the unit test suite

on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'llama_stack/**'
- '!llama_stack/ui/**'
Expand Down
62 changes: 53 additions & 9 deletions llama_stack/core/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from collections.abc import AsyncGenerator
from contextvars import ContextVar

from llama_stack.providers.utils.telemetry.tracing import CURRENT_TRACE_CONTEXT

_MISSING = object()


def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
Expand All @@ -21,20 +25,60 @@ def preserve_contexts_async_generator[T](

async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
# Restore context values before any await
for context_var in context_vars:
context_var.set(initial_context_values[context_var.name])
previous_values: dict[ContextVar, object] = {}
tokens: dict[ContextVar, object] = {}

item = await gen.__anext__()
# Restore ALL context values before any await and capture previous state
# This is needed to propagate context across async generator boundaries
for context_var in context_vars:
try:
previous_values[context_var] = context_var.get()
except LookupError:
previous_values[context_var] = _MISSING
tokens[context_var] = context_var.set(initial_context_values[context_var.name])

# Update our tracked values with any changes made during this iteration
for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get()
def _restore_context_var(context_var: ContextVar, *, _tokens=tokens, _prev=previous_values) -> None:
token = _tokens.get(context_var)
previous_value = _prev.get(context_var, _MISSING)
if token is not None:
try:
context_var.reset(token)
return
except (RuntimeError, ValueError):
pass

yield item
if previous_value is _MISSING:
context_var.set(None)
else:
context_var.set(previous_value)

try:
item = await gen.__anext__()
except StopAsyncIteration:
# Restore all context vars before exiting to prevent leaks
# Use _restore_context_var for all vars to properly restore to previous values
for context_var in context_vars:
_restore_context_var(context_var)
break
except Exception:
# Restore all context vars on exception
for context_var in context_vars:
_restore_context_var(context_var)
raise

try:
yield item
# Update our tracked values with any changes made during this iteration
# Only for non-trace context vars - trace context must persist across yields
# to allow nested span tracking for telemetry
for context_var in context_vars:
if context_var is not CURRENT_TRACE_CONTEXT:
initial_context_values[context_var.name] = context_var.get()
finally:
# Restore non-trace context vars after each yield to prevent leaks between requests
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
for context_var in context_vars:
if context_var is not CURRENT_TRACE_CONTEXT:
_restore_context_var(context_var)

return wrapper()
33 changes: 4 additions & 29 deletions llama_stack/ui/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions tests/unit/core/test_provider_data_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import json
from contextlib import contextmanager
from contextvars import ContextVar

from llama_stack.core.utils.context import preserve_contexts_async_generator

# Define provider data context variable and context manager locally
PROVIDER_DATA_VAR = ContextVar("provider_data", default=None)


@contextmanager
def request_provider_data_context(headers):
val = headers.get("X-LlamaStack-Provider-Data")
provider_data = json.loads(val) if val else {}
token = PROVIDER_DATA_VAR.set(provider_data)
try:
yield
finally:
PROVIDER_DATA_VAR.reset(token)


def create_sse_event(data):
return f"data: {json.dumps(data)}\n\n"


async def sse_generator(event_gen_coroutine):
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0)


async def async_event_gen():
async def event_gen():
yield PROVIDER_DATA_VAR.get()

return event_gen()


async def test_provider_data_context_cleared_between_sse_requests():
headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})}
with request_provider_data_context(headers):
gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])

events1 = [event async for event in gen1]
assert events1 == [create_sse_event({"api_key": "abc"})]
assert PROVIDER_DATA_VAR.get() is None

gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
events2 = [event async for event in gen2]
assert events2 == [create_sse_event(None)]
assert PROVIDER_DATA_VAR.get() is None
Loading
Loading