11import os
2+ import random
23from typing import Any , Dict , List , Optional
34
45import httpx
@@ -19,10 +20,12 @@ def __init__(self, api_key: str):
1920 )
2021
2122 @handle_async_errors
22- async def _get (self , path : str , params : Optional [Dict [str , Any ]] = None , timeout : int = None ) -> dict :
23+ async def _get (
24+ self , path : str , params : Optional [Dict [str , Any ]] = None , timeout : int = None
25+ ) -> dict :
2326 """
2427 Send an async GET request to the specified path.
25-
28+
2629 Args:
2730 path: API endpoint path
2831 params: Optional query parameters
@@ -73,6 +76,7 @@ def __init__(self, logs_resource):
7376 self .app_name : Optional [str ] = None
7477 self .environment : Optional [str ] = None
7578 self .tags : Dict [str , Any ] = {}
79+ self .sample_rate : float = 1.0
7680 self .hallucination_detection : bool = False
7781 self .inconsistency_detection : bool = False
7882 self ._configured = False
@@ -84,6 +88,7 @@ def init(
8488 app_name : str ,
8589 environment : str ,
8690 tags : Optional [Dict [str , Any ]] = {},
91+ sample_rate : float = 1.0 ,
8792 hallucination_detection : bool = False ,
8893 inconsistency_detection : bool = False ,
8994 hallucination_detection_sample_rate : float = 0 ,
@@ -95,12 +100,23 @@ def init(
95100 self .app_name = app_name
96101 self .environment = environment
97102 self .tags = tags or {}
103+ self .sample_rate = sample_rate
104+
105+ if not (0.0 <= self .sample_rate <= 1.0 ):
106+ raise QuotientAIError ("sample_rate must be between 0.0 and 1.0" )
107+
98108 self .hallucination_detection = hallucination_detection
99109 self .inconsistency_detection = inconsistency_detection
100110 self ._configured = True
101111 self .hallucination_detection_sample_rate = hallucination_detection_sample_rate
102112 return self
103113
114+ def _should_sample (self ) -> bool :
115+ """
116+ Determine if the log should be sampled based on the sample rate.
117+ """
118+ return random .random () < self .sample_rate
119+
104120 async def log (
105121 self ,
106122 * ,
@@ -139,21 +155,22 @@ async def log(
139155 else self .inconsistency_detection
140156 )
141157
142- log = await self .logs_resource .create (
143- app_name = self .app_name ,
144- environment = self .environment ,
145- user_query = user_query ,
146- model_output = model_output ,
147- documents = documents ,
148- message_history = message_history ,
149- instructions = instructions ,
150- tags = merged_tags ,
151- hallucination_detection = hallucination_detection ,
152- inconsistency_detection = inconsistency_detection ,
153- hallucination_detection_sample_rate = self .hallucination_detection_sample_rate ,
154- )
158+ if self ._should_sample ():
159+ await self .logs_resource .create (
160+ app_name = self .app_name ,
161+ environment = self .environment ,
162+ user_query = user_query ,
163+ model_output = model_output ,
164+ documents = documents ,
165+ message_history = message_history ,
166+ instructions = instructions ,
167+ tags = merged_tags ,
168+ hallucination_detection = hallucination_detection ,
169+ inconsistency_detection = inconsistency_detection ,
170+ hallucination_detection_sample_rate = self .hallucination_detection_sample_rate ,
171+ )
155172
156- return log
173+ return None
157174
158175
159176class AsyncQuotientAI :
0 commit comments