Skip to content

Commit 950d6e1

Browse files
committed
Merge branch 'main' into set-token-limit
2 parents 7eff0c0 + 200e26c commit 950d6e1

File tree

6 files changed

+97
-14
lines changed

6 files changed

+97
-14
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### New features
1313

14-
* The `Chat` class gains a `.token_count()` method to help estimate input tokens before sending it to the LLM. (#23)
14+
* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).
15+
* `Chat` gains a `.token_count()` method to help estimate token cost of new input. (#23)
1516

1617
### Bug fixes
1718

chatlas/_chat.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Optional,
1717
Sequence,
1818
TypeVar,
19+
overload,
1920
)
2021

2122
from pydantic import BaseModel
@@ -177,14 +178,42 @@ def system_prompt(self, value: str | None):
177178
if value is not None:
178179
self._turns.insert(0, Turn("system", value))
179180

180-
def tokens(self) -> list[int]:
181+
@overload
182+
def tokens(self) -> list[tuple[int, int] | None]: ...
183+
184+
@overload
185+
def tokens(
186+
self,
187+
values: Literal["cumulative"],
188+
) -> list[tuple[int, int] | None]: ...
189+
190+
@overload
191+
def tokens(
192+
self,
193+
values: Literal["discrete"],
194+
) -> list[int]: ...
195+
196+
def tokens(
197+
self,
198+
values: Literal["cumulative", "discrete"] = "discrete",
199+
) -> list[int] | list[tuple[int, int] | None]:
181200
"""
182201
Get the tokens for each turn in the chat.
183202
203+
Parameters
204+
----------
205+
values
206+
If "cumulative" (the default), the result can be summed to get the
207+
chat's overall token usage (helpful for computing overall cost of
208+
the chat). If "discrete", the result can be summed to get the number of
209+
tokens the turns will cost to generate the next response (helpful
210+
for estimating cost of the next response, or for determining if you
211+
are about to exceed the token limit).
212+
184213
Returns
185214
-------
186215
list[int]
187-
A list of token counts for each turn in the chat. Note that the
216+
A list of token counts for each (non-system) turn in the chat. The
188217
1st turn includes the tokens count for the system prompt (if any).
189218
190219
Raises
@@ -199,6 +228,9 @@ def tokens(self) -> list[int]:
199228

200229
turns = self.get_turns(include_system_prompt=False)
201230

231+
if values == "cumulative":
232+
return [turn.tokens for turn in turns]
233+
202234
if len(turns) == 0:
203235
return []
204236

@@ -220,21 +252,25 @@ def tokens(self) -> list[int]:
220252
)
221253

222254
if turns[0].role != "user":
223-
raise ValueError("Expected the first turn to have role='user'. " + err_info)
255+
raise ValueError(
256+
"Expected the 1st non-system turn to have role='user'. " + err_info
257+
)
224258

225259
if turns[1].role != "assistant":
226260
raise ValueError(
227-
"Expected the 2nd turn to have role='assistant'. " + err_info
261+
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
228262
)
229263

230264
if turns[1].tokens is None:
231265
raise ValueError(
232266
"Expected the 1st assistant turn to contain token counts. " + err_info
233267
)
234268

235-
tokens: list[int] = [
269+
res: list[int] = [
270+
# Implied token count for the 1st user input
236271
turns[1].tokens[0],
237-
sum(turns[1].tokens),
272+
# The token count for the 1st assistant response
273+
turns[1].tokens[1],
238274
]
239275
for i in range(1, len(turns) - 1, 2):
240276
ti = turns[i]
@@ -248,7 +284,7 @@ def tokens(self) -> list[int]:
248284
"Expected role='assistant' turns to contain token counts."
249285
+ err_info
250286
)
251-
tokens.extend(
287+
res.extend(
252288
[
253289
# Implied token count for the user input
254290
tj.tokens[0] - sum(ti.tokens),
@@ -257,7 +293,7 @@ def tokens(self) -> list[int]:
257293
]
258294
)
259295

260-
return tokens
296+
return res
261297

262298
def token_count(
263299
self,
@@ -285,12 +321,18 @@ def token_count(
285321
int
286322
The token count for the input.
287323
324+
Note
325+
----
326+
Remember that the token count is an estimate. Also, models based on
327+
`ChatOpenAI()` currently does not take tools into account when
328+
estimating token counts.
329+
288330
Examples
289331
--------
290332
```python
291-
from chatlas import ChatOpenAI
333+
from chatlas import ChatAnthropic
292334
293-
chat = ChatOpenAI()
335+
chat = ChatAnthropic()
294336
# Estimate the token count before sending the input
295337
print(chat.token_count("What is 2 + 2?"))
296338

chatlas/_openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,12 @@ def _chat_perform_args(
295295
"stream": stream,
296296
"messages": self._as_message_param(turns),
297297
"model": self._model,
298-
"seed": self._seed,
299298
**(kwargs or {}),
300299
}
301300

301+
if self._seed is not None:
302+
kwargs_full["seed"] = self._seed
303+
302304
if tool_schemas:
303305
kwargs_full["tools"] = tool_schemas
304306

docs/get-started.qmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Learn more in the article on [structured data extraction](structured-data.qmd).
8282

8383
LLMs can also be useful to solve general programming problems. For example:
8484

85-
* You can use LLMs to explain code, or even ask them to [generate a diagram](https://bsky.app/profile/daviddiviny.bsky.social/post/3lb6kjaen4c2u).
85+
* You can use LLMs to explain code, or even ask them to [generate a diagram](https://bsky.app/profile/daviddiviny.com/post/3lb6kjaen4c2u).
8686

8787
* You can ask an LLM to analyse your code for potential code smells or security issues. You can do this a function at a time, or explore including the entire source code for your package or script in the prompt.
8888

tests/test_provider_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_openai_simple_request():
2121
chat.chat("What is 1 + 1?")
2222
turn = chat.get_last_turn()
2323
assert turn is not None
24-
assert turn.tokens == (27, 1)
24+
assert turn.tokens == (27, 2)
2525
assert turn.finish_reason == "stop"
2626

2727

tests/test_tokens.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,45 @@
1+
from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn
12
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
23
from chatlas._tokens import token_usage, tokens_log, tokens_reset
34

45

6+
def test_tokens_method():
7+
chat = ChatOpenAI()
8+
assert chat.tokens(values="discrete") == []
9+
10+
chat = ChatOpenAI(
11+
turns=[
12+
Turn(role="user", contents="Hi"),
13+
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
14+
]
15+
)
16+
17+
assert chat.tokens(values="discrete") == [2, 10]
18+
19+
chat = ChatOpenAI(
20+
turns=[
21+
Turn(role="user", contents="Hi"),
22+
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
23+
Turn(role="user", contents="Hi"),
24+
Turn(role="assistant", contents="Hello", tokens=(14, 10)),
25+
]
26+
)
27+
28+
assert chat.tokens(values="discrete") == [2, 10, 2, 10]
29+
assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]
30+
31+
32+
def test_token_count_method():
33+
chat = ChatOpenAI(model="gpt-4o-mini")
34+
assert chat.token_count("What is 1 + 1?") == 31
35+
36+
chat = ChatAnthropic(model="claude-3-5-sonnet-20241022")
37+
assert chat.token_count("What is 1 + 1?") == 16
38+
39+
chat = ChatGoogle(model="gemini-1.5-flash")
40+
assert chat.token_count("What is 1 + 1?") == 9
41+
42+
543
def test_usage_is_none():
644
tokens_reset()
745
assert token_usage() is None

0 commit comments

Comments
 (0)