|
16 | 16 | wait_random_exponential,
|
17 | 17 | retry_if_not_exception_type,
|
18 | 18 | )
|
19 |
| -from typing import Optional |
| 19 | +from typing import Optional, Union |
20 | 20 |
|
21 | 21 | logger = logging.getLogger("api_models")
|
22 | 22 |
|
@@ -49,7 +49,6 @@ def __add__(self, other):
|
49 | 49 | field.name: getattr(self, field.name) + getattr(other, field.name)
|
50 | 50 | for field in fields(self)
|
51 | 51 | })
|
52 |
| - |
53 | 52 | def replace(self, other):
|
54 | 53 | if not isinstance(other, APIStats):
|
55 | 54 | raise TypeError("Can only replace APIStats with APIStats")
|
@@ -148,19 +147,13 @@ def update_stats(self, input_tokens: int, output_tokens: int) -> float:
|
148 | 147 | )
|
149 | 148 |
|
150 | 149 | # Check whether total cost or instance cost limits have been exceeded
|
151 |
| - if ( |
152 |
| - self.args.total_cost_limit > 0 |
153 |
| - and self.stats.total_cost >= self.args.total_cost_limit |
154 |
| - ): |
| 150 | + if 0 < self.args.total_cost_limit <= self.stats.total_cost: |
155 | 151 | logger.warning(
|
156 | 152 | f"Cost {self.stats.total_cost:.2f} exceeds limit {self.args.total_cost_limit:.2f}"
|
157 | 153 | )
|
158 | 154 | raise CostLimitExceededError("Total cost limit exceeded")
|
159 | 155 |
|
160 |
| - if ( |
161 |
| - self.args.per_instance_cost_limit > 0 |
162 |
| - and self.stats.instance_cost >= self.args.per_instance_cost_limit |
163 |
| - ): |
| 156 | + if 0 < self.args.per_instance_cost_limit <= self.stats.instance_cost: |
164 | 157 | logger.warning(
|
165 | 158 | f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.args.per_instance_cost_limit:.2f}"
|
166 | 159 | )
|
@@ -233,7 +226,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
|
233 | 226 |
|
234 | 227 | def history_to_messages(
|
235 | 228 | self, history: list[dict[str, str]], is_demonstration: bool = False
|
236 |
| - ) -> list[dict[str, str]]: |
| 229 | + ) -> Union[str, list[dict[str, str]]]: |
237 | 230 | """
|
238 | 231 | Create `messages` by filtering out all keys except for role/content per `history` turn
|
239 | 232 | """
|
@@ -273,6 +266,7 @@ def query(self, history: list[dict[str, str]]) -> str:
|
273 | 266 | self.update_stats(input_tokens, output_tokens)
|
274 | 267 | return response.choices[0].message.content
|
275 | 268 |
|
| 269 | + |
276 | 270 | class AnthropicModel(BaseModel):
|
277 | 271 | MODELS = {
|
278 | 272 | "claude-instant": {
|
@@ -326,7 +320,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
|
326 | 320 |
|
327 | 321 | def history_to_messages(
|
328 | 322 | self, history: list[dict[str, str]], is_demonstration: bool = False
|
329 |
| - ) -> list[dict[str, str]]: |
| 323 | + ) -> Union[str, list[dict[str, str]]]: |
330 | 324 | """
|
331 | 325 | Create `prompt` by filtering out all keys except for role/content per `history` turn
|
332 | 326 | Reference: https://docs.anthropic.com/claude/reference/complete_post
|
@@ -440,7 +434,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
|
440 | 434 |
|
441 | 435 | def history_to_messages(
|
442 | 436 | self, history: list[dict[str, str]], is_demonstration: bool = False
|
443 |
| - ) -> list[dict[str, str]]: |
| 437 | + ) -> Union[str, list[dict[str, str]]]: |
444 | 438 | """
|
445 | 439 | Create `messages` by filtering out all keys except for role/content per `history` turn
|
446 | 440 | """
|
@@ -516,7 +510,7 @@ class TogetherModel(BaseModel):
|
516 | 510 | "max_context": 32768,
|
517 | 511 | "cost_per_input_token": 6e-07,
|
518 | 512 | "cost_per_output_token": 6e-07,
|
519 |
| - }, |
| 513 | + }, |
520 | 514 | }
|
521 | 515 |
|
522 | 516 | SHORTCUTS = {
|
@@ -593,7 +587,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
|
593 | 587 |
|
594 | 588 | def history_to_messages(
|
595 | 589 | self, history: list[dict[str, str]], is_demonstration: bool = False
|
596 |
| - ) -> list[dict[str, str]]: |
| 590 | + ) -> Union[str, list[dict[str, str]]]: |
597 | 591 | """
|
598 | 592 | Create `messages` by filtering out all keys except for role/content per `history` turn
|
599 | 593 | """
|
@@ -652,7 +646,7 @@ def query(self, history: list[dict[str, str]]) -> str:
|
652 | 646 | break
|
653 | 647 | thought_all += thought
|
654 | 648 | thought = input("... ")
|
655 |
| - |
| 649 | + |
656 | 650 | action = super().query(history, action_prompt="Action: ")
|
657 | 651 |
|
658 | 652 | return f"{thought_all}\n```\n{action}\n```"
|
|
0 commit comments