Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/src/cairo_coder/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
# Merge metrics
for key, value in metrics.items():
if isinstance(value, int | float):
result[model][key] = result[model].get(key, 0) + value
# Use (x or 0) to handle None values - dict.get() returns None
# if the key exists with value None, not the default
result[model][key] = (result[model].get(key) or 0) + value
elif isinstance(value, dict):
if key not in result[model] or result[model][key] is None:
result[model][key] = value.copy()
Expand All @@ -219,7 +221,7 @@ def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
for detail_key, detail_value in value.items():
if isinstance(detail_value, int | float):
result[model][key][detail_key] = (
result[model][key].get(detail_key, 0) + detail_value
(result[model][key].get(detail_key) or 0) + detail_value
)
return result

Expand Down
94 changes: 93 additions & 1 deletion python/tests/unit/test_rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
RagPipeline,
RagPipelineFactory,
)
from cairo_coder.core.types import Document, DocumentSource, Message, Role, StreamEventType
from cairo_coder.core.types import (
Document,
DocumentSource,
Message,
Role,
StreamEventType,
combine_usage,
)
from cairo_coder.dspy.retrieval_judge import RetrievalJudge


Expand Down Expand Up @@ -853,3 +860,88 @@ def test_create_pipeline_with_custom_components(self, mock_vector_store_config):
assert pipeline.config.max_source_count == 20
assert pipeline.config.similarity_threshold == 0.6
assert pipeline.config.sources == [DocumentSource.CAIRO_BOOK]


class TestCombineUsage:
"""Tests for the combine_usage function."""

@pytest.mark.parametrize(
"usage1,usage2,expected",
[
# Test handling None values in usage dicts
pytest.param(
{
"gpt-4": {
"prompt_tokens": None, # Key exists but value is None
"completion_tokens": 100,
}
},
{
"gpt-4": {
"prompt_tokens": 50,
"completion_tokens": 50,
}
},
{
"gpt-4": {
"prompt_tokens": 50,
"completion_tokens": 150,
}
},
id="none_values_in_usage_dict",
),
# Test handling None values in nested dicts
pytest.param(
{
"gpt-4": {
"details": {
"audio_tokens": None, # Key exists but value is None
"cached_tokens": 100,
}
}
},
{
"gpt-4": {
"details": {
"audio_tokens": 25,
"cached_tokens": 50,
}
}
},
{
"gpt-4": {
"details": {
"audio_tokens": 25,
"cached_tokens": 150,
}
}
},
id="none_values_in_nested_dict",
),
# Test basic combining of usage dicts
pytest.param(
{"gpt-4": {"prompt_tokens": 100, "completion_tokens": 50}},
{"gpt-4": {"prompt_tokens": 200, "completion_tokens": 100}},
{"gpt-4": {"prompt_tokens": 300, "completion_tokens": 150}},
id="basic_combining",
),
# Test combining with empty dicts
pytest.param(
{},
{},
{},
id="both_empty",
),
# Test combining with one empty dict
pytest.param(
{"gpt-4": {"tokens": 100}},
{},
{"gpt-4": {"tokens": 100}},
id="second_empty",
),
],
)
def test_combine_usage(self, usage1, usage2, expected):
"""Test combine_usage with various input scenarios."""
result = combine_usage(usage1, usage2)
assert result == expected