Skip to content

Commit 72f0a54

Browse files
authored
fix: model usage computation (#94)
1 parent 251b82d commit 72f0a54

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

python/src/cairo_coder/core/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
210210
# Merge metrics
211211
for key, value in metrics.items():
212212
if isinstance(value, int | float):
213-
result[model][key] = result[model].get(key, 0) + value
213+
# Use (x or 0) to handle None values - dict.get() returns None
214+
# if the key exists with value None, not the default
215+
result[model][key] = (result[model].get(key) or 0) + value
214216
elif isinstance(value, dict):
215217
if key not in result[model] or result[model][key] is None:
216218
result[model][key] = value.copy()
@@ -219,7 +221,7 @@ def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
219221
for detail_key, detail_value in value.items():
220222
if isinstance(detail_value, int | float):
221223
result[model][key][detail_key] = (
222-
result[model][key].get(detail_key, 0) + detail_value
224+
(result[model][key].get(detail_key) or 0) + detail_value
223225
)
224226
return result
225227

python/tests/unit/test_rag_pipeline.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414
RagPipeline,
1515
RagPipelineFactory,
1616
)
17-
from cairo_coder.core.types import Document, DocumentSource, Message, Role, StreamEventType
17+
from cairo_coder.core.types import (
18+
Document,
19+
DocumentSource,
20+
Message,
21+
Role,
22+
StreamEventType,
23+
combine_usage,
24+
)
1825
from cairo_coder.dspy.retrieval_judge import RetrievalJudge
1926

2027

@@ -853,3 +860,88 @@ def test_create_pipeline_with_custom_components(self, mock_vector_store_config):
853860
assert pipeline.config.max_source_count == 20
854861
assert pipeline.config.similarity_threshold == 0.6
855862
assert pipeline.config.sources == [DocumentSource.CAIRO_BOOK]
863+
864+
865+
class TestCombineUsage:
866+
"""Tests for the combine_usage function."""
867+
868+
@pytest.mark.parametrize(
869+
"usage1,usage2,expected",
870+
[
871+
# Test handling None values in usage dicts
872+
pytest.param(
873+
{
874+
"gpt-4": {
875+
"prompt_tokens": None, # Key exists but value is None
876+
"completion_tokens": 100,
877+
}
878+
},
879+
{
880+
"gpt-4": {
881+
"prompt_tokens": 50,
882+
"completion_tokens": 50,
883+
}
884+
},
885+
{
886+
"gpt-4": {
887+
"prompt_tokens": 50,
888+
"completion_tokens": 150,
889+
}
890+
},
891+
id="none_values_in_usage_dict",
892+
),
893+
# Test handling None values in nested dicts
894+
pytest.param(
895+
{
896+
"gpt-4": {
897+
"details": {
898+
"audio_tokens": None, # Key exists but value is None
899+
"cached_tokens": 100,
900+
}
901+
}
902+
},
903+
{
904+
"gpt-4": {
905+
"details": {
906+
"audio_tokens": 25,
907+
"cached_tokens": 50,
908+
}
909+
}
910+
},
911+
{
912+
"gpt-4": {
913+
"details": {
914+
"audio_tokens": 25,
915+
"cached_tokens": 150,
916+
}
917+
}
918+
},
919+
id="none_values_in_nested_dict",
920+
),
921+
# Test basic combining of usage dicts
922+
pytest.param(
923+
{"gpt-4": {"prompt_tokens": 100, "completion_tokens": 50}},
924+
{"gpt-4": {"prompt_tokens": 200, "completion_tokens": 100}},
925+
{"gpt-4": {"prompt_tokens": 300, "completion_tokens": 150}},
926+
id="basic_combining",
927+
),
928+
# Test combining with empty dicts
929+
pytest.param(
930+
{},
931+
{},
932+
{},
933+
id="both_empty",
934+
),
935+
# Test combining with one empty dict
936+
pytest.param(
937+
{"gpt-4": {"tokens": 100}},
938+
{},
939+
{"gpt-4": {"tokens": 100}},
940+
id="second_empty",
941+
),
942+
],
943+
)
944+
def test_combine_usage(self, usage1, usage2, expected):
945+
"""Test combine_usage with various input scenarios."""
946+
result = combine_usage(usage1, usage2)
947+
assert result == expected

0 commit comments

Comments
 (0)