Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use the latest user messages block instead of single message #585

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/archived.jsonl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"name":"@prefix/archived-npm-dummy","type":"npm","description":"Dummy archived to test with encoded package name on npm"}
{"name":"archived-npm-dummy","type":"npm","description":"Dummy archived to test with simple package name on npm"}
{"name":"@prefix/archived-pypi-dummy","type":"pypi","description":"Dummy archived to test with encoded package name on pypi"}
{"name":"archived-pypi-dummy","type":"pypi","description":"Dummy archived to test with simple package name on pypi"}
{"name":"archived_pypi_dummy","type":"pypi","description":"Dummy archived to test with simple package name on pypi"}
{"name":"@prefix/archived-maven-dummy","type":"maven","description":"Dummy archived to test with encoded package name on maven"}
{"name":"archived-maven-dummy","type":"maven","description":"Dummy archived to test with simple package name on maven"}
{"name":"github.com/archived-go-dummy","type":"npm","description":"Dummy archived to test with encoded package name on go"}
Expand Down
2 changes: 1 addition & 1 deletion data/deprecated.jsonl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"name":"@prefix/deprecated-npm-dummy","type":"npm","description":"Dummy deprecated to test with encoded package name on npm"}
{"name":"deprecated-npm-dummy","type":"npm","description":"Dummy deprecated to test with simple package name on npm"}
{"name":"@prefix/deprecated-pypi-dummy","type":"pypi","description":"Dummy deprecated to test with encoded package name on pypi"}
{"name":"deprecated-pypi-dummy","type":"pypi","description":"Dummy deprecated to test with simple package name on pypi"}
{"name":"deprecated_pypi_dummy","type":"pypi","description":"Dummy deprecated to test with simple package name on pypi"}
{"name":"@prefix/deprecated-maven-dummy","type":"maven","description":"Dummy deprecated to test with encoded package name on maven"}
{"name":"deprecated-maven-dummy","type":"maven","description":"Dummy deprecated to test with simple package name on maven"}
{"name":"github.com/deprecated-go-dummy","type":"npm","description":"Dummy deprecated to test with encoded package name on go"}
Expand Down
2 changes: 1 addition & 1 deletion data/malicious.jsonl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"name":"@prefix/malicious-npm-dummy","type":"npm","description":"Dummy malicious to test with encoded package name on npm"}
{"name":"malicious-npm-dummy","type":"npm","description":"Dummy malicious to test with simple package name on npm"}
{"name":"@prefix/malicious-pypi-dummy","type":"pypi","description":"Dummy malicious to test with encoded package name on pypi"}
{"name":"malicious-pypi-dummy","type":"pypi","description":"Dummy malicious to test with simple package name on pypi"}
{"name":"malicious_pypi_dummy","type":"pypi","description":"Dummy malicious to test with simple package name on pypi"}
{"name":"@prefix/malicious-maven-dummy","type":"maven","description":"Dummy malicious to test with encoded package name on maven"}
{"name":"malicious-maven-dummy","type":"maven","description":"Dummy malicious to test with simple package name on maven"}
{"name":"github.com/malicious-go-dummy","type":"go","description":"Dummy malicious to test with encoded package name on go"}
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ tree-sitter-javascript = ">=0.23.1"
tree-sitter-python = ">=0.23.6"
tree-sitter-rust = ">=0.23.2"
sqlite-vec-sl-tmp = "^0.0.4"
pygments = "^2.19.1"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.0"
Expand Down
41 changes: 39 additions & 2 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,45 @@ def get_last_user_message(
return None
for i in reversed(range(len(request["messages"]))):
if request["messages"][i]["role"] == "user":
content = request["messages"][i]["content"]
return content, i
content = request["messages"][i]["content"] # type: ignore
return str(content), i

return None

@staticmethod
def get_last_user_message_block(
request: ChatCompletionRequest,
) -> Optional[str]:
"""
Get the last block of consecutive 'user' messages from the request.

Args:
request (ChatCompletionRequest): The chat completion request to process

Returns:
Optional[str]: A string containing all consecutive user messages in the
last user message block, separated by newlines, or None if
no user message block is found.
"""
if request.get("messages") is None:
return None

user_messages = []
messages = request["messages"]

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we include the assistant messages? This reads to me as if we would get all the user messages until the first system message?

Maybe it would be nice to have some tests so that the purpose is easier to see through them?

Copy link
Contributor Author

@yrobla yrobla Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when testing deeper with aider, i saw that some of the blocks were composed by this pattern:

system
user
assistant
user

so we need to extract all this block. I am fixing tests now and i can add a test with the aider pattern for it. I also tested against the other tools and does not affect, as they just have a single user message

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when testing deeper with aider, i saw that some of the blocks were composed by this pattern:

system user assistant user

so we need to extract all this block. I am fixing tests now and i can add a test with the aider pattern for it. I also tested against the other tools and does not affect, as they just have a single user message

I'm more concerned if we don't break any other extractions because other assistants might just send (assistant,user) in a single turn of a conversation..

if messages[i]["role"] == "user":
user_messages.append(messages[i]["content"]) # type: ignore
else:
# Stop when a message with a different role is encountered
if user_messages:
break

# Reverse the collected user messages to preserve the original order
if user_messages:
return "\n".join(reversed(user_messages))

return None

Expand Down
24 changes: 11 additions & 13 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import re

import structlog
from litellm import ChatCompletionRequest

Expand Down Expand Up @@ -59,38 +58,37 @@ async def process(
"""
Use RAG DB to add context to the user request
"""
# Get the latest user messages
user_messages = self.get_latest_user_messages(request)

# Nothing to do if the user_messages string is empty
if len(user_messages) == 0:
# Get the latest user message
user_message = self.get_last_user_message_block(request)
if not user_message:
return PipelineResult(request=request)

# Create storage engine object
storage_engine = StorageEngine()

# Extract any code snippets
snippets = extract_snippets(user_messages)
snippets = extract_snippets(user_message)

bad_snippet_packages = []
if len(snippets) > 0:
snippet_language = snippets[0].language
# Collect all packages referenced in the snippets
snippet_packages = []
for snippet in snippets:
snippet_packages.extend(
PackageExtractor.extract_packages(snippet.code, snippet.language)
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
)
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")

logger.info(f"Found {len(snippet_packages)} packages "
"for language {snippet_language} in code snippets.")
# Find bad packages in the snippets
bad_snippet_packages = await storage_engine.search(
language=snippets[0].language, packages=snippet_packages
)
language=snippet_language, packages=snippet_packages) # type: ignore
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")

# Remove code snippets from the user messages and search for bad packages
# in the rest of the user query/messsages
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)
user_messages = re.sub(r"```.*?```", "", user_message, flags=re.DOTALL)

# Vector search to find bad packages
bad_packages = await storage_engine.search(query=user_messages, distance=0.5, limit=100)
Expand Down Expand Up @@ -119,7 +117,7 @@ async def process(
# Add the context to the last user message
# Format: "Context: {context_str} \n Query: {last user message content}"
message = new_request["messages"][last_user_idx]
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' # type: ignore
message["content"] = context_msg

logger.debug("Final context message", context_message=context_msg)
Expand Down
11 changes: 8 additions & 3 deletions src/codegate/pipeline/extract_snippets/extract_snippets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from pygments.lexers import guess_lexer
from typing import List, Optional

import structlog
Expand Down Expand Up @@ -105,6 +106,11 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
filename = filename.strip()
# Determine language from the filename
lang = ecosystem_from_filepath(filename)
if lang is None:
# try to guess it from the code
lexer = guess_lexer(content)
if lexer and lexer.name:
lang = lexer.name.lower()

snippets.append(CodeSnippet(filepath=filename, code=content, language=lang))

Expand All @@ -129,10 +135,9 @@ async def process(
request: ChatCompletionRequest,
context: PipelineContext,
) -> PipelineResult:
last_user_message = self.get_last_user_message(request)
if not last_user_message:
msg_content = self.get_last_user_message_block(request)
if not msg_content:
return PipelineResult(request=request, context=context)
msg_content, _ = last_user_message
snippets = extract_snippets(msg_content)

logger.info(f"Extracted {len(snippets)} code snippets from the user message")
Expand Down
Loading