Skip to content

Commit fbcd1b1

Browse files
mspronestipmprones
and
pmprones
authored
chore(models): simplify conditions and fix return types (SWE-agent#216)
* chore(models): simplify conditions and fix return types * undo formatting --------- Co-authored-by: pmprones <[email protected]>
1 parent e052bef commit fbcd1b1

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,6 @@ website/frontend/build
189189
trajectories/*
190190

191191
.vscode/**
192+
193+
# PyCharm
194+
.idea/

sweagent/agent/models.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
wait_random_exponential,
1717
retry_if_not_exception_type,
1818
)
19-
from typing import Optional
19+
from typing import Optional, Union
2020

2121
logger = logging.getLogger("api_models")
2222

@@ -49,7 +49,6 @@ def __add__(self, other):
4949
field.name: getattr(self, field.name) + getattr(other, field.name)
5050
for field in fields(self)
5151
})
52-
5352
def replace(self, other):
5453
if not isinstance(other, APIStats):
5554
raise TypeError("Can only replace APIStats with APIStats")
@@ -148,19 +147,13 @@ def update_stats(self, input_tokens: int, output_tokens: int) -> float:
148147
)
149148

150149
# Check whether total cost or instance cost limits have been exceeded
151-
if (
152-
self.args.total_cost_limit > 0
153-
and self.stats.total_cost >= self.args.total_cost_limit
154-
):
150+
if 0 < self.args.total_cost_limit <= self.stats.total_cost:
155151
logger.warning(
156152
f"Cost {self.stats.total_cost:.2f} exceeds limit {self.args.total_cost_limit:.2f}"
157153
)
158154
raise CostLimitExceededError("Total cost limit exceeded")
159155

160-
if (
161-
self.args.per_instance_cost_limit > 0
162-
and self.stats.instance_cost >= self.args.per_instance_cost_limit
163-
):
156+
if 0 < self.args.per_instance_cost_limit <= self.stats.instance_cost:
164157
logger.warning(
165158
f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.args.per_instance_cost_limit:.2f}"
166159
)
@@ -233,7 +226,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
233226

234227
def history_to_messages(
235228
self, history: list[dict[str, str]], is_demonstration: bool = False
236-
) -> list[dict[str, str]]:
229+
) -> Union[str, list[dict[str, str]]]:
237230
"""
238231
Create `messages` by filtering out all keys except for role/content per `history` turn
239232
"""
@@ -273,6 +266,7 @@ def query(self, history: list[dict[str, str]]) -> str:
273266
self.update_stats(input_tokens, output_tokens)
274267
return response.choices[0].message.content
275268

269+
276270
class AnthropicModel(BaseModel):
277271
MODELS = {
278272
"claude-instant": {
@@ -326,7 +320,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
326320

327321
def history_to_messages(
328322
self, history: list[dict[str, str]], is_demonstration: bool = False
329-
) -> list[dict[str, str]]:
323+
) -> Union[str, list[dict[str, str]]]:
330324
"""
331325
Create `prompt` by filtering out all keys except for role/content per `history` turn
332326
Reference: https://docs.anthropic.com/claude/reference/complete_post
@@ -440,7 +434,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
440434

441435
def history_to_messages(
442436
self, history: list[dict[str, str]], is_demonstration: bool = False
443-
) -> list[dict[str, str]]:
437+
) -> Union[str, list[dict[str, str]]]:
444438
"""
445439
Create `messages` by filtering out all keys except for role/content per `history` turn
446440
"""
@@ -516,7 +510,7 @@ class TogetherModel(BaseModel):
516510
"max_context": 32768,
517511
"cost_per_input_token": 6e-07,
518512
"cost_per_output_token": 6e-07,
519-
},
513+
},
520514
}
521515

522516
SHORTCUTS = {
@@ -593,7 +587,7 @@ def __init__(self, args: ModelArguments, commands: list[Command]):
593587

594588
def history_to_messages(
595589
self, history: list[dict[str, str]], is_demonstration: bool = False
596-
) -> list[dict[str, str]]:
590+
) -> Union[str, list[dict[str, str]]]:
597591
"""
598592
Create `messages` by filtering out all keys except for role/content per `history` turn
599593
"""
@@ -652,7 +646,7 @@ def query(self, history: list[dict[str, str]]) -> str:
652646
break
653647
thought_all += thought
654648
thought = input("... ")
655-
649+
656650
action = super().query(history, action_prompt="Action: ")
657651

658652
return f"{thought_all}\n```\n{action}\n```"

0 commit comments

Comments
 (0)