Skip to content
This repository was archived by the owner on Feb 14, 2025. It is now read-only.

Commit 986ef7b

Browse files
committed
feat: Added support for stream helper function
1 parent cf19102 commit 986ef7b

File tree

2 files changed

+120
-30
lines changed

2 files changed

+120
-30
lines changed

examples/anthrophic.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,41 @@ async def test_async_streaming():
6969
print(event)
7070

7171

72+
def test_sync_stream_helper():
73+
client = Anthropic()
74+
monitor(client)
75+
76+
with client.messages.stream(
77+
max_tokens=1024,
78+
messages=[{
79+
"role": "user",
80+
"content": "Hello, Claude",
81+
}],
82+
model="claude-3-opus-20240229",
83+
) as stream:
84+
for event in stream:
85+
print(event)
86+
87+
async def test_async_stream_helper():
88+
client = monitor(AsyncAnthropic())
89+
90+
async with client.messages.stream(
91+
max_tokens=1024,
92+
messages=[
93+
{
94+
"role": "user",
95+
"content": "Say hello there!",
96+
}
97+
],
98+
model="claude-3-opus-20240229",
99+
) as stream:
100+
async for event in stream:
101+
print(event)
102+
103+
message = await stream.get_final_message()
104+
print(message.to_json())
105+
106+
72107
def test_extra_arguments():
73108
client = Anthropic()
74109
monitor(client)
@@ -100,4 +135,7 @@ def test_extra_arguments():
100135
# test_sync_streaming()
101136
# test_asyncio.run(async_streaming())
102137

103-
test_extra_arguments()
138+
# test_extra_arguments()
139+
140+
# test_sync_stream_helper()
141+
asyncio.run(test_async_stream_helper())

lunary/anthrophic.py

+81-29
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,42 @@
11
import typing as t
22
from functools import partial
3+
from inspect import iscoroutine
4+
from contextlib import AsyncContextDecorator, ContextDecorator
5+
36
from . import track_event, run_context, run_manager, logging, logger, user_props_ctx, user_ctx, traceback, tags_ctx, filter_params
47

58
try:
69
from anthropic import Anthropic, AsyncAnthropic
710
from anthropic.types import Message
11+
from anthropic.lib.streaming import MessageStreamManager, AsyncMessageStreamManager
812
except ImportError:
913
raise ImportError("Anthrophic SDK not installed!") from None
1014

1115

16+
class sync_context_wrapper(ContextDecorator):
17+
18+
def __init__(self, stream):
19+
self.__stream = stream
20+
21+
def __enter__(self):
22+
return self.__stream
23+
24+
def __exit__(self, *_):
25+
return
26+
27+
28+
class async_context_wrapper(AsyncContextDecorator):
29+
30+
def __init__(self, stream):
31+
self.__stream = stream
32+
33+
async def __aenter__(self):
34+
return self.__stream
35+
36+
async def __aexit__(self, *_):
37+
return
38+
39+
1240
def __input_parser(kwargs: t.Dict):
1341
return {"input": kwargs.get("messages"), "name": kwargs.get("model")}
1442

@@ -35,42 +63,44 @@ def __output_parser(output: t.Union[Message], stream: bool = False):
3563

3664
def __stream_handler(method, run_id, name, type, *args, **kwargs):
3765
messages = []
66+
original_stream = None
3867
stream = method(*args, **kwargs)
3968

69+
if isinstance(stream, MessageStreamManager):
70+
original_stream = stream
71+
stream = original_stream.__enter__()
72+
4073
for event in stream:
4174
if event.type == "message_start":
42-
# print(event.message.model)
4375
messages.append({
4476
"role": event.message.role,
4577
"model": event.message.model
4678
})
4779
if event.type == "message_delta":
48-
# print("*", event.usage.output_tokens)
4980
if len(messages) >= 1:
5081
message = messages[-1]
5182
message["usage"] = {"tokens": event.usage.output_tokens}
5283

5384
if event.type == "message_stop": pass
5485
if event.type == "content_block_start":
55-
# print("* START")
56-
# print(event.content_block.text)
5786
if len(messages) >= 1:
5887
message = messages[-1]
5988
message["output"] = event.content_block.text
6089

6190
if event.type == "content_block_delta":
62-
# print(event.delta.text, end="")
6391
if len(messages) >= 1:
6492
message = messages[-1]
6593
message["output"] = message.get("output",
6694
"") + event.delta.text
6795

6896
if event.type == "content_block_stop":
69-
# print("* END")
7097
pass
7198

7299
yield event
73100

101+
if original_stream:
102+
original_stream.__exit__(None, None, None)
103+
74104
track_event(
75105
type,
76106
"end",
@@ -86,42 +116,47 @@ def __stream_handler(method, run_id, name, type, *args, **kwargs):
86116

87117
async def __async_stream_handler(method, run_id, name, type, *args, **kwargs):
88118
messages = []
89-
stream = await method(*args, **kwargs)
119+
original_stream = None
120+
stream = method(*args, **kwargs)
121+
122+
if iscoroutine(stream):
123+
stream = await stream
124+
125+
if isinstance(stream, AsyncMessageStreamManager):
126+
original_stream = stream
127+
stream = await original_stream.__aenter__()
90128

91129
async for event in stream:
92130
if event.type == "message_start":
93-
# print(event.message.model)
94131
messages.append({
95132
"role": event.message.role,
96133
"model": event.message.model
97134
})
98135
if event.type == "message_delta":
99-
# print("*", event.usage.output_tokens)
100136
if len(messages) >= 1:
101137
message = messages[-1]
102138
message["usage"] = {"tokens": event.usage.output_tokens}
103139

104140
if event.type == "message_stop": pass
105141
if event.type == "content_block_start":
106-
# print("* START")
107-
# print(event.content_block.text)
108142
if len(messages) >= 1:
109143
message = messages[-1]
110144
message["output"] = event.content_block.text
111145

112146
if event.type == "content_block_delta":
113-
# print(event.delta.text, end="")
114147
if len(messages) >= 1:
115148
message = messages[-1]
116149
message["output"] = message.get("output",
117150
"") + event.delta.text
118151

119152
if event.type == "content_block_stop":
120-
# print("* END")
121153
pass
122154

123155
yield event
124156

157+
if original_stream:
158+
await original_stream.__aexit__(None, None, None)
159+
125160
track_event(
126161
type,
127162
"end",
@@ -136,9 +171,7 @@ async def __async_stream_handler(method, run_id, name, type, *args, **kwargs):
136171

137172

138173
def __metadata_parser(metadata):
139-
return {
140-
x: metadata[x] for x in metadata if x in ["user_id"]
141-
}
174+
return {x: metadata[x] for x in metadata if x in ["user_id"]}
142175

143176

144177
def __wrap_sync(method: t.Callable,
@@ -152,6 +185,7 @@ def __wrap_sync(method: t.Callable,
152185
output_parser=__output_parser,
153186
stream_handler=__stream_handler,
154187
metadata_parser=__metadata_parser,
188+
contextify_stream: t.Optional[t.Callable] = None,
155189
*args,
156190
**kwargs):
157191
output = None
@@ -189,10 +223,13 @@ def __wrap_sync(method: t.Callable,
189223
except Exception as e:
190224
logging.exception(e)
191225

192-
if kwargs.get("stream") == True:
193-
return stream_handler(method, run.id, name
194-
or parsed_input["name"], type, *args,
195-
**kwargs)
226+
if contextify_stream or kwargs.get("stream") == True:
227+
generator = stream_handler(method, run.id, name
228+
or parsed_input["name"], type,
229+
*args, **kwargs)
230+
if contextify_stream:
231+
return contextify_stream(generator)
232+
else: return generator
196233

197234
try:
198235
output = method(*args, **kwargs)
@@ -241,6 +278,7 @@ async def __wrap_async(method: t.Callable,
241278
output_parser=__output_parser,
242279
stream_handler=__async_stream_handler,
243280
metadata_parser=__metadata_parser,
281+
contextify_stream: t.Optional[bool] = False,
244282
*args,
245283
**kwargs):
246284
output = None
@@ -274,14 +312,17 @@ async def __wrap_async(method: t.Callable,
274312
or tags_ctx.get()),
275313
template_id=(kwargs.get("extra_headers", {}).get(
276314
"Template-Id", None)),
277-
is_openai=True)
315+
is_openai=False)
278316
except Exception as e:
279317
logging.exception(e)
280318

281-
if kwargs.get("stream") == True:
282-
return stream_handler(method, run.id, name
283-
or parsed_input["name"], type,
284-
*args, **kwargs)
319+
if contextify_stream or kwargs.get("stream") == True:
320+
generator = stream_handler(method, run.id, name
321+
or parsed_input["name"], type,
322+
*args, **kwargs)
323+
if contextify_stream:
324+
return contextify_stream(generator)
325+
else: return generator
285326

286327
try:
287328
output = await method(*args, **kwargs)
@@ -325,11 +366,22 @@ async def __wrap_async(method: t.Callable,
325366

326367
def monitor(client: "ClientType") -> "ClientType":
327368
if isinstance(client, Anthropic):
328-
client.messages.create = partial(__wrap_sync, client.messages.create,
329-
"llm")
369+
client.messages.create = partial(__wrap_sync,
370+
client.messages.create,
371+
type="llm")
372+
client.messages.stream = partial(__wrap_sync,
373+
client.messages.stream,
374+
type="llm",
375+
contextify_stream=sync_context_wrapper)
330376
elif isinstance(client, AsyncAnthropic):
331-
client.messages.create = partial(__wrap_async, client.messages.create,
332-
"llm")
377+
client.messages.create = partial(__wrap_async,
378+
client.messages.create,
379+
type="llm")
380+
client.messages.stream = partial(__wrap_sync,
381+
client.messages.stream,
382+
type="llm",
383+
stream_handler=__async_stream_handler,
384+
contextify_stream=async_context_wrapper)
333385
else:
334386
raise Exception(
335387
"Invalid argument. Expected instance of Anthropic Client")

0 commit comments

Comments
 (0)