22import os
33import re
44import threading
5+ import warnings
56from typing import Any , Literal , cast
67
78import litellm
@@ -61,8 +62,9 @@ def __init__(
6162 from the models available for inference.
6263 rollout_id: Optional integer used to differentiate cache entries for otherwise
6364 identical requests. Different values bypass DSPy's caches while still caching
64- future calls with the same inputs and rollout ID. This argument is stripped
65- before sending requests to the provider.
65+ future calls with the same inputs and rollout ID. Note that `rollout_id`
66+ only affects generation when `temperature` is non-zero. This argument is
67+ stripped before sending requests to the provider.
6668 """
6769 # Remember to update LM.copy() if you modify the constructor!
6870 self .model = model
@@ -75,6 +77,7 @@ def __init__(
7577 self .finetuning_model = finetuning_model
7678 self .launch_kwargs = launch_kwargs or {}
7779 self .train_kwargs = train_kwargs or {}
80+ self ._warned_zero_temp_rollout = False
7881
7982 # Handle model-specific configuration for different model families
8083 model_family = model .split ("/" )[- 1 ].lower () if "/" in model else model .lower ()
@@ -96,6 +99,20 @@ def __init__(
9699 if self .kwargs .get ("rollout_id" ) is None :
97100 self .kwargs .pop ("rollout_id" , None )
98101
102+ self ._warn_zero_temp_rollout (self .kwargs .get ("temperature" ), self .kwargs .get ("rollout_id" ))
103+
104+ def _warn_zero_temp_rollout (self , temperature : float | None , rollout_id ):
105+ if (
106+ not self ._warned_zero_temp_rollout
107+ and rollout_id is not None
108+ and (temperature is None or temperature == 0 )
109+ ):
110+ warnings .warn (
111+ "rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache." ,
112+ stacklevel = 3 ,
113+ )
114+ self ._warned_zero_temp_rollout = True
115+
99116 def _get_cached_completion_fn (self , completion_fn , cache ):
100117 ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ]
101118 if cache :
@@ -115,6 +132,7 @@ def forward(self, prompt=None, messages=None, **kwargs):
115132
116133 messages = messages or [{"role" : "user" , "content" : prompt }]
117134 kwargs = {** self .kwargs , ** kwargs }
135+ self ._warn_zero_temp_rollout (kwargs .get ("temperature" ), kwargs .get ("rollout_id" ))
118136 if kwargs .get ("rollout_id" ) is None :
119137 kwargs .pop ("rollout_id" , None )
120138
@@ -145,6 +163,7 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
145163
146164 messages = messages or [{"role" : "user" , "content" : prompt }]
147165 kwargs = {** self .kwargs , ** kwargs }
166+ self ._warn_zero_temp_rollout (kwargs .get ("temperature" ), kwargs .get ("rollout_id" ))
148167 if kwargs .get ("rollout_id" ) is None :
149168 kwargs .pop ("rollout_id" , None )
150169
0 commit comments