Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pr_agent/servers/azuredevops_server_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# The server listens for incoming webhooks from Azure DevOps Server and forwards them to the PR Agent.
# ADO webhook documentation: https://learn.microsoft.com/en-us/azure/devops/service-hooks/services/webhooks?view=azure-devops

import copy
import json
import os
import re
Expand All @@ -17,11 +18,12 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware

from pr_agent.agent.pr_agent import PRAgent, command2class
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers import get_git_provider_with_context
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from pr_agent.git_providers.utils import apply_repo_settings
Expand Down Expand Up @@ -58,7 +60,7 @@ def handle_line_comment(body: str, thread_id: int, provider: AzureDevopsProvider
thread_context = provider.get_thread_context(thread_id)
if not thread_context:
return body

path = thread_context.file_path
if thread_context.left_file_end or thread_context.left_file_start:
start_line = thread_context.left_file_start.line
Expand All @@ -71,7 +73,7 @@ def handle_line_comment(body: str, thread_id: int, provider: AzureDevopsProvider
else:
get_logger().info("No line range found in thread context", artifact={"thread_context": thread_context})
return body

question = body[5:].lstrip() # remove 4 chars: '/ask '
return f"/ask_line --line_start={start_line} --line_end={end_line} --side={side} --file_name={path} --comment_id={thread_id} {question}"

Expand All @@ -80,7 +82,7 @@ def handle_line_comment(body: str, thread_id: int, provider: AzureDevopsProvider
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
if WEBHOOK_USERNAME is None or WEBHOOK_PASSWORD is None:
return

is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
if not (is_user_ok and is_pass_ok):
Expand Down Expand Up @@ -172,6 +174,7 @@ async def handle_request_azure(data, log_context):
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "azure_devops_server"}
data = await request.json()
context["settings"] = copy.deepcopy(global_settings)
# get_logger().info(json.dumps(data))

background_tasks.add_task(handle_request_azure, data, log_context)
Expand Down
20 changes: 12 additions & 8 deletions pr_agent/servers/bitbucket_server_webhook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import json
import os
import re
Expand All @@ -13,11 +14,12 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware

from pr_agent.agent.pr_agent import PRAgent
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.utils import verify_signature
Expand Down Expand Up @@ -45,22 +47,22 @@ def should_process_pr_logic(data) -> bool:
try:
pr_data = data.get("pullRequest", {})
title = pr_data.get("title", "")

from_ref = pr_data.get("fromRef", {})
source_branch = from_ref.get("displayId", "") if from_ref else ""

to_ref = pr_data.get("toRef", {})
target_branch = to_ref.get("displayId", "") if to_ref else ""

author = pr_data.get("author", {})
user = author.get("user", {}) if author else {}
sender = user.get("name", "") if user else ""

repository = to_ref.get("repository", {}) if to_ref else {}
project = repository.get("project", {}) if repository else {}
project_key = project.get("key", "") if project else ""
repo_slug = repository.get("slug", "") if repository else ""

repo_full_name = f"{project_key}/{repo_slug}" if project_key and repo_slug else ""
pr_id = pr_data.get("id", None)

Expand Down Expand Up @@ -102,7 +104,8 @@ def should_process_pr_logic(data) -> bool:
# Allow_only_specific_folders
allowed_folders = get_settings().config.get("allow_only_specific_folders", [])
if allowed_folders and pr_id and project_key and repo_slug:
from pr_agent.git_providers.bitbucket_server_provider import BitbucketServerProvider
from pr_agent.git_providers.bitbucket_server_provider import \
BitbucketServerProvider
bitbucket_server_url = get_settings().get("BITBUCKET_SERVER.URL", "")
pr_url = f"{bitbucket_server_url}/projects/{project_key}/repos/{repo_slug}/pull-requests/{pr_id}"
provider = BitbucketServerProvider(pr_url=pr_url)
Expand All @@ -114,7 +117,7 @@ def should_process_pr_logic(data) -> bool:
if any(file_path.startswith(folder) for folder in allowed_folders):
all_files_outside = False
break

if all_files_outside:
get_logger().info(f"Ignoring PR because all files {changed_files} are outside allowed folders {allowed_folders}")
return False
Expand All @@ -131,6 +134,7 @@ async def redirect_to_webhook():
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "bitbucket_server"}
data = await request.json()
context["settings"] = copy.deepcopy(global_settings)
get_logger().info(json.dumps(data))

webhook_secret = get_settings().get("BITBUCKET_SERVER.WEBHOOK_SECRET", None)
Expand Down
121 changes: 121 additions & 0 deletions tests/unittest/test_apply_repo_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import copy

import pytest
from starlette_context import context, request_cycle_context

from pr_agent.config_loader import get_settings, global_settings
from pr_agent.git_providers import utils as git_utils

REPO_A_TOML = b"""
[pr_reviewer]
extra_instructions = "MARKER-FROM-REPO-A"

[pr_code_suggestions]
extra_instructions = "MARKER-FROM-REPO-A"
"""


class FakeGitProvider:
def __init__(self, repo_settings: bytes):
self._repo_settings = repo_settings

def get_repo_settings(self):
return self._repo_settings

def is_supported(self, feature):
return False

def publish_comment(self, body):
pass

def publish_persistent_comment(self, *args, **kwargs):
pass


@pytest.fixture
def fresh_global_settings():
"""Restore module-level global_settings after each test in case anything mutated it."""
snapshot = copy.deepcopy(global_settings.as_dict())
yield
for section in set(global_settings.as_dict().keys()) - set(snapshot.keys()):
global_settings.unset(section)
for section, contents in snapshot.items():
global_settings.unset(section)
global_settings.set(section, copy.deepcopy(contents), merge=False)


def _extra_instructions(section: str) -> str:
return get_settings().get(f"{section}.extra_instructions", "") or ""


class TestApplyRepoSettings:
"""Verify that the per-request settings clone (set by webhook handlers via
`context['settings'] = copy.deepcopy(global_settings)`) successfully
isolates `apply_repo_settings()` mutations to the request that produced
them — preventing cross-repo `.pr_agent.toml` state leaks reported in #2345.
"""

def test_repo_settings_from_toml_are_applied(self, fresh_global_settings, monkeypatch):
monkeypatch.setattr(
"pr_agent.git_providers.utils.get_git_provider_with_context",
lambda url: FakeGitProvider(REPO_A_TOML),
)
with request_cycle_context({}):
context["settings"] = copy.deepcopy(global_settings)
git_utils.apply_repo_settings("https://git.example/projects/A/repos/a/pull-requests/1")
assert "MARKER-FROM-REPO-A" in _extra_instructions("pr_reviewer")
assert "MARKER-FROM-REPO-A" in _extra_instructions("pr_code_suggestions")

def test_repo_without_toml_does_not_inherit_previous_repo_settings(
self, fresh_global_settings, monkeypatch
):
# Request 1: Repo A with .pr_agent.toml — mutates only this request's settings clone.
monkeypatch.setattr(
"pr_agent.git_providers.utils.get_git_provider_with_context",
lambda url: FakeGitProvider(REPO_A_TOML),
)
with request_cycle_context({}):
context["settings"] = copy.deepcopy(global_settings)
git_utils.apply_repo_settings("https://git.example/projects/A/repos/a/pull-requests/1")
assert "MARKER-FROM-REPO-A" in _extra_instructions("pr_reviewer"), "precondition"

# Request 2: Repo B with no .pr_agent.toml — fresh clone of the unmutated global_settings.
monkeypatch.setattr(
"pr_agent.git_providers.utils.get_git_provider_with_context",
lambda url: FakeGitProvider(b""),
)
with request_cycle_context({}):
context["settings"] = copy.deepcopy(global_settings)
git_utils.apply_repo_settings("https://git.example/projects/B/repos/b/pull-requests/1")
assert "MARKER-FROM-REPO-A" not in _extra_instructions("pr_reviewer"), \
"repo A's [pr_reviewer].extra_instructions leaked into repo B"
assert "MARKER-FROM-REPO-A" not in _extra_instructions("pr_code_suggestions"), \
"repo A's [pr_code_suggestions].extra_instructions leaked into repo B"

def test_unknown_section_does_not_leak_to_next_repo(self, fresh_global_settings, monkeypatch):
"""Catches the case where a repo's `.pr_agent.toml` introduces a section
name not present in the startup defaults. With the per-request clone,
the new section lives in `context['settings']` and dies with the request.
"""
custom_section_toml = b"""
[my_custom_repo_section]
foo = "X-FROM-REPO-A"
"""
monkeypatch.setattr(
"pr_agent.git_providers.utils.get_git_provider_with_context",
lambda url: FakeGitProvider(custom_section_toml),
)
with request_cycle_context({}):
context["settings"] = copy.deepcopy(global_settings)
git_utils.apply_repo_settings("https://git.example/projects/A/repos/a/pull-requests/1")
assert get_settings().get("my_custom_repo_section.foo") == "X-FROM-REPO-A", "precondition"

monkeypatch.setattr(
"pr_agent.git_providers.utils.get_git_provider_with_context",
lambda url: FakeGitProvider(b""),
)
with request_cycle_context({}):
context["settings"] = copy.deepcopy(global_settings)
git_utils.apply_repo_settings("https://git.example/projects/B/repos/b/pull-requests/1")
assert get_settings().get("my_custom_repo_section.foo") is None, \
"repo A's [my_custom_repo_section] leaked into repo B"