Skip to content

Commit 77c4ff3

Browse files
committed
feat: better support for reasoning/thinking
1 parent 985bf34 commit 77c4ff3

File tree

3 files changed

+88
-14
lines changed

3 files changed

+88
-14
lines changed

chatlas/_provider_anthropic.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ContentJson,
1818
ContentPDF,
1919
ContentText,
20+
ContentThinking,
2021
ContentToolRequest,
2122
ContentToolResult,
2223
ContentToolResultImage,
@@ -41,6 +42,8 @@
4142
MessageParam,
4243
RawMessageStreamEvent,
4344
TextBlock,
45+
ThinkingBlock,
46+
ThinkingBlockParam,
4447
ToolParam,
4548
ToolUseBlock,
4649
)
@@ -50,6 +53,7 @@
5053
from anthropic.types.messages.batch_create_params import Request as BatchRequest
5154
from anthropic.types.model_param import ModelParam
5255
from anthropic.types.text_block_param import TextBlockParam
56+
from anthropic.types.thinking_config_enabled_param import ThinkingConfigEnabledParam
5357
from anthropic.types.tool_result_block_param import ToolResultBlockParam
5458
from anthropic.types.tool_use_block_param import ToolUseBlockParam
5559

@@ -61,6 +65,7 @@
6165
ToolUseBlockParam,
6266
ToolResultBlockParam,
6367
DocumentBlockParam,
68+
ThinkingBlockParam,
6469
]
6570
else:
6671
Message = object
@@ -71,8 +76,9 @@ def ChatAnthropic(
7176
*,
7277
system_prompt: Optional[str] = None,
7378
model: "Optional[ModelParam]" = None,
74-
api_key: Optional[str] = None,
7579
max_tokens: int = 4096,
80+
reasoning: Optional["int | ThinkingConfigEnabledParam"] = None,
81+
api_key: Optional[str] = None,
7682
kwargs: Optional["ChatClientArgs"] = None,
7783
) -> Chat["SubmitInputArgs", Message]:
7884
"""
@@ -119,12 +125,19 @@ def ChatAnthropic(
119125
The model to use for the chat. The default, None, will pick a reasonable
120126
default, and warn you about it. We strongly recommend explicitly
121127
choosing a model for all but the most casual use.
128+
max_tokens
129+
Maximum number of tokens to generate before stopping.
130+
reasoning
131+
Determines how many tokens Claude can be allocated to reasoning. Must be
132+
≥1024 and less than `max_tokens`. Larger budgets can enable more
133+
thorough analysis for complex problems, improving response quality. See
134+
[extended
135+
thinking](https://docs.claude.com/en/docs/build-with-claude/extended-thinking)
136+
for details.
122137
api_key
123138
The API key to use for authentication. You generally should not supply
124139
this directly, but instead set the `ANTHROPIC_API_KEY` environment
125140
variable.
126-
max_tokens
127-
Maximum number of tokens to generate before stopping.
128141
kwargs
129142
Additional arguments to pass to the `anthropic.Anthropic()` client
130143
constructor.
@@ -174,6 +187,12 @@ def ChatAnthropic(
174187
if model is None:
175188
model = log_model_default("claude-sonnet-4-0")
176189

190+
kwargs_chat: "SubmitInputArgs" = {}
191+
if reasoning is not None:
192+
if isinstance(reasoning, int):
193+
reasoning = {"type": "enabled", "budget_tokens": reasoning}
194+
kwargs_chat = {"thinking": reasoning}
195+
177196
return Chat(
178197
provider=AnthropicProvider(
179198
api_key=api_key,
@@ -182,6 +201,7 @@ def ChatAnthropic(
182201
kwargs=kwargs,
183202
),
184203
system_prompt=system_prompt,
204+
kwargs_chat=kwargs_chat,
185205
)
186206

187207

@@ -396,6 +416,12 @@ def stream_merge_chunks(self, completion, chunk):
396416
if not isinstance(this_content.input, str):
397417
this_content.input = "" # type: ignore
398418
this_content.input += json_delta # type: ignore
419+
elif chunk.delta.type == "thinking_delta":
420+
this_content = cast("ThinkingBlock", this_content)
421+
this_content.thinking += chunk.delta.thinking
422+
elif chunk.delta.type == "signature_delta":
423+
this_content = cast("ThinkingBlock", this_content)
424+
this_content.signature += chunk.delta.signature
399425
elif chunk.type == "content_block_stop":
400426
this_content = completion.content[chunk.index]
401427
if this_content.type == "tool_use" and isinstance(this_content.input, str):
@@ -588,6 +614,13 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
588614
res["content"] = content.get_model_value() # type: ignore
589615

590616
return res
617+
elif isinstance(content, ContentThinking):
618+
extra = content.extra or {}
619+
return {
620+
"type": "thinking",
621+
"thinking": content.thinking,
622+
"signature": extra.get("signature", ""),
623+
}
591624

592625
raise ValueError(f"Unknown content type: {type(content)}")
593626

@@ -641,6 +674,13 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn:
641674
arguments=content.input,
642675
)
643676
)
677+
elif content.type == "thinking":
678+
contents.append(
679+
ContentThinking(
680+
thinking=content.thinking,
681+
extra={"signature": content.signature},
682+
)
683+
)
644684

645685
return Turn(
646686
"assistant",

chatlas/_provider_google.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
GenerateContentResponseDict,
3535
Part,
3636
PartDict,
37+
ThinkingConfigDict,
3738
)
3839

3940
from .types.google import ChatClientArgs, SubmitInputArgs
@@ -45,6 +46,7 @@ def ChatGoogle(
4546
*,
4647
system_prompt: Optional[str] = None,
4748
model: Optional[str] = None,
49+
reasoning: Optional["int | ThinkingConfigDict"] = None,
4850
api_key: Optional[str] = None,
4951
kwargs: Optional["ChatClientArgs"] = None,
5052
) -> Chat["SubmitInputArgs", GenerateContentResponse]:
@@ -86,6 +88,10 @@ def ChatGoogle(
8688
The model to use for the chat. The default, None, will pick a reasonable
8789
default, and warn you about it. We strongly recommend explicitly choosing
8890
a model for all but the most casual use.
91+
reasoning
92+
If provided, enables reasoning (a.k.a. "thoughts") in the model's
93+
responses. This can be an integer number of tokens to use for reasoning,
94+
or a full `ThinkingConfigDict` to customize the reasoning behavior.
8995
api_key
9096
The API key to use for authentication. You generally should not supply
9197
this directly, but instead set the `GOOGLE_API_KEY` environment variable.
@@ -137,14 +143,20 @@ def ChatGoogle(
137143
if model is None:
138144
model = log_model_default("gemini-2.5-flash")
139145

146+
kwargs_chat: "SubmitInputArgs" = {}
147+
if reasoning is not None:
148+
if isinstance(reasoning, int):
149+
reasoning = {"thinking_budget": reasoning, "include_thoughts": True}
150+
kwargs_chat["config"] = {"thinking_config": reasoning}
151+
140152
return Chat(
141153
provider=GoogleProvider(
142154
model=model,
143155
api_key=api_key,
144-
name="Google/Gemini",
145156
kwargs=kwargs,
146157
),
147158
system_prompt=system_prompt,
159+
kwargs_chat=kwargs_chat,
148160
)
149161

150162

@@ -367,7 +379,7 @@ def value_tokens(self, completion):
367379
cached = usage.cached_content_token_count or 0
368380
return (
369381
(usage.prompt_token_count or 0) - cached,
370-
usage.candidates_token_count or 0,
382+
(usage.candidates_token_count or 0) + (usage.thoughts_token_count or 0),
371383
usage.cached_content_token_count or 0,
372384
)
373385

chatlas/_provider_openai.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
)
3636
from openai.types.responses.easy_input_message_param import EasyInputMessageParam
3737
from openai.types.responses.tool_param import ToolParam
38+
from openai.types.shared.reasoning_effort import ReasoningEffort
39+
from openai.types.shared_params.reasoning import Reasoning
3840
from openai.types.shared_params.responses_model import ResponsesModel
3941

4042
from .types.openai import ChatClientArgs
@@ -47,8 +49,9 @@ def ChatOpenAI(
4749
*,
4850
system_prompt: Optional[str] = None,
4951
model: "Optional[ResponsesModel | str]" = None,
50-
api_key: Optional[str] = None,
5152
base_url: str = "https://api.openai.com/v1",
53+
reasoning: "Optional[ReasoningEffort | Reasoning]" = None,
54+
api_key: Optional[str] = None,
5255
kwargs: Optional["ChatClientArgs"] = None,
5356
) -> Chat["SubmitInputArgs", Response]:
5457
"""
@@ -87,12 +90,15 @@ def ChatOpenAI(
8790
The model to use for the chat. The default, None, will pick a reasonable
8891
default, and warn you about it. We strongly recommend explicitly
8992
choosing a model for all but the most casual use.
93+
base_url
94+
The base URL to the endpoint; the default uses OpenAI.
95+
reasoning
96+
The reasoning effort to use (for reasoning-capable models like the o and
97+
gpt-5 series).
9098
api_key
9199
The API key to use for authentication. You generally should not supply
92100
this directly, but instead set the `OPENAI_API_KEY` environment
93101
variable.
94-
base_url
95-
The base URL to the endpoint; the default uses OpenAI.
96102
kwargs
97103
Additional arguments to pass to the `openai.OpenAI()` client
98104
constructor.
@@ -146,6 +152,14 @@ def ChatOpenAI(
146152
if model is None:
147153
model = log_model_default("gpt-4.1")
148154

155+
kwargs_chat: "SubmitInputArgs" = {}
156+
if reasoning is not None:
157+
if not is_reasoning_model(model):
158+
warnings.warn(f"Model {model} is not reasoning-capable", UserWarning)
159+
if isinstance(reasoning, str):
160+
reasoning = {"effort": reasoning, "summary": "auto"}
161+
kwargs_chat = {"reasoning": reasoning}
162+
149163
return Chat(
150164
provider=OpenAIProvider(
151165
api_key=api_key,
@@ -154,6 +168,7 @@ def ChatOpenAI(
154168
kwargs=kwargs,
155169
),
156170
system_prompt=system_prompt,
171+
kwargs_chat=kwargs_chat,
157172
)
158173

159174

@@ -239,7 +254,7 @@ def _chat_perform_args(
239254

240255
# Request reasoning content for reasoning models
241256
include = []
242-
if self._is_reasoning(self.model):
257+
if is_reasoning_model(self.model):
243258
include.append("reasoning.encrypted_content")
244259

245260
if "log_probs" in kwargs_full:
@@ -254,7 +269,14 @@ def _chat_perform_args(
254269

255270
def stream_text(self, chunk):
256271
if chunk.type == "response.output_text.delta":
272+
# https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text/delta
273+
return chunk.delta
274+
if chunk.type == "response.reasoning_summary_text.delta":
275+
# https://platform.openai.com/docs/api-reference/responses-streaming/response/reasoning_summary_text/delta
257276
return chunk.delta
277+
if chunk.type == "response.reasoning_summary_text.done":
278+
# https://platform.openai.com/docs/api-reference/responses-streaming/response/reasoning_summary_text/done
279+
return "\n\n"
258280
return None
259281

260282
def stream_merge_chunks(self, completion, chunk):
@@ -337,11 +359,6 @@ def _response_as_turn(completion: Response, has_data_model: bool) -> Turn:
337359
completion=completion,
338360
)
339361

340-
@staticmethod
341-
def _is_reasoning(model: str) -> bool:
342-
# https://platform.openai.com/docs/models/compare
343-
return model.startswith("o") or model.startswith("gpt-5")
344-
345362
@staticmethod
346363
def _turns_as_inputs(turns: list[Turn]) -> "list[ResponseInputItemParam]":
347364
res: "list[ResponseInputItemParam]" = []
@@ -456,3 +473,8 @@ def as_input_param(content: Content, role: Role) -> "ResponseInputItemParam":
456473

457474
def as_message(x: "ResponseInputContentParam", role: Role) -> "EasyInputMessageParam":
458475
return {"role": role, "content": [x]}
476+
477+
478+
def is_reasoning_model(model: str) -> bool:
479+
# https://platform.openai.com/docs/models/compare
480+
return model.startswith("o") or model.startswith("gpt-5")

0 commit comments

Comments
 (0)