Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dspy/clients/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import threading
from functools import wraps
from hashlib import sha256
import xxhash
from typing import Any

import cloudpickle
Expand Down Expand Up @@ -93,7 +93,7 @@ def transform_value(value):
return value

params = {k: transform_value(v) for k, v in request.items() if k not in ignored_args_for_cache_key}
return sha256(orjson.dumps(params, option=orjson.OPT_SORT_KEYS)).hexdigest()
return xxhash.xxh64(orjson.dumps(params, option=orjson.OPT_SORT_KEYS)).hexdigest()

def get(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | None = None) -> Any:

Expand Down
150 changes: 150 additions & 0 deletions tests/benchmarks/test_cache_key_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Benchmark for cache key generation performance (SHA256 vs xxhash)."""
import time
from hashlib import sha256

import orjson
import pytest
import xxhash


def benchmark_hash_function(hash_func, data_samples, iterations=1000):
"""Benchmark a hash function with multiple data samples."""
# Warmup
for _ in range(10):
for data in data_samples:
hash_func(data)

# Benchmark
start = time.perf_counter()
for _ in range(iterations):
for data in data_samples:
hash_func(data)
duration = time.perf_counter() - start

return duration / (iterations * len(data_samples))


def create_test_data():
"""Create realistic cache request data samples."""
return [
# Small request
orjson.dumps({"model": "gpt-4", "messages": [{"role": "user", "content": "test"}]}, option=orjson.OPT_SORT_KEYS),

# Medium request
orjson.dumps({
"model": "gpt-4o-mini",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?" * 10}
],
"temperature": 0.7,
"max_tokens": 100
}, option=orjson.OPT_SORT_KEYS),

# Large request
orjson.dumps({
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain quantum computing in detail." * 50}
],
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
"frequency_penalty": 0.5
}, option=orjson.OPT_SORT_KEYS),

# Request with nested structures
orjson.dumps({
"model": "gpt-4",
"messages": [
{"role": "user", "content": "test" * 20}
],
"tools": [
{"type": "function", "function": {"name": "get_weather", "description": "Get weather data"}},
{"type": "function", "function": {"name": "search", "description": "Search the web"}}
]
}, option=orjson.OPT_SORT_KEYS),
]


def test_sha256_performance():
"""Benchmark SHA256 hash performance (current implementation)."""
data_samples = create_test_data()

def sha256_hash(data):
return sha256(data).hexdigest()

avg_time = benchmark_hash_function(sha256_hash, data_samples, iterations=1000)

print(f"\nSHA256 average time: {avg_time*1e6:.2f}µs per operation")
assert avg_time > 0 # Sanity check


def test_xxhash_performance():
"""Benchmark xxhash performance (proposed implementation)."""
data_samples = create_test_data()

def xxhash_hash(data):
return xxhash.xxh64(data).hexdigest()

avg_time = benchmark_hash_function(xxhash_hash, data_samples, iterations=1000)

print(f"\nxxhash average time: {avg_time*1e6:.2f}µs per operation")
assert avg_time > 0 # Sanity check


def test_hash_performance_comparison():
"""Compare SHA256 vs xxhash performance."""
data_samples = create_test_data()

def sha256_hash(data):
return sha256(data).hexdigest()

def xxhash_hash(data):
return xxhash.xxh64(data).hexdigest()

sha256_time = benchmark_hash_function(sha256_hash, data_samples, iterations=1000)
xxhash_time = benchmark_hash_function(xxhash_hash, data_samples, iterations=1000)

speedup = sha256_time / xxhash_time

print(f"\n{'='*70}")
print("CACHE KEY GENERATION PERFORMANCE COMPARISON")
print(f"{'='*70}")
print(f"SHA256 (current): {sha256_time*1e6:.2f}µs per operation")
print(f"xxhash (proposed): {xxhash_time*1e6:.2f}µs per operation")
print(f"Speedup: {speedup:.2f}x faster")
print(f"Time reduction: {((sha256_time - xxhash_time)/sha256_time)*100:.1f}%")
print(f"{'='*70}")

# xxhash should be significantly faster
assert xxhash_time < sha256_time, "xxhash should be faster than SHA256"
assert speedup >= 2.0, f"Expected at least 2x speedup, got {speedup:.2f}x"


def test_real_cache_usage_pattern():
"""Benchmark realistic cache usage pattern with multiple requests."""
import dspy

# Simulate cache key generation for typical LM requests
requests = [
{"model": "gpt-4", "messages": [{"role": "user", "content": f"Question {i}"}], "temperature": 0.7}
for i in range(100)
]

cache = dspy.cache

# Benchmark SHA256 (current)
start = time.perf_counter()
for request in requests:
_ = cache.cache_key(request)
sha256_total = time.perf_counter() - start

print(f"\n100 cache key generations:")
print(f" SHA256 total time: {sha256_total*1000:.2f}ms")
print(f" Average per key: {sha256_total*1000/100:.2f}ms")

# This test just shows current performance baseline
# After implementing xxhash, we'll add comparison here

6 changes: 5 additions & 1 deletion tests/clients/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def test_cache_key_generation(cache):
request = {"prompt": "Hello", "model": "openai/gpt-4o-mini", "temperature": 0.7}
key = cache.cache_key(request)
assert isinstance(key, str)
assert len(key) == 64 # SHA-256 hash is 64 characters
assert len(key) == 16 # xxhash64 produces 16 character hex string

# Test determinism - same input should produce same key
key2 = cache.cache_key(request)
assert key == key2, "Cache keys must be deterministic"

# Test with pydantic model
class TestModel(pydantic.BaseModel):
Expand Down