Skip to content

Commit 70b53f0

Browse files
authored
Merge branch 'main' into issue-209
2 parents d772f86 + c93e106 commit 70b53f0

File tree

17 files changed

+1532
-354
lines changed

17 files changed

+1532
-354
lines changed

src/codegate/api/v1_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pydantic
66

77
from codegate.db import models as db_models
8-
from codegate.pipeline.base import CodeSnippet
8+
from codegate.extract_snippets.message_extractor import CodeSnippet
99
from codegate.providers.base import BaseProvider
1010
from codegate.providers.registry import ProviderRegistry
1111

src/codegate/clients/clients.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ class ClientType(Enum):
1111
KODU = "kodu" # Kodu client
1212
COPILOT = "copilot" # Copilot client
1313
OPEN_INTERPRETER = "open_interpreter" # Open Interpreter client
14+
AIDER = "aider" # Aider client
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Optional
3+
4+
from codegate.extract_snippets.message_extractor import (
5+
AiderCodeSnippetExtractor,
6+
ClineCodeSnippetExtractor,
7+
CodeSnippetExtractor,
8+
DefaultCodeSnippetExtractor,
9+
OpenInterpreterCodeSnippetExtractor,
10+
)
11+
12+
13+
class BodyCodeSnippetExtractorError(Exception):
14+
pass
15+
16+
17+
class BodyCodeSnippetExtractor(ABC):
18+
19+
def __init__(self):
20+
# Initialize the extractor in parent class. The child classes will set the extractor.
21+
self._snippet_extractor: Optional[CodeSnippetExtractor] = None
22+
23+
def _extract_from_user_messages(self, data: dict) -> set[str]:
24+
"""
25+
The method extracts the code snippets from the user messages in the data got from the
26+
clients.
27+
28+
It returns a set of filenames extracted from the code snippets.
29+
"""
30+
if self._snippet_extractor is None:
31+
raise BodyCodeSnippetExtractorError("Code Extractor not set.")
32+
33+
filenames: List[str] = []
34+
for msg in data.get("messages", []):
35+
if msg.get("role", "") == "user":
36+
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
37+
msg.get("content")
38+
)
39+
filenames.extend(extracted_snippets.keys())
40+
return set(filenames)
41+
42+
@abstractmethod
43+
def extract_unique_filenames(self, data: dict) -> set[str]:
44+
"""
45+
Extract the unique filenames from the data received by the clients (Cline, Continue, ...)
46+
"""
47+
pass
48+
49+
50+
class ContinueBodySnippetExtractor(BodyCodeSnippetExtractor):
51+
52+
def __init__(self):
53+
self._snippet_extractor = DefaultCodeSnippetExtractor()
54+
55+
def extract_unique_filenames(self, data: dict) -> set[str]:
56+
return self._extract_from_user_messages(data)
57+
58+
59+
class AiderBodySnippetExtractor(BodyCodeSnippetExtractor):
60+
61+
def __init__(self):
62+
self._snippet_extractor = AiderCodeSnippetExtractor()
63+
64+
def extract_unique_filenames(self, data: dict) -> set[str]:
65+
return self._extract_from_user_messages(data)
66+
67+
68+
class ClineBodySnippetExtractor(BodyCodeSnippetExtractor):
69+
70+
def __init__(self):
71+
self._snippet_extractor = ClineCodeSnippetExtractor()
72+
73+
def extract_unique_filenames(self, data: dict) -> set[str]:
74+
return self._extract_from_user_messages(data)
75+
76+
77+
class OpenInterpreterBodySnippetExtractor(BodyCodeSnippetExtractor):
78+
79+
def __init__(self):
80+
self._snippet_extractor = OpenInterpreterCodeSnippetExtractor()
81+
82+
def _is_msg_tool_call(self, msg: dict) -> bool:
83+
return msg.get("role", "") == "assistant" and msg.get("tool_calls", [])
84+
85+
def _is_msg_tool_result(self, msg: dict) -> bool:
86+
return msg.get("role", "") == "tool" and msg.get("content", "")
87+
88+
def _extract_args_from_tool_call(self, msg: dict) -> str:
89+
"""
90+
Extract the arguments from the tool call message.
91+
"""
92+
tool_calls = msg.get("tool_calls", [])
93+
if not tool_calls:
94+
return ""
95+
return tool_calls[0].get("function", {}).get("arguments", "")
96+
97+
def _extract_result_from_tool_result(self, msg: dict) -> str:
98+
"""
99+
Extract the result from the tool result message.
100+
"""
101+
return msg.get("content", "")
102+
103+
def extract_unique_filenames(self, data: dict) -> set[str]:
104+
messages = data.get("messages", [])
105+
if not messages:
106+
return set()
107+
108+
filenames: List[str] = []
109+
for i_msg in range(len(messages) - 1):
110+
msg = messages[i_msg]
111+
next_msg = messages[i_msg + 1]
112+
if self._is_msg_tool_call(msg) and self._is_msg_tool_result(next_msg):
113+
tool_args = self._extract_args_from_tool_call(msg)
114+
tool_response = self._extract_result_from_tool_result(next_msg)
115+
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
116+
f"{tool_args}\n{tool_response}"
117+
)
118+
filenames.extend(extracted_snippets.keys())
119+
return set(filenames)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from codegate.clients.clients import ClientType
2+
from codegate.extract_snippets.body_extractor import (
3+
AiderBodySnippetExtractor,
4+
BodyCodeSnippetExtractor,
5+
ClineBodySnippetExtractor,
6+
ContinueBodySnippetExtractor,
7+
OpenInterpreterBodySnippetExtractor,
8+
)
9+
from codegate.extract_snippets.message_extractor import (
10+
AiderCodeSnippetExtractor,
11+
ClineCodeSnippetExtractor,
12+
CodeSnippetExtractor,
13+
DefaultCodeSnippetExtractor,
14+
OpenInterpreterCodeSnippetExtractor,
15+
)
16+
17+
18+
class BodyCodeExtractorFactory:
19+
20+
@staticmethod
21+
def create_snippet_extractor(detected_client: ClientType) -> BodyCodeSnippetExtractor:
22+
mapping_client_extractor = {
23+
ClientType.GENERIC: ContinueBodySnippetExtractor(),
24+
ClientType.CLINE: ClineBodySnippetExtractor(),
25+
ClientType.AIDER: AiderBodySnippetExtractor(),
26+
ClientType.OPEN_INTERPRETER: OpenInterpreterBodySnippetExtractor(),
27+
}
28+
return mapping_client_extractor.get(detected_client, ContinueBodySnippetExtractor())
29+
30+
31+
class MessageCodeExtractorFactory:
32+
33+
@staticmethod
34+
def create_snippet_extractor(detected_client: ClientType) -> CodeSnippetExtractor:
35+
mapping_client_extractor = {
36+
ClientType.GENERIC: DefaultCodeSnippetExtractor(),
37+
ClientType.CLINE: ClineCodeSnippetExtractor(),
38+
ClientType.AIDER: AiderCodeSnippetExtractor(),
39+
ClientType.OPEN_INTERPRETER: OpenInterpreterCodeSnippetExtractor(),
40+
}
41+
return mapping_client_extractor.get(detected_client, DefaultCodeSnippetExtractor())

0 commit comments

Comments
 (0)