2121 get_regex_logits_processor ,
2222)
2323from outlines .backends .base import LogitsProcessorType
24+ from outlines .outputs import Output
25+ from outlines .tools import get_formatted_tools , ToolsInput
2426from outlines .types import CFG , JsonSchema
2527from outlines .types .dsl import python_types_to_terms , to_regex
2628
@@ -35,7 +37,13 @@ class BlackBoxGenerator:
3537 """
3638 output_type : Optional [Any ]
3739
38- def __init__ (self , model : BlackBoxModel , output_type : Optional [Any ]):
40+ def __init__ (
41+ self ,
42+ model : BlackBoxModel ,
43+ output_type : Optional [Any ],
44+ * ,
45+ tools : Optional [ToolsInput ] = None ,
46+ ):
3947 """
4048 Parameters
4149 ----------
@@ -47,8 +55,9 @@ def __init__(self, model: BlackBoxModel, output_type: Optional[Any]):
4755 """
4856 self .model = model
4957 self .output_type = output_type
58+ self .tools = get_formatted_tools (tools )
5059
51- def __call__ (self , prompt : Any , ** inference_kwargs ) -> Any :
60+ def __call__ (self , prompt : Any , ** inference_kwargs ) -> Output :
5261 """Generate a response from the model.
5362
5463 Parameters
@@ -65,10 +74,10 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
6574
6675 """
6776 return self .model .generate (
68- prompt , self .output_type , ** inference_kwargs
77+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
6978 )
7079
71- def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Any ]:
80+ def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Output ]:
7281 """Generate a batch of responses from the model.
7382
7483 Parameters
@@ -85,7 +94,7 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
8594
8695 """
8796 return self .model .generate_batch (
88- prompts , self .output_type , ** inference_kwargs
97+ prompts , self .output_type , tools = self . tools , ** inference_kwargs
8998 )
9099
91100 def stream (self , prompt : Any , ** inference_kwargs ) -> Iterator [Any ]:
@@ -105,7 +114,7 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
105114
106115 """
107116 return self .model .generate_stream (
108- prompt , self .output_type , ** inference_kwargs
117+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
109118 )
110119
111120
@@ -119,7 +128,13 @@ class AsyncBlackBoxGenerator:
119128 """
120129 output_type : Optional [Any ]
121130
122- def __init__ (self , model : AsyncBlackBoxModel , output_type : Optional [Any ]):
131+ def __init__ (
132+ self ,
133+ model : AsyncBlackBoxModel ,
134+ output_type : Optional [Any ],
135+ * ,
136+ tools : Optional [ToolsInput ] = None ,
137+ ):
123138 """
124139 Parameters
125140 ----------
@@ -131,8 +146,9 @@ def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]):
131146 """
132147 self .model = model
133148 self .output_type = output_type
149+ self .tools = get_formatted_tools (tools )
134150
135- async def __call__ (self , prompt : Any , ** inference_kwargs ) -> Any :
151+ async def __call__ (self , prompt : Any , ** inference_kwargs ) -> Output :
136152 """Generate a response from the model.
137153
138154 Parameters
@@ -149,10 +165,10 @@ async def __call__(self, prompt: Any, **inference_kwargs) -> Any:
149165
150166 """
151167 return await self .model .generate (
152- prompt , self .output_type , ** inference_kwargs
168+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
153169 )
154170
155- async def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Any ]:
171+ async def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Output ]:
156172 """Generate a batch of responses from the model.
157173
158174 Parameters
@@ -169,7 +185,7 @@ async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
169185
170186 """
171187 return await self .model .generate_batch (
172- prompts , self .output_type , ** inference_kwargs
188+ prompts , self .output_type , tools = self . tools , ** inference_kwargs
173189 )
174190
175191 async def stream (self , prompt : Any , ** inference_kwargs ) -> AsyncIterator [Any ]:
@@ -189,7 +205,7 @@ async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]:
189205
190206 """
191207 async for chunk in self .model .generate_stream ( # pragma: no cover
192- prompt , self .output_type , ** inference_kwargs
208+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
193209 ):
194210 yield chunk
195211
@@ -218,6 +234,8 @@ def __init__(
218234 model : SteerableModel ,
219235 output_type : Optional [Any ],
220236 backend_name : Optional [str ] = None ,
237+ * ,
238+ tools : Optional [ToolsInput ] = None ,
221239 ):
222240 """
223241 Parameters
@@ -231,6 +249,7 @@ def __init__(
231249
232250 """
233251 self .model = model
252+ self .tools = get_formatted_tools (tools )
234253 if output_type is None :
235254 self .logits_processor = None
236255 else :
@@ -258,7 +277,11 @@ def __init__(
258277
259278 @classmethod
260279 def from_processor (
261- cls , model : SteerableModel , processor : LogitsProcessorType
280+ cls ,
281+ model : SteerableModel ,
282+ processor : LogitsProcessorType ,
283+ * ,
284+ tools : Optional [ToolsInput ] = None ,
262285 ):
263286 """Create a generator from a logits processor.
264287
@@ -270,13 +293,12 @@ def from_processor(
270293 An instance of a logits processor.
271294
272295 """
273- instance = cls .__new__ (cls )
274- instance .model = model
296+ instance = cls (model , None , tools = tools )
275297 instance .logits_processor = processor
276298
277299 return instance
278300
279- def __call__ (self , prompt : Any , ** inference_kwargs ) -> Any :
301+ def __call__ (self , prompt : Any , ** inference_kwargs ) -> Output :
280302 """Generate a response from the model.
281303
282304 Parameters
@@ -295,10 +317,10 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
295317 if self .logits_processor is not None :
296318 self .logits_processor .reset ()
297319 return self .model .generate (
298- prompt , self .logits_processor , ** inference_kwargs
320+ prompt , self .logits_processor , tools = self . tools , ** inference_kwargs
299321 )
300322
301- def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Any ]:
323+ def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Output ]:
302324 """Generate a batch of responses from the model.
303325
304326 Parameters
@@ -317,7 +339,7 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
317339 if self .logits_processor is not None :
318340 self .logits_processor .reset ()
319341 return self .model .generate_batch (
320- prompts , self .logits_processor , ** inference_kwargs
342+ prompts , self .logits_processor , tools = self . tools , ** inference_kwargs
321343 )
322344
323345 def stream (self , prompt : Any , ** inference_kwargs ) -> Iterator [Any ]:
@@ -339,7 +361,7 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
339361 if self .logits_processor is not None :
340362 self .logits_processor .reset ()
341363 return self .model .generate_stream (
342- prompt , self .logits_processor , ** inference_kwargs
364+ prompt , self .logits_processor , tools = self . tools , ** inference_kwargs
343365 )
344366
345367
@@ -348,6 +370,7 @@ def Generator(
348370 output_type : Optional [Any ] = None ,
349371 backend : Optional [str ] = None ,
350372 * ,
373+ tools : Optional [ToolsInput ] = None ,
351374 processor : Optional [LogitsProcessorType ] = None ,
352375) -> Union [SteerableGenerator , BlackBoxGenerator , AsyncBlackBoxGenerator ]:
353376 """Create a generator for the given model and output parameters.
@@ -387,18 +410,18 @@ def Generator(
387410
388411 if isinstance (model , SteerableModel ): # type: ignore
389412 if processor is not None :
390- return SteerableGenerator .from_processor (model , processor ) # type: ignore
413+ return SteerableGenerator .from_processor (model , processor , tools = tools ) # type: ignore
391414 else :
392- return SteerableGenerator (model , output_type , backend ) # type: ignore
415+ return SteerableGenerator (model , output_type , backend , tools = tools ) # type: ignore
393416 else :
394417 if processor is not None :
395418 raise NotImplementedError (
396419 "This model does not support logits processors"
397420 )
398421 if isinstance (model , AsyncBlackBoxModel ): # type: ignore
399- return AsyncBlackBoxGenerator (model , output_type ) # type: ignore
422+ return AsyncBlackBoxGenerator (model , output_type , tools = tools ) # type: ignore
400423 elif isinstance (model , BlackBoxModel ): # type: ignore
401- return BlackBoxGenerator (model , output_type ) # type: ignore
424+ return BlackBoxGenerator (model , output_type , tools = tools ) # type: ignore
402425 else :
403426 raise ValueError (
404427 "The model argument must be an instance of "
0 commit comments