diff --git a/python/src/cairo_coder/core/types.py b/python/src/cairo_coder/core/types.py index 8486f50..623dd4e 100644 --- a/python/src/cairo_coder/core/types.py +++ b/python/src/cairo_coder/core/types.py @@ -210,7 +210,9 @@ def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage: # Merge metrics for key, value in metrics.items(): if isinstance(value, int | float): - result[model][key] = result[model].get(key, 0) + value + # Use (x or 0) to handle None values - dict.get() returns None + # if the key exists with value None, not the default + result[model][key] = (result[model].get(key) or 0) + value elif isinstance(value, dict): if key not in result[model] or result[model][key] is None: result[model][key] = value.copy() @@ -219,7 +221,7 @@ def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage: for detail_key, detail_value in value.items(): if isinstance(detail_value, int | float): result[model][key][detail_key] = ( - result[model][key].get(detail_key, 0) + detail_value + (result[model][key].get(detail_key) or 0) + detail_value ) return result diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index 63c9147..9e51db8 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -14,7 +14,14 @@ RagPipeline, RagPipelineFactory, ) -from cairo_coder.core.types import Document, DocumentSource, Message, Role, StreamEventType +from cairo_coder.core.types import ( + Document, + DocumentSource, + Message, + Role, + StreamEventType, + combine_usage, +) from cairo_coder.dspy.retrieval_judge import RetrievalJudge @@ -853,3 +860,88 @@ def test_create_pipeline_with_custom_components(self, mock_vector_store_config): assert pipeline.config.max_source_count == 20 assert pipeline.config.similarity_threshold == 0.6 assert pipeline.config.sources == [DocumentSource.CAIRO_BOOK] + + +class TestCombineUsage: + """Tests for the combine_usage function.""" + + @pytest.mark.parametrize( + "usage1,usage2,expected", + [ + # Test handling None values in usage dicts + pytest.param( + { + "gpt-4": { + "prompt_tokens": None, # Key exists but value is None + "completion_tokens": 100, + } + }, + { + "gpt-4": { + "prompt_tokens": 50, + "completion_tokens": 50, + } + }, + { + "gpt-4": { + "prompt_tokens": 50, + "completion_tokens": 150, + } + }, + id="none_values_in_usage_dict", + ), + # Test handling None values in nested dicts + pytest.param( + { + "gpt-4": { + "details": { + "audio_tokens": None, # Key exists but value is None + "cached_tokens": 100, + } + } + }, + { + "gpt-4": { + "details": { + "audio_tokens": 25, + "cached_tokens": 50, + } + } + }, + { + "gpt-4": { + "details": { + "audio_tokens": 25, + "cached_tokens": 150, + } + } + }, + id="none_values_in_nested_dict", + ), + # Test basic combining of usage dicts + pytest.param( + {"gpt-4": {"prompt_tokens": 100, "completion_tokens": 50}}, + {"gpt-4": {"prompt_tokens": 200, "completion_tokens": 100}}, + {"gpt-4": {"prompt_tokens": 300, "completion_tokens": 150}}, + id="basic_combining", + ), + # Test combining with empty dicts + pytest.param( + {}, + {}, + {}, + id="both_empty", + ), + # Test combining with one empty dict + pytest.param( + {"gpt-4": {"tokens": 100}}, + {}, + {"gpt-4": {"tokens": 100}}, + id="second_empty", + ), + ], + ) + def test_combine_usage(self, usage1, usage2, expected): + """Test combine_usage with various input scenarios.""" + result = combine_usage(usage1, usage2) + assert result == expected