1
1
import typing as t
2
2
from functools import partial
3
+ from inspect import iscoroutine
4
+ from contextlib import AsyncContextDecorator , ContextDecorator
5
+
3
6
from . import track_event , run_context , run_manager , logging , logger , user_props_ctx , user_ctx , traceback , tags_ctx , filter_params
4
7
5
8
try :
6
9
from anthropic import Anthropic , AsyncAnthropic
7
10
from anthropic .types import Message
11
+ from anthropic .lib .streaming import MessageStreamManager , AsyncMessageStreamManager
8
12
except ImportError :
9
13
raise ImportError ("Anthrophic SDK not installed!" ) from None
10
14
11
15
16
+ class sync_context_wrapper (ContextDecorator ):
17
+
18
+ def __init__ (self , stream ):
19
+ self .__stream = stream
20
+
21
+ def __enter__ (self ):
22
+ return self .__stream
23
+
24
+ def __exit__ (self , * _ ):
25
+ return
26
+
27
+
28
+ class async_context_wrapper (AsyncContextDecorator ):
29
+
30
+ def __init__ (self , stream ):
31
+ self .__stream = stream
32
+
33
+ async def __aenter__ (self ):
34
+ return self .__stream
35
+
36
+ async def __aexit__ (self , * _ ):
37
+ return
38
+
39
+
12
40
def __input_parser (kwargs : t .Dict ):
13
41
return {"input" : kwargs .get ("messages" ), "name" : kwargs .get ("model" )}
14
42
@@ -35,42 +63,44 @@ def __output_parser(output: t.Union[Message], stream: bool = False):
35
63
36
64
def __stream_handler (method , run_id , name , type , * args , ** kwargs ):
37
65
messages = []
66
+ original_stream = None
38
67
stream = method (* args , ** kwargs )
39
68
69
+ if isinstance (stream , MessageStreamManager ):
70
+ original_stream = stream
71
+ stream = original_stream .__enter__ ()
72
+
40
73
for event in stream :
41
74
if event .type == "message_start" :
42
- # print(event.message.model)
43
75
messages .append ({
44
76
"role" : event .message .role ,
45
77
"model" : event .message .model
46
78
})
47
79
if event .type == "message_delta" :
48
- # print("*", event.usage.output_tokens)
49
80
if len (messages ) >= 1 :
50
81
message = messages [- 1 ]
51
82
message ["usage" ] = {"tokens" : event .usage .output_tokens }
52
83
53
84
if event .type == "message_stop" : pass
54
85
if event .type == "content_block_start" :
55
- # print("* START")
56
- # print(event.content_block.text)
57
86
if len (messages ) >= 1 :
58
87
message = messages [- 1 ]
59
88
message ["output" ] = event .content_block .text
60
89
61
90
if event .type == "content_block_delta" :
62
- # print(event.delta.text, end="")
63
91
if len (messages ) >= 1 :
64
92
message = messages [- 1 ]
65
93
message ["output" ] = message .get ("output" ,
66
94
"" ) + event .delta .text
67
95
68
96
if event .type == "content_block_stop" :
69
- # print("* END")
70
97
pass
71
98
72
99
yield event
73
100
101
+ if original_stream :
102
+ original_stream .__exit__ (None , None , None )
103
+
74
104
track_event (
75
105
type ,
76
106
"end" ,
@@ -86,42 +116,47 @@ def __stream_handler(method, run_id, name, type, *args, **kwargs):
86
116
87
117
async def __async_stream_handler (method , run_id , name , type , * args , ** kwargs ):
88
118
messages = []
89
- stream = await method (* args , ** kwargs )
119
+ original_stream = None
120
+ stream = method (* args , ** kwargs )
121
+
122
+ if iscoroutine (stream ):
123
+ stream = await stream
124
+
125
+ if isinstance (stream , AsyncMessageStreamManager ):
126
+ original_stream = stream
127
+ stream = await original_stream .__aenter__ ()
90
128
91
129
async for event in stream :
92
130
if event .type == "message_start" :
93
- # print(event.message.model)
94
131
messages .append ({
95
132
"role" : event .message .role ,
96
133
"model" : event .message .model
97
134
})
98
135
if event .type == "message_delta" :
99
- # print("*", event.usage.output_tokens)
100
136
if len (messages ) >= 1 :
101
137
message = messages [- 1 ]
102
138
message ["usage" ] = {"tokens" : event .usage .output_tokens }
103
139
104
140
if event .type == "message_stop" : pass
105
141
if event .type == "content_block_start" :
106
- # print("* START")
107
- # print(event.content_block.text)
108
142
if len (messages ) >= 1 :
109
143
message = messages [- 1 ]
110
144
message ["output" ] = event .content_block .text
111
145
112
146
if event .type == "content_block_delta" :
113
- # print(event.delta.text, end="")
114
147
if len (messages ) >= 1 :
115
148
message = messages [- 1 ]
116
149
message ["output" ] = message .get ("output" ,
117
150
"" ) + event .delta .text
118
151
119
152
if event .type == "content_block_stop" :
120
- # print("* END")
121
153
pass
122
154
123
155
yield event
124
156
157
+ if original_stream :
158
+ await original_stream .__aexit__ (None , None , None )
159
+
125
160
track_event (
126
161
type ,
127
162
"end" ,
@@ -136,9 +171,7 @@ async def __async_stream_handler(method, run_id, name, type, *args, **kwargs):
136
171
137
172
138
173
def __metadata_parser (metadata ):
139
- return {
140
- x : metadata [x ] for x in metadata if x in ["user_id" ]
141
- }
174
+ return {x : metadata [x ] for x in metadata if x in ["user_id" ]}
142
175
143
176
144
177
def __wrap_sync (method : t .Callable ,
@@ -152,6 +185,7 @@ def __wrap_sync(method: t.Callable,
152
185
output_parser = __output_parser ,
153
186
stream_handler = __stream_handler ,
154
187
metadata_parser = __metadata_parser ,
188
+ contextify_stream : t .Optional [t .Callable ] = None ,
155
189
* args ,
156
190
** kwargs ):
157
191
output = None
@@ -189,10 +223,13 @@ def __wrap_sync(method: t.Callable,
189
223
except Exception as e :
190
224
logging .exception (e )
191
225
192
- if kwargs .get ("stream" ) == True :
193
- return stream_handler (method , run .id , name
194
- or parsed_input ["name" ], type , * args ,
195
- ** kwargs )
226
+ if contextify_stream or kwargs .get ("stream" ) == True :
227
+ generator = stream_handler (method , run .id , name
228
+ or parsed_input ["name" ], type ,
229
+ * args , ** kwargs )
230
+ if contextify_stream :
231
+ return contextify_stream (generator )
232
+ else : return generator
196
233
197
234
try :
198
235
output = method (* args , ** kwargs )
@@ -241,6 +278,7 @@ async def __wrap_async(method: t.Callable,
241
278
output_parser = __output_parser ,
242
279
stream_handler = __async_stream_handler ,
243
280
metadata_parser = __metadata_parser ,
281
+ contextify_stream : t .Optional [bool ] = False ,
244
282
* args ,
245
283
** kwargs ):
246
284
output = None
@@ -274,14 +312,17 @@ async def __wrap_async(method: t.Callable,
274
312
or tags_ctx .get ()),
275
313
template_id = (kwargs .get ("extra_headers" , {}).get (
276
314
"Template-Id" , None )),
277
- is_openai = True )
315
+ is_openai = False )
278
316
except Exception as e :
279
317
logging .exception (e )
280
318
281
- if kwargs .get ("stream" ) == True :
282
- return stream_handler (method , run .id , name
283
- or parsed_input ["name" ], type ,
284
- * args , ** kwargs )
319
+ if contextify_stream or kwargs .get ("stream" ) == True :
320
+ generator = stream_handler (method , run .id , name
321
+ or parsed_input ["name" ], type ,
322
+ * args , ** kwargs )
323
+ if contextify_stream :
324
+ return contextify_stream (generator )
325
+ else : return generator
285
326
286
327
try :
287
328
output = await method (* args , ** kwargs )
@@ -325,11 +366,22 @@ async def __wrap_async(method: t.Callable,
325
366
326
367
def monitor (client : "ClientType" ) -> "ClientType" :
327
368
if isinstance (client , Anthropic ):
328
- client .messages .create = partial (__wrap_sync , client .messages .create ,
329
- "llm" )
369
+ client .messages .create = partial (__wrap_sync ,
370
+ client .messages .create ,
371
+ type = "llm" )
372
+ client .messages .stream = partial (__wrap_sync ,
373
+ client .messages .stream ,
374
+ type = "llm" ,
375
+ contextify_stream = sync_context_wrapper )
330
376
elif isinstance (client , AsyncAnthropic ):
331
- client .messages .create = partial (__wrap_async , client .messages .create ,
332
- "llm" )
377
+ client .messages .create = partial (__wrap_async ,
378
+ client .messages .create ,
379
+ type = "llm" )
380
+ client .messages .stream = partial (__wrap_sync ,
381
+ client .messages .stream ,
382
+ type = "llm" ,
383
+ stream_handler = __async_stream_handler ,
384
+ contextify_stream = async_context_wrapper )
333
385
else :
334
386
raise Exception (
335
387
"Invalid argument. Expected instance of Anthropic Client" )
0 commit comments