diff --git a/configs/submit76/config.yaml b/configs/submit76/config.yaml new file mode 100644 index 000000000..c8139f069 --- /dev/null +++ b/configs/submit76/config.yaml @@ -0,0 +1,62 @@ +# Submit76 deployment config +# Deploy with: +# ./configs/submit76/deploy-submit76.sh --live +# +# Target: mohoney@submit76.mit.edu +# Ollama must be running on submit76 at localhost:7870 with gpt-oss:120b pulled. + +name: my_archi + +services: + chat_app: + agent_class: CMSCompOpsAgent + agents_dir: examples/agents + default_provider: local + default_model: "qwen3:32b" + providers: + local: + enabled: true + base_url: http://localhost:7870 + mode: ollama + default_model: "qwen3:32b" + models: + - "gpt-oss:120b" + - "qwen3:32b" + port: 7865 + external_port: 7865 + ab_testing: + enabled: true + pool: + champion: default + variants: + - name: default + provider: local + model: "qwen3:32b" + - name: gpt-oss-120b + provider: local + model: "gpt-oss:120b" + postgres: + port: 5435 + data_manager: + port: 7878 + external_port: 7878 + auth: + enabled: true + +data_manager: + sources: + jira: + enabled: true + max_tickets: 10 + url: https://its.cern.ch/jira/ + projects: + - "CMSPROD" + links: + input_lists: + - /home/submit/pmlugato/random_configs/lists/sso_git.list + redmine: + url: https://cleo.mit.edu + project: emails-to-ticket + projects: + - emails-to-ticket + embedding_name: HuggingFaceEmbeddings diff --git a/docs/docs/api_reference.md b/docs/docs/api_reference.md index 9937ab4cc..37f22ee49 100644 --- a/docs/docs/api_reference.md +++ b/docs/docs/api_reference.md @@ -46,7 +46,60 @@ Retrieve the full trace of a previous request. ### `POST /api/ab/create` -Create an A/B comparison between two model responses. +Create an A/B comparison between two model responses (legacy manual mode). + +### `GET /api/ab/pool` + +Get the server-side A/B testing pool configuration. The response shape depends on RBAC: + +- `ab:view` or `ab:manage`: full read-only experiment configuration +- `ab:participate`: participant-focused payload including the effective per-user sample rate and participant eligibility diagnostics +- otherwise: `enabled: false` + +**Response (pool active):** +```json +{ + "success": true, + "enabled": true, + "can_view": true, + "can_manage": false, + "champion": "default", + "variants": ["default", "creative", "concise"], + "comparison_rate": 0.25, + "default_comparison_rate": 0.25, + "participant_eligible": true, + "participant_reason": "eligible" +} +``` + +Participant payloads can also report `participant_reason: "not_targeted"` when the deployment has an active experiment but the current user's roles or permissions are not included in that experiment's target filters. + +### `POST /api/ab/compare` + +Stream a pool-based champion-vs-variant A/B comparison. The server randomly pairs the champion against another variant from the pool and streams interleaved NDJSON events tagged with `arm: "a"` or `arm: "b"`. A final `ab_meta` event carries the `comparison_id` and variant mapping. + +**Request body:** Same as `/api/get_chat_response_stream`. + +### `GET /api/ab/metrics` + +Get per-variant aggregate metrics (wins, losses, ties, total comparisons). + +**Response:** +```json +{ + "success": true, + "metrics": [ + { + "variant_name": "creative", + "wins": 12, + "losses": 5, + "ties": 3, + "total_comparisons": 20, + "last_updated": "2025-01-15T10:30:00" + } + ] +} +``` --- @@ -92,14 +145,15 @@ Get or create the current user. ### `PATCH /api/users/me/preferences` -Update user preferences (model, temperature, prompts, theme). +Update user preferences (model, temperature, prompts, theme, and A/B participation override). **Request:** ```json { "theme": "light", "preferred_model": "claude-3-opus", - "preferred_temperature": 0.5 + "preferred_temperature": 0.5, + "ab_participation_rate": 0.75 } ``` @@ -232,6 +286,39 @@ Set the active agent for the current session. } ``` +### `GET /api/ab/agents/list` + +List the Postgres-backed A/B agent catalog for the A/B admin page. Requires A/B page access. + +### `GET /api/ab/agents/template` + +Get the A/B admin template payload with structured tool metadata. Requires `ab:manage`. + +**Response:** +```json +{ + "name": "New A/B Agent", + "prompt": "Write your system prompt here.", + "tools": [ + {"name": "search_vectorstore_hybrid", "description": "Search indexed documents."} + ], + "scope": "ab" +} +``` + +### `POST /api/ab/agents` + +Create a new Postgres-backed A/B agent spec from structured fields. Requires `ab:manage`. + +**Request:** +```json +{ + "name": "A/B Candidate", + "tools": ["search_vectorstore_hybrid"], + "prompt": "You are a helpful A/B experiment agent." +} +``` + --- ## Prompts diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index 39117203f..235fc1668 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -97,6 +97,36 @@ services: - alerts:manage ``` +#### `services.chat_app.auth` + +Authentication can be enabled with SSO or basic auth. + +For RBAC-managed admin access, use SSO plus `auth_roles`. Basic auth supports identity-only login, but it does not assign RBAC roles. + +```yaml +services: + chat_app: + auth: + enabled: true + basic: + enabled: true + auth_roles: + default_role: base-user + roles: + base-user: + permissions: + - chat:query + - ab:participate + ab-reviewer: + permissions: + - documents:view + - ab:view + - ab:metrics + ab-admin: + permissions: + - ab:manage +``` + #### Provider Configuration ```yaml @@ -257,6 +287,111 @@ data_manager: --- +## A/B Testing Pool + +Archi supports champion-vs-variant A/B testing via a server-side variant pool. When configured, the system automatically pairs the champion agent against a random variant for each comparison. Users vote on which response is better, and aggregate metrics are tracked per variant. + +Configure A/B testing under `services.chat_app.ab_testing`: + +```yaml +services: + chat_app: + ab_testing: + enabled: true + force_yaml_override: false + comparison_rate: 0.25 + variant_label_mode: post_vote_reveal + activity_panel_default_state: hidden + max_pending_comparisons_per_conversation: 1 + eligible_roles: [] + eligible_permissions: [] + pool: + champion: default + variants: + - label: default + agent_spec: default.md + - label: creative + agent_spec: default.md + provider: openai + model: gpt-4o + recursion_limit: 30 + - label: concise + agent_spec: concise.md + provider: anthropic + model: claude-sonnet-4-20250514 + num_documents_to_retrieve: 3 +``` + +`services.ab_testing` is deprecated and no longer loaded. Use `services.chat_app.ab_testing` only. + +If `enabled: true` is set before the A/B pool is fully configured, Archi starts successfully but keeps A/B inactive until setup is completed in the admin UI. Missing champion/variant selections or unresolved A/B agent-spec records are surfaced as warnings instead of blocking startup. + +The runtime source of truth for A/B agent specs is now PostgreSQL. The optional `ab_agents_dir` path is treated only as a legacy import source during reconciliation; runtime A/B loading never falls back to reading staged container markdown files directly. + +New A/B specs created through the admin UI are stored in the same PostgreSQL-backed catalog. Existing A/B specs are not edited through the admin page; if you need a changed prompt or tool selection, create a new A/B spec and point the variant at that new catalog entry. + +For `services.chat_app.ab_testing`, the persisted PostgreSQL-backed static config becomes authoritative after the first successful bootstrap. On later restarts or reseeds, the rendered YAML A/B block is used only if no persisted A/B block exists yet, unless `force_yaml_override: true` is set to intentionally replace the saved A/B state. This flag is bootstrap-only and defaults to `false`. + +### Variant Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `label` | string | *required* | Unique human-facing variant label used in the UI and metrics | +| `agent_spec` | string | *required* | A/B agent-spec filename resolved from the database-backed A/B catalog | +| `provider` | string | `null` | Override LLM provider | +| `model` | string | `null` | Override LLM model | +| `num_documents_to_retrieve` | int | `null` | Override retriever document count | +| `recursion_limit` | int | `null` | Override agent recursion limit | + +### Pool Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `enabled` | boolean | `false` | Enable the experiment pool | +| `ab_agents_dir` | string | `/root/archi/ab_agents` | Optional legacy import directory for migrating A/B markdown specs into the DB catalog | +| `force_yaml_override` | boolean | `false` | Bootstrap-only override that forces the rendered YAML A/B block to replace persisted A/B state on reseed/restart | +| `comparison_rate` | float | `1.0` | Fraction of eligible turns that should run A/B | +| `variant_label_mode` | string | `post_vote_reveal` | One of `hidden`, `post_vote_reveal`, `always_visible` | +| `activity_panel_default_state` | string | `hidden` | One of `hidden`, `collapsed`, `expanded` | +| `max_pending_comparisons_per_conversation` | int | `1` | Maximum unresolved comparisons per conversation | +| `eligible_roles` | list[string] | `[]` | Restrict participation to matching RBAC roles | +| `eligible_permissions` | list[string] | `[]` | Restrict participation to matching permissions | + +### Config-To-UI Mapping + +| Config Field | Config Value | UI Label | Runtime Meaning | +|--------------|--------------|----------|-----------------| +| `comparison_rate` | `0.0..1.0` | `Comparison Rate` | Fraction of eligible turns that become A/B comparisons | +| `variant_label_mode` | `hidden` | `Hidden` | Hide variant labels before and after the vote | +| `variant_label_mode` | `post_vote_reveal` | `Post-Vote Reveal` | Hide variant labels until the vote is submitted | +| `variant_label_mode` | `always_visible` | `Always Visible` | Show variant labels throughout the comparison | +| `activity_panel_default_state` | `hidden` | `Hidden` | Do not show the per-arm activity panel by default | +| `activity_panel_default_state` | `collapsed` | `Collapsed` | Show the activity panel in a collapsed state | +| `activity_panel_default_state` | `expanded` | `Expanded` | Show the activity panel expanded by default | +| `max_pending_comparisons_per_conversation` | integer >= 1 | `Max Pending Comparisons Per Conversation` | Limit unresolved comparisons before the user must vote | +| `pool.champion` | existing variant label | `Champion` | Baseline variant that always appears in each comparison | + +The `champion` field must reference an existing variant `label`. At least two variants are required before the experiment becomes active. `name`-only variant config is not supported. When a user enables A/B mode in the chat UI, the pool takes over: the champion always appears in one arm, and a random variant is placed in the other. Arm positions (A vs B) are randomized per comparison. + +### A/B RBAC and User Preference + +Use RBAC to separate participation, read-only review, metrics access, and write access: + +| Permission | Purpose | +|------------|---------| +| `ab:participate` | Makes a user eligible for A/B comparisons and shows the per-user sampling slider in chat settings | +| `ab:view` | Allows read-only access to the A/B admin page | +| `ab:metrics` | Allows access to aggregate A/B metrics | +| `ab:manage` | Allows editing variants, A/B agent specs, and experiment settings | + +`services.chat_app.ab_testing.comparison_rate` remains the deployment default. Users with `ab:participate` can override that default per account with a `0..1` slider in chat settings. + +Users with `ab:view`, `ab:metrics`, or `ab:manage` can open the dedicated A/B Testing page from the data viewer and from chat settings. Users with `ab:participate` but not A/B page access still get the personal sampling slider in chat settings. + +Variant metrics (wins, losses, ties) are tracked in the `ab_variant_metrics` database table and available via `GET /api/ab/metrics`. + +--- + ## Agent Configuration Model Archi no longer uses a top-level `archi:` block in standard deployment YAML. diff --git a/src/archi/archi.py b/src/archi/archi.py index d6d0d584c..570691673 100644 --- a/src/archi/archi.py +++ b/src/archi/archi.py @@ -1,4 +1,5 @@ -import src.archi.pipelines as archiPipelines +from importlib import import_module + from src.utils.config_access import get_full_config from src.utils.logging import get_logger from src.archi.utils.output_dataclass import PipelineOutput @@ -6,6 +7,11 @@ logger = get_logger(__name__) + +def _get_pipelines_module(): + """Load pipeline exports only when the runtime needs to resolve a class.""" + return import_module("src.archi.pipelines") + class archi(): """ Central class of the archi framework. @@ -51,7 +57,7 @@ def _create_pipeline_instance(self, class_name, *args, **kwargs): logger.debug("and kwargs:") logger.debug(f"{kwargs}") try: - cls = getattr(archiPipelines, class_name) + cls = getattr(_get_pipelines_module(), class_name) return cls(*args, **kwargs) except AttributeError: raise ValueError(f"Class '{class_name}' not found in module") diff --git a/src/archi/pipelines/__init__.py b/src/archi/pipelines/__init__.py index 3a04015b8..e17eabcb1 100644 --- a/src/archi/pipelines/__init__.py +++ b/src/archi/pipelines/__init__.py @@ -1,17 +1,31 @@ """Pipeline package exposing the available pipeline classes.""" -from .classic_pipelines.base import BasePipeline -from .classic_pipelines.grading import GradingPipeline -from .classic_pipelines.image_processing import ImageProcessingPipeline -from .classic_pipelines.qa import QAPipeline -from .agents.base_react import BaseReActAgent -from .agents.cms_comp_ops_agent import CMSCompOpsAgent - -__all__ = [ - "BasePipeline", - "GradingPipeline", - "ImageProcessingPipeline", - "QAPipeline", - "BaseReActAgent", - "CMSCompOpsAgent", -] +from importlib import import_module + + +_PIPELINE_EXPORTS = { + "BasePipeline": (".classic_pipelines.base", "BasePipeline"), + "GradingPipeline": (".classic_pipelines.grading", "GradingPipeline"), + "ImageProcessingPipeline": (".classic_pipelines.image_processing", "ImageProcessingPipeline"), + "QAPipeline": (".classic_pipelines.qa", "QAPipeline"), + "BaseReActAgent": (".agents.base_react", "BaseReActAgent"), + "CMSCompOpsAgent": (".agents.cms_comp_ops_agent", "CMSCompOpsAgent"), +} + +__all__ = list(_PIPELINE_EXPORTS) + + +def __getattr__(name): + target = _PIPELINE_EXPORTS.get(name) + if target is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module_name, attr_name = target + module = import_module(module_name, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value + + +def __dir__(): + return sorted(list(globals().keys()) + __all__) diff --git a/src/archi/pipelines/agents/agent_spec.py b/src/archi/pipelines/agents/agent_spec.py index 5c99f491b..2361d0641 100644 --- a/src/archi/pipelines/agents/agent_spec.py +++ b/src/archi/pipelines/agents/agent_spec.py @@ -13,6 +13,7 @@ class AgentSpec: tools: List[str] prompt: str source_path: Path + ab_only: bool = False class AgentSpecError(ValueError): @@ -31,22 +32,26 @@ def load_agent_spec(path: Path) -> AgentSpec: text = path.read_text() frontmatter, prompt = _parse_frontmatter(text, path) name, tools = _extract_metadata(frontmatter, path) + ab_only = bool(frontmatter.get("ab_only", False)) return AgentSpec( name=name, tools=tools, prompt=prompt, source_path=path, + ab_only=ab_only, ) def load_agent_spec_from_text(text: str) -> AgentSpec: frontmatter, prompt = _parse_frontmatter(text, Path("")) name, tools = _extract_metadata(frontmatter, Path("")) + ab_only = bool(frontmatter.get("ab_only", False)) return AgentSpec( name=name, tools=tools, prompt=prompt, source_path=Path(""), + ab_only=ab_only, ) diff --git a/src/archi/pipelines/agents/base_react.py b/src/archi/pipelines/agents/base_react.py index 43e08f604..ea61d351c 100644 --- a/src/archi/pipelines/agents/base_react.py +++ b/src/archi/pipelines/agents/base_react.py @@ -390,6 +390,29 @@ def stream(self, **kwargs) -> Iterator[PipelineOutput]: content = self._message_content(message) additional_kwargs = getattr(message, "additional_kwargs", None) or {} reasoning_content = additional_kwargs.get("reasoning_content", "") + + # Detect empty AI chunks as implicit thinking activity. + # Some LLM integrations (e.g. langchain-ollama <1.1) drop + # the thinking/reasoning payload, producing chunks where + # both content and reasoning_content are empty while the + # model is still in its thinking phase. We treat these as + # a signal to start (or continue) the thinking indicator + # so the UI stays responsive. + if not content and not reasoning_content: + if thinking_step_id is None and "chunk" in msg_class: + thinking_step_id = str(uuid.uuid4()) + thinking_start_time = time.time() + yield self.finalize_output( + answer="", + memory=self.active_memory, + messages=[], + metadata={ + "event_type": "thinking_start", + "step_id": thinking_step_id, + }, + final=False, + ) + if content or reasoning_content: # Start thinking phase if not already active if thinking_step_id is None: @@ -665,6 +688,23 @@ async def astream(self, **kwargs) -> AsyncIterator[PipelineOutput]: content = self._message_content(message) additional_kwargs = getattr(message, "additional_kwargs", None) or {} reasoning_content = additional_kwargs.get("reasoning_content", "") + + # Detect empty AI chunks as implicit thinking activity. + if not content and not reasoning_content: + if thinking_step_id is None and "chunk" in msg_class: + thinking_step_id = str(uuid.uuid4()) + thinking_start_time = time.time() + yield self.finalize_output( + answer="", + memory=self.active_memory, + messages=[], + metadata={ + "event_type": "thinking_start", + "step_id": thinking_step_id, + }, + final=False, + ) + if content or reasoning_content: # Start thinking phase if not already active if thinking_step_id is None: diff --git a/src/archi/providers/local_provider.py b/src/archi/providers/local_provider.py index b6026eda9..cfbb34e90 100644 --- a/src/archi/providers/local_provider.py +++ b/src/archi/providers/local_provider.py @@ -94,6 +94,7 @@ def _get_ollama_model(self, model_name: str, **kwargs) -> BaseChatModel: model_kwargs = { "model": model_name, "streaming": True, + "keep_alive": "24h", **self.config.extra_kwargs, **kwargs, } diff --git a/src/bin/service_chat.py b/src/bin/service_chat.py index ad30af0e2..9de1fbfb6 100644 --- a/src/bin/service_chat.py +++ b/src/bin/service_chat.py @@ -4,6 +4,7 @@ from flask import Flask +from src.interfaces.chat_app.api import register_api from src.interfaces.chat_app.app import FlaskAppWrapper from src.utils.env import read_secret from src.utils.logging import setup_logging @@ -27,30 +28,45 @@ def main(): # Reload config from Postgres (runtime source of truth) config = get_full_config() chat_config = config["services"]["chat_app"] - print(f"Starting Chat Service with (host, port): ({chat_config['host']}, {chat_config['port']})") - print(f"Accessible externally at (host, port): ({chat_config['hostname']}, {chat_config['external_port']})") - generate_script(chat_config) - app = FlaskAppWrapper(Flask( + # Deployment-time fields may not be in Postgres; use sensible defaults + host = chat_config.get("host", "0.0.0.0") + port = chat_config.get("port", 7681) + hostname = chat_config.get("hostname", host) + external_port = chat_config.get("external_port", port) + + # Resolve template/static folders from installed package location + _pkg_dir = os.path.dirname(os.path.abspath(__file__)) + _chat_app_dir = os.path.join(os.path.dirname(_pkg_dir), "interfaces", "chat_app") + template_folder = chat_config.get("template_folder", os.path.join(_chat_app_dir, "templates")) + static_folder = chat_config.get("static_folder", os.path.join(_chat_app_dir, "static")) + + print(f"Starting Chat Service with (host, port): ({host}, {port})") + print(f"Accessible externally at (host, port): ({hostname}, {external_port})") + + generate_script(chat_config, static_folder) + flask_app = Flask( __name__, - template_folder=chat_config["template_folder"], - static_folder=chat_config["static_folder"], - )) - app.run(debug=True, use_reloader=False, port=chat_config["port"], host=chat_config["host"]) + template_folder=template_folder, + static_folder=static_folder, + ) + register_api(flask_app) + app = FlaskAppWrapper(flask_app) + app.run(debug=True, use_reloader=False, port=port, host=host) -def generate_script(chat_config): +def generate_script(chat_config, static_folder): """ This is not elegant but it creates the javascript file from the template using the config.yaml parameters """ - script_template = os.path.join(chat_config["static_folder"], "script.js-template") + script_template = os.path.join(static_folder, "script.js-template") with open(script_template, "r") as f: template = f.read() - filled_template = template.replace('XX-NUM-RESPONSES-XX', str(chat_config["num_responses_until_feedback"])) + filled_template = template.replace('XX-NUM-RESPONSES-XX', str(chat_config.get("num_responses_until_feedback", 3))) filled_template = filled_template.replace('XX-TRAINED_ON-XX', str(chat_config.get("trained_on", ""))) - script_file = os.path.join(chat_config["static_folder"], "script.js") + script_file = os.path.join(static_folder, "script.js") with open(script_file, "w") as f: f.write(filled_template) diff --git a/src/cli/cli_main.py b/src/cli/cli_main.py index c94124d58..0453951dd 100644 --- a/src/cli/cli_main.py +++ b/src/cli/cli_main.py @@ -9,7 +9,7 @@ select_autoescape) from src.cli.managers.config_manager import ConfigurationManager -from src.cli.managers.deployment_manager import DeploymentManager +from src.cli.managers.deployment_manager import DeploymentError, DeploymentManager from src.cli.managers.secrets_manager import SecretsManager from src.cli.managers.templates_manager import TemplateManager from src.cli.managers.volume_manager import VolumeManager @@ -56,7 +56,7 @@ def cli(): @click.option('--dry', '--dry-run', is_flag=True, help="Validate configuration and show what would be created without actually deploying") def create(name: str, config_files: list, config_dir: str, env_file: str, services: list, force: bool, dry: bool, verbosity: int, **other_flags): - """Create an ARCHI deployment with selected services and data sources.""" + """Create an Archi deployment with selected services and data sources.""" if not (bool(config_files) ^ bool(config_dir)): raise click.ClickException(f"Must specify only one of config files or config dir") @@ -66,7 +66,7 @@ def create(name: str, config_files: list, config_dir: str, env_file: str, servic if len(config_files) != 1: raise click.ClickException("Exactly one config file is supported; please provide a single -c file.") - click.echo("Starting ARCHI deployment process...") + click.echo("Starting Archi deployment process...") setup_cli_logging(verbosity=verbosity) logger = get_logger(__name__) @@ -183,7 +183,7 @@ def create(name: str, config_files: list, config_dir: str, env_file: str, servic @click.option('--podman', '-p', is_flag=True, default=False, help="specify if podman is being used") def delete(name: str, rmi: bool, rmv: bool, keep_files: bool, list_deployments: bool, verbosity: int, podman: bool): """ - Delete an ARCHI deployment with the specified name. + Delete an Archi deployment with the specified name. This command stops containers and optionally removes images, volumes, and files. @@ -266,7 +266,7 @@ def delete(name: str, rmi: bool, rmv: bool, keep_files: bool, list_deployments: @click.option('--service', '-s', type=str, default="chatbot", help="Service to restart (default: chatbot)") @click.option('--config', '-c', 'config_files', type=str, multiple=True, help="Path to .yaml archi configuration") @click.option('--config-dir', '-cd', 'config_dir', type=str, help="Path to configs directory") -@click.option('--env-file', '-e', type=str, required=False, help="Path to .env file with secrets") +@click.option('--env-file', '-e', type=str, required=True, help="Path to .env file with secrets") @click.option('--no-build', is_flag=True, help="Restart without rebuilding the image") @click.option('--with-deps', is_flag=True, help="Also restart dependent services") @click.option('--podman', '-p', is_flag=True, default=False, help="specify if podman is being used") @@ -276,7 +276,7 @@ def restart( service: str, config_files: tuple, config_dir: Optional[str], - env_file: Optional[str], + env_file: str, no_build: bool, with_deps: bool, podman: bool, @@ -390,6 +390,20 @@ def restart( allow_port_reuse=True, ) + deployment_manager = DeploymentManager(use_podman=podman) + + if config_files or config_dir: + try: + if deployment_manager.has_service(deployment_dir, "config-seed"): + deployment_manager.run_service_once( + deployment_dir=deployment_dir, + service_name="config-seed", + build=not no_build, + no_deps=True, + ) + except DeploymentError as e: + raise click.ClickException(str(e)) + if not no_build and not (config_files or config_dir): template_manager = TemplateManager(env, verbosity) try: @@ -397,7 +411,9 @@ def restart( except Exception as e: logger.warning(f"Warning: could not update source code before rebuild: {e}", err=True) - deployment_manager = DeploymentManager(use_podman=podman) + if service == "config-seed": + return + deployment_manager.restart_service( deployment_dir=deployment_dir, service_name=service, @@ -410,7 +426,7 @@ def restart( def list_services(): """List all available services""" - click.echo("Available ARCHI services:\n") + click.echo("Available Archi services:\n") # Application services app_services = service_registry.get_application_services() @@ -481,7 +497,7 @@ def list_deployments(): @click.option('--tag', '-t', type=str, default="2000", help="Image tag for built containers") @click.option('--verbosity', '-v', type=int, default=3, help="Logging verbosity level (0-4)") def evaluate(name: str, config_file: str, config_dir: str, env_file: str, force: bool, verbosity: int, **other_flags): - """Create an ARCHI deployment with selected services and data sources.""" + """Create an Archi deployment with selected services and data sources.""" if not (bool(config_file) ^ bool(config_dir)): raise click.ClickException(f"Must specify only one of config files or config dir") if config_dir: @@ -490,7 +506,7 @@ def evaluate(name: str, config_file: str, config_dir: str, env_file: str, force: else: config_files = [item for item in config_file.split(",")] - click.echo("Starting ARCHI benchmarking process...") + click.echo("Starting Archi benchmarking process...") setup_cli_logging(verbosity=verbosity) logger = get_logger(__name__) diff --git a/src/cli/managers/config_manager.py b/src/cli/managers/config_manager.py index 7ed9f9405..1f6932abd 100644 --- a/src/cli/managers/config_manager.py +++ b/src/cli/managers/config_manager.py @@ -7,6 +7,7 @@ from src.cli.managers.templates_manager import BASE_CONFIG_TEMPLATE from src.cli.source_registry import source_registry +from src.utils.ab_testing import ABPool, ABPoolError, load_ab_pool_state from src.utils.logging import get_logger logger = get_logger(__name__) @@ -196,6 +197,38 @@ def _validate_chat_app_config(self, config: Dict[str, Any], services: List[str]) if timeout_value > 86400: raise ValueError(f"Invalid field: '{timeout_path}' must be <= 86400 seconds") + self._validate_ab_testing_config(chat_cfg) + + def _validate_ab_testing_config(self, chat_cfg: Dict[str, Any]) -> None: + ab_cfg = chat_cfg.get("ab_testing") + if not isinstance(ab_cfg, dict) or not ab_cfg.get("enabled", False): + return + state = load_ab_pool_state({"services": {"chat_app": chat_cfg}}) + for warning in state.warnings: + logger.warning("A/B testing config warning: %s", warning) + try: + ABPool.from_config(ab_cfg) + except ABPoolError as exc: + incomplete_markers = ( + "ab_testing.pool must be a mapping", + "ab_testing.pool.champion must be a non-empty string", + "ab_testing.pool.variants must be a non-empty list", + "at least 2 variants", + "not found in pool", + "must include a string 'label'", + "must include a string 'agent_spec'", + ) + if any(marker in str(exc) for marker in incomplete_markers): + logger.warning( + "A/B testing config is incomplete and will start inactive until configured in the admin UI: %s", + exc, + ) + return + raise ValueError( + "Invalid field: 'services.chat_app.ab_testing' is misconfigured. " + f"{exc}" + ) + def _validate_benchmarking_config(self, config: Dict[str, Any], services: List[str]) -> None: if not services or "benchmarking" not in services: return diff --git a/src/cli/managers/deployment_manager.py b/src/cli/managers/deployment_manager.py index 1fc42d6a0..303979da9 100644 --- a/src/cli/managers/deployment_manager.py +++ b/src/cli/managers/deployment_manager.py @@ -127,6 +127,57 @@ def restart_service(self, deployment_dir: Path, service_name: str, build: bool = raise except subprocess.SubprocessError as e: raise DeploymentError(f"Failed to restart service: {e}", getattr(e, 'returncode', 1)) + + def has_service(self, deployment_dir: Path, service_name: str) -> bool: + compose_file = deployment_dir / "compose.yaml" + if not compose_file.exists(): + raise FileNotFoundError(f"Compose file not found: {compose_file}") + self._validate_compose_file(compose_file) + import yaml + with open(compose_file, 'r') as f: + compose_data = yaml.safe_load(f) or {} + services = compose_data.get("services") or {} + return service_name in services + + def run_service_once(self, deployment_dir: Path, service_name: str, build: bool = True, + no_deps: bool = True) -> None: + """Run a one-shot compose service and remove the container when it exits.""" + compose_file = deployment_dir / "compose.yaml" + + if not compose_file.exists(): + raise FileNotFoundError(f"Compose file not found: {compose_file}") + + logger.info(f"Running one-shot service '{service_name}'") + + try: + self._validate_compose_file(compose_file) + except Exception as e: + raise DeploymentError(f"Invalid compose file: {e}", 1) + + flags = ["--rm"] + if no_deps: + flags.append("--no-deps") + if build: + flags.append("--build") + + flags_str = " ".join(flags) + compose_cmd = f"{self.compose_tool} -f {compose_file} run {flags_str} {service_name}".strip() + + try: + stdout, stderr, exit_code = CommandRunner.run_streaming(compose_cmd, cwd=deployment_dir) + + if exit_code != 0: + error_msg = f"One-shot service '{service_name}' failed with exit code {exit_code}" + if stderr.strip(): + error_msg += f"\nError output:\n{stderr}" + raise DeploymentError(error_msg, exit_code, stderr) + + logger.info(f"One-shot service '{service_name}' completed successfully") + except KeyboardInterrupt: + logger.warning("One-shot service interrupted by user") + raise + except subprocess.SubprocessError as e: + raise DeploymentError(f"Failed to run one-shot service: {e}", getattr(e, 'returncode', 1)) def delete_deployment(self, deployment_name: str, remove_images: bool = False, remove_volumes: bool = False, remove_files: bool = True) -> None: diff --git a/src/cli/managers/templates_manager.py b/src/cli/managers/templates_manager.py index cf40c4293..079b1d1aa 100644 --- a/src/cli/managers/templates_manager.py +++ b/src/cli/managers/templates_manager.py @@ -12,6 +12,7 @@ from src.cli.service_registry import service_registry from src.cli.utils.service_builder import DeploymentPlan from src.cli.utils.grafana_styling import assign_feedback_palette +from src.utils.ab_testing import DEFAULT_AB_AGENTS_DIR from src.utils.logging import get_logger logger = get_logger(__name__) @@ -25,6 +26,7 @@ BASE_GRAFANA_DASHBOARDS_TEMPLATE = "grafana/dashboards.yaml" BASE_GRAFANA_ARCHI_DEFAULT_DASHBOARDS_TEMPLATE = "grafana/archi-default-dashboard.json" BASE_GRAFANA_CONFIG_TEMPLATE = "grafana/grafana.ini" +DEPLOYMENT_AGENTS_DIR = "/root/archi/agents" def get_git_information() -> Dict[str, str]: @@ -158,6 +160,7 @@ def _stage_prompts(self, context: TemplateContext) -> None: def _stage_agents(self, context: TemplateContext) -> None: config = context.config_manager.config or {} dst_dir = context.base_dir / "data" / "agents" + ab_dst_dir = context.base_dir / "data" / "ab_agents" services_cfg = config.get("services", {}) or {} if context.benchmarking: @@ -184,17 +187,65 @@ def _stage_agents(self, context: TemplateContext) -> None: if dst_dir.exists() and any(p.suffix.lower() == ".md" for p in dst_dir.iterdir()): return raise ValueError("Missing required services.chat_app.agents_dir in config.") - src_dir = Path(agents_dir).expanduser() - if not src_dir.exists() or not src_dir.is_dir(): - raise ValueError(f"Agents directory not found: {src_dir}") - dst_dir.mkdir(parents=True, exist_ok=True) + src_dir = self._resolve_directory_path(str(agents_dir), config) + self._copy_markdown_directory( + src_dir, + dst_dir, + missing_message=f"Agents directory not found: {src_dir}", + empty_message=f"No agent markdown files found in {src_dir}", + required=True, + ) + + ab_dst_dir.mkdir(parents=True, exist_ok=True) + ab_cfg = ((services_cfg.get("chat_app") or {}).get("ab_testing") or {}) + ab_agents_dir = ab_cfg.get("ab_agents_dir") + if not ab_agents_dir: + return + ab_src_dir = self._resolve_directory_path(str(ab_agents_dir), config) + self._copy_markdown_directory( + ab_src_dir, + ab_dst_dir, + missing_message=f"A/B agents directory not found: {ab_src_dir}", + empty_message=f"No A/B agent markdown files found in {ab_src_dir}", + required=False, + ) + + @staticmethod + def _resolve_directory_path(raw_path: str, config: Dict[str, Any]) -> Path: + source_path = Path(str(raw_path)).expanduser() + config_path_raw = config.get("_config_path", "") + config_path = Path(str(config_path_raw)).expanduser() if config_path_raw else None + if source_path.is_absolute() or not config_path: + return source_path + candidate = (config_path.parent / source_path).resolve() + if candidate.exists(): + return candidate + return source_path + + @staticmethod + def _copy_markdown_directory( + source_dir: Path, + destination_dir: Path, + *, + missing_message: str, + empty_message: str, + required: bool, + ) -> None: + if not source_dir.exists() or not source_dir.is_dir(): + if required: + raise ValueError(missing_message) + logger.warning(missing_message) + return + destination_dir.mkdir(parents=True, exist_ok=True) copied = 0 - for agent_file in sorted(src_dir.iterdir()): - if agent_file.is_file() and agent_file.suffix.lower() == ".md": - shutil.copyfile(agent_file, dst_dir / agent_file.name) + for source_file in sorted(source_dir.iterdir()): + if source_file.is_file() and source_file.suffix.lower() == ".md": + shutil.copyfile(source_file, destination_dir / source_file.name) copied += 1 if copied == 0: - raise ValueError(f"No agent markdown files found in {src_dir}") + if required: + raise ValueError(empty_message) + logger.warning(empty_message) def _stage_skills(self, context: TemplateContext) -> None: config = context.config_manager.config or {} @@ -325,7 +376,8 @@ def _copy_pipeline_prompts( # config rendering def _render_config_files(self, context: TemplateContext) -> None: configs_path = context.base_dir / "configs" - configs_path.mkdir(exist_ok=True) + configs_path.mkdir(parents=True, exist_ok=True) + benchmarking_enabled = bool(getattr(context, "benchmarking", False)) archi_configs = context.config_manager.get_configs() single_mode = len(archi_configs) == 1 @@ -341,15 +393,19 @@ def _render_config_files(self, context: TemplateContext) -> None: for service_name in ("chat_app", "redmine_mailbox", "piazza"): service_cfg = services_cfg.get(service_name) if isinstance(service_cfg, dict): - service_cfg["agents_dir"] = "/root/archi/agents" + service_cfg["agents_dir"] = DEPLOYMENT_AGENTS_DIR if service_cfg.get("skills_dir"): service_cfg["skills_dir"] = "/root/archi/skills" - if context.benchmarking: + if service_name == "chat_app": + ab_cfg = service_cfg.get("ab_testing") + if isinstance(ab_cfg, dict) and ab_cfg.get("ab_agents_dir"): + ab_cfg["ab_agents_dir"] = DEFAULT_AB_AGENTS_DIR + if benchmarking_enabled: benchmark_cfg = services_cfg.get("benchmarking") if isinstance(benchmark_cfg, dict): agent_md_file = benchmark_cfg.get("agent_md_file") if agent_md_file: - benchmark_cfg["agent_md_file"] = f"/root/archi/agents/{Path(str(agent_md_file)).name}" + benchmark_cfg["agent_md_file"] = f"{DEPLOYMENT_AGENTS_DIR}/{Path(str(agent_md_file)).name}" config_template = self.env.get_template(BASE_CONFIG_TEMPLATE) config_rendered = config_template.render(verbosity=context.plan.verbosity, **updated_config) diff --git a/src/cli/templates/base-compose.yaml b/src/cli/templates/base-compose.yaml index 218ff7f86..3ff54dc2a 100644 --- a/src/cli/templates/base-compose.yaml +++ b/src/cli/templates/base-compose.yaml @@ -161,6 +161,7 @@ services: - ./configs:/root/archi/configs - ./data/prompts:/root/archi/data/prompts:ro - ./data/agents:/root/archi/agents + - ./data/ab_agents:/root/archi/ab_agents - ./data/skills:/root/archi/skills:ro {% for prompt_file in prompt_files | default([]) -%} - ./{{ prompt_file }}:/root/archi/{{ prompt_file }} diff --git a/src/cli/templates/base-config.yaml b/src/cli/templates/base-config.yaml index ef6d21fee..83a3402b1 100644 --- a/src/cli/templates/base-config.yaml +++ b/src/cli/templates/base-config.yaml @@ -105,7 +105,52 @@ services: managers: {%- for manager in services.chat_app.alerts.managers | default([], true) %} - {{ manager }} - {%- endfor %} + {%- endfor %} + {%- if services.chat_app.ab_testing is defined and services.chat_app.ab_testing.enabled | default(false) %} + ab_testing: + enabled: {{ services.chat_app.ab_testing.enabled | default(false, true) }} + {%- if services.chat_app.ab_testing.ab_agents_dir is defined %} + ab_agents_dir: "{{ services.chat_app.ab_testing.ab_agents_dir }}" + {%- endif %} + {%- if services.chat_app.ab_testing.force_yaml_override is defined %} + force_yaml_override: {{ services.chat_app.ab_testing.force_yaml_override | default(false, true) }} + {%- endif %} + comparison_rate: {{ services.chat_app.ab_testing.comparison_rate | default(services.chat_app.ab_testing.sample_rate | default(0.2, true), true) }} + variant_label_mode: {{ services.chat_app.ab_testing.variant_label_mode | default(services.chat_app.ab_testing.disclosure_mode | default("post_vote_reveal", true), true) }} + activity_panel_default_state: {{ services.chat_app.ab_testing.activity_panel_default_state | default(services.chat_app.ab_testing.default_trace_mode | default("hidden", true), true) }} + max_pending_comparisons_per_conversation: {{ services.chat_app.ab_testing.max_pending_comparisons_per_conversation | default(services.chat_app.ab_testing.max_pending_per_conversation | default(1, true), true) }} + {%- if services.chat_app.ab_testing.eligible_roles is defined or services.chat_app.ab_testing.target_roles is defined %} + eligible_roles: {{ services.chat_app.ab_testing.eligible_roles | default(services.chat_app.ab_testing.target_roles, true) | tojson }} + {%- endif %} + {%- if services.chat_app.ab_testing.eligible_permissions is defined or services.chat_app.ab_testing.target_permissions is defined %} + eligible_permissions: {{ services.chat_app.ab_testing.eligible_permissions | default(services.chat_app.ab_testing.target_permissions, true) | tojson }} + {%- endif %} + {%- if services.chat_app.ab_testing.pool is defined %} + pool: + {%- if services.chat_app.ab_testing.pool.champion is defined or services.chat_app.ab_testing.pool.control is defined %} + champion: {{ services.chat_app.ab_testing.pool.champion | default(services.chat_app.ab_testing.pool.control, true) }} + {%- endif %} + {%- if services.chat_app.ab_testing.pool.variants is defined %} + variants: + {%- for v in services.chat_app.ab_testing.pool.variants %} + - label: {{ v.label }} + agent_spec: "{{ v.agent_spec }}" + {%- if v.provider is defined %} + provider: {{ v.provider }} + {%- endif %} + {%- if v.model is defined %} + model: "{{ v.model }}" + {%- endif %} + {%- if v.num_documents_to_retrieve is defined %} + num_documents_to_retrieve: {{ v.num_documents_to_retrieve }} + {%- endif %} + {%- if v.recursion_limit is defined %} + recursion_limit: {{ v.recursion_limit }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endif %} + {%- endif %} data_manager: auth: enabled: {{ services.data_manager.auth.enabled | default(false) }} @@ -204,7 +249,7 @@ data_manager: verify_urls: {{ data_manager.sources.links.html_scraper.verify_urls | default(false, true) }} enable_warnings: {{ data_manager.sources.links.html_scraper.enable_warnings | default(false, true) }} selenium_scraper: - enabled: {{ data_manager.sources.links.selenium_scraper.selenium_scraper.enabled | default(false, True) }} + enabled: {{ data_manager.sources.links.selenium_scraper.selenium_scraper.enabled | default(false, true) }} visible: {{ data_manager.sources.links.selenium_scraper.selenium_scraper.visible | default(false, true) }} use_for_scraping: {{ data_manager.sources.links.selenium_scraper.use_for_scraping | default(false, true) }} selenium_class: {{ data_manager.sources.links.selenium_scraper.selenium_class | default('CERNSSOScraper', true) }} diff --git a/src/cli/templates/init.sql b/src/cli/templates/init.sql index 1334fc23c..60793c94d 100644 --- a/src/cli/templates/init.sql +++ b/src/cli/templates/init.sql @@ -46,6 +46,7 @@ CREATE TABLE IF NOT EXISTS users ( theme VARCHAR(20) NOT NULL DEFAULT 'system', preferred_model VARCHAR(200), -- Override global default preferred_temperature NUMERIC(3,2), -- Override global default + ab_participation_rate NUMERIC(3,2), -- Per-user A/B sampling override preferred_max_tokens INTEGER, -- Override global default preferred_num_documents INTEGER, -- Override retrieval count preferred_condense_prompt VARCHAR(100), -- Prompt selection @@ -503,6 +504,12 @@ CREATE TABLE IF NOT EXISTS ab_comparisons ( config_a_id INTEGER REFERENCES configs(config_id), config_b_id INTEGER REFERENCES configs(config_id), + -- Pool-based variant info (populated when ab_testing pool is active) + variant_a_name VARCHAR(200), + variant_b_name VARCHAR(200), + variant_a_meta JSONB, + variant_b_meta JSONB, + is_config_a_first BOOLEAN NOT NULL, preference VARCHAR(10), preference_ts TIMESTAMPTZ, @@ -513,6 +520,48 @@ CREATE INDEX IF NOT EXISTS idx_ab_comparisons_conversation ON ab_comparisons(con CREATE INDEX IF NOT EXISTS idx_ab_comparisons_models ON ab_comparisons(model_a, model_b); CREATE INDEX IF NOT EXISTS idx_ab_comparisons_preference ON ab_comparisons(preference) WHERE preference IS NOT NULL; CREATE INDEX IF NOT EXISTS idx_ab_comparisons_pending ON ab_comparisons(conversation_id) WHERE preference IS NULL; +CREATE INDEX IF NOT EXISTS idx_ab_comparisons_variant_a ON ab_comparisons(variant_a_name) WHERE variant_a_name IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_ab_comparisons_variant_b ON ab_comparisons(variant_b_name) WHERE variant_b_name IS NOT NULL; + +CREATE TABLE IF NOT EXISTS ab_agent_specs ( + spec_id SERIAL PRIMARY KEY, + filename VARCHAR(255) NOT NULL UNIQUE, + current_name VARCHAR(255) NOT NULL UNIQUE, + current_version_id INTEGER, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_saved_by VARCHAR(200) +); + +CREATE TABLE IF NOT EXISTS ab_agent_spec_versions ( + version_id SERIAL PRIMARY KEY, + spec_id INTEGER NOT NULL REFERENCES ab_agent_specs(spec_id) ON DELETE CASCADE, + version_number INTEGER NOT NULL, + name VARCHAR(255) NOT NULL, + tools TEXT[] NOT NULL DEFAULT '{}', + prompt TEXT NOT NULL, + content TEXT NOT NULL, + ab_only BOOLEAN NOT NULL DEFAULT FALSE, + content_hash VARCHAR(64) NOT NULL, + prompt_hash VARCHAR(64) NOT NULL, + source_type VARCHAR(50) NOT NULL DEFAULT 'ui', + source_path TEXT, + created_by VARCHAR(200), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (spec_id, version_number) +); + +CREATE INDEX IF NOT EXISTS idx_ab_agent_spec_versions_spec ON ab_agent_spec_versions(spec_id, version_number DESC); + +-- Per-variant aggregate metrics (wins/losses/ties) +CREATE TABLE IF NOT EXISTS ab_variant_metrics ( + variant_name VARCHAR(200) PRIMARY KEY, + wins INTEGER NOT NULL DEFAULT 0, + losses INTEGER NOT NULL DEFAULT 0, + ties INTEGER NOT NULL DEFAULT 0, + total_comparisons INTEGER NOT NULL DEFAULT 0, + last_updated TIMESTAMP NOT NULL DEFAULT NOW() +); -- ============================================================================ -- 9. MIGRATION STATE (for resumable migrations) @@ -576,6 +625,7 @@ GRANT SELECT ON timing, agent_tool_calls, ab_comparisons, + ab_variant_metrics, migration_state TO grafana; {% endif %} diff --git a/src/cli/tools/config_seed.py b/src/cli/tools/config_seed.py index 81b2e29af..3f7d53978 100644 --- a/src/cli/tools/config_seed.py +++ b/src/cli/tools/config_seed.py @@ -15,6 +15,7 @@ import os import sys import yaml +from typing import Any, Dict, Optional from src.utils.postgres_service_factory import PostgresServiceFactory from src.utils.config_service import ConfigService @@ -25,6 +26,39 @@ def load_config(path: str): return yaml.safe_load(f) +def _copy_mapping(value): + if isinstance(value, dict): + return {k: _copy_mapping(v) for k, v in value.items()} + if isinstance(value, list): + return [_copy_mapping(item) for item in value] + return value + + +def _resolve_seeded_services_config( + services: Dict[str, Any], + existing_services: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + seeded_services = _copy_mapping(services or {}) + chat_cfg = seeded_services.setdefault("chat_app", {}) + yaml_ab_cfg = chat_cfg.get("ab_testing") + existing_chat_cfg = ((existing_services or {}).get("chat_app") or {}) + existing_ab_cfg = existing_chat_cfg.get("ab_testing") + + if not isinstance(yaml_ab_cfg, dict): + if isinstance(existing_ab_cfg, dict) and existing_ab_cfg: + chat_cfg["ab_testing"] = _copy_mapping(existing_ab_cfg) + return seeded_services + + force_yaml_override = bool(yaml_ab_cfg.get("force_yaml_override", False)) + yaml_ab_cfg = {k: _copy_mapping(v) for k, v in yaml_ab_cfg.items() if k != "force_yaml_override"} + chat_cfg["ab_testing"] = yaml_ab_cfg + + if isinstance(existing_ab_cfg, dict) and existing_ab_cfg and not force_yaml_override: + chat_cfg["ab_testing"] = _copy_mapping(existing_ab_cfg) + + return seeded_services + + def seed(config: dict, cs: ConfigService): print("[config-seed] Starting seed with config keys:", list(config.keys())) dm = config.get("data_manager", {}) @@ -33,15 +67,20 @@ def seed(config: dict, cs: ConfigService): mcp_servers = config.get("mcp_servers", {}) or {} archi_cfg = {**archi_cfg} global_cfg = config.get("global", {}) + current_static = cs.get_static_config(force_reload=True) + seeded_services = _resolve_seeded_services_config( + services, + current_static.services_config if current_static else None, + ) # Embedding dimensions fallback TODO why is this here? embedding_name = dm.get("embedding_name", "HuggingFaceEmbeddings") embedding_class_map = dm.get("embedding_class_map", {}) embedding_dimensions = embedding_class_map.get(embedding_name, {}).get("dimensions", 384) - agent_class = services.get("chat_app", {}).get("agent_class") - provider = services.get("chat_app", {}).get("provider") - model = services.get("chat_app", {}).get("model") + agent_class = seeded_services.get("chat_app", {}).get("agent_class") + provider = seeded_services.get("chat_app", {}).get("provider") + model = seeded_services.get("chat_app", {}).get("model") available_pipelines = [agent_class] if agent_class else [] available_models = [f"{provider}/{model}"] if provider and model else [] available_providers = [provider] if provider else [] @@ -57,9 +96,9 @@ def seed(config: dict, cs: ConfigService): available_pipelines=available_pipelines, available_models=available_models, available_providers=available_providers, - auth_enabled=services.get("chat_app", {}).get("auth", {}).get("enabled", False), + auth_enabled=seeded_services.get("chat_app", {}).get("auth", {}).get("enabled", False), sources_config=dm.get("sources", {}), - services_config=services, + services_config=seeded_services, mcp_servers_config=mcp_servers, data_manager_config=dm, archi_config=archi_cfg, @@ -75,7 +114,7 @@ def seed(config: dict, cs: ConfigService): hybrid = retrievers.get("hybrid_retriever", {}) active_model = f"{provider}/{model}" if provider and model else None cs.update_dynamic_config( - active_pipeline=services.get("chat_app", {}).get("agent_class", "CMSCompOpsAgent"), + active_pipeline=seeded_services.get("chat_app", {}).get("agent_class", "CMSCompOpsAgent"), active_model=active_model, num_documents_to_retrieve=hybrid.get("num_documents_to_retrieve", 10), bm25_weight=hybrid.get("bm25_weight", 0.3), diff --git a/src/cli/utils/helpers.py b/src/cli/utils/helpers.py index 4dd960df9..d573719d3 100644 --- a/src/cli/utils/helpers.py +++ b/src/cli/utils/helpers.py @@ -192,7 +192,7 @@ def _render_config_for_compare( updated_config = copy.deepcopy(config) if host_mode: updated_config["host_mode"] = True - TemplateManager(env)._apply_host_mode_port_overrides(updated_config) + TemplateManager(env, verbosity)._apply_host_mode_port_overrides(updated_config) config_template = env.get_template("base-config.yaml") rendered = config_template.render(verbosity=verbosity, **updated_config) diff --git a/src/interfaces/chat_app/api.py b/src/interfaces/chat_app/api.py index 76cdcdc51..3e1e5a305 100644 --- a/src/interfaces/chat_app/api.py +++ b/src/interfaces/chat_app/api.py @@ -13,7 +13,7 @@ from functools import wraps from typing import List, Optional -from flask import Blueprint, jsonify, request, g, current_app +from flask import Blueprint, jsonify, request, g, current_app, session from src.utils.postgres_service_factory import PostgresServiceFactory from src.utils.env import read_secret @@ -46,10 +46,10 @@ def _get_agent_tool_registry(agent_class_name: Optional[str]) -> List[str]: return [] try: from src.archi import pipelines + agent_cls = getattr(pipelines, agent_class_name, None) except Exception as exc: - logger.warning("Failed to import pipelines module: %s", exc) + logger.warning("Failed to load pipeline class %s: %s", agent_class_name, exc) return [] - agent_cls = getattr(pipelines, agent_class_name, None) if not agent_cls or not hasattr(agent_cls, "get_tool_registry"): return [] try: @@ -116,8 +116,11 @@ def get_services() -> PostgresServiceFactory: def get_client_id() -> str: """Get client ID from request (session, header, or generate).""" + user = session.get('user') or {} + if user.get('id'): + return user['id'] + # Check session first - from flask import session if 'client_id' in session: return session['client_id'] @@ -165,9 +168,12 @@ def get_current_user(): """ try: services = get_services() + session_user = session.get('user') or {} user = services.user_service.get_or_create_user( user_id=g.client_id, - auth_provider='anonymous', + auth_provider=session_user.get('auth_method', 'anonymous') or 'anonymous', + display_name=session_user.get('name'), + email=session_user.get('email'), ) return jsonify({ @@ -177,11 +183,12 @@ def get_current_user(): 'auth_provider': user.auth_provider, 'theme': user.theme, 'preferred_model': user.preferred_model, - 'preferred_temperature': float(user.preferred_temperature) if user.preferred_temperature else None, - 'has_openrouter_key': user.api_key_openrouter is not None, - 'has_openai_key': user.api_key_openai is not None, - 'has_anthropic_key': user.api_key_anthropic is not None, - 'created_at': user.created_at.isoformat() if user.created_at else None, + 'preferred_temperature': float(user.preferred_temperature) if user.preferred_temperature is not None else None, + 'ab_participation_rate': float(user.ab_participation_rate) if user.ab_participation_rate is not None else None, + 'has_openrouter_key': bool(user.api_key_openrouter), + 'has_openai_key': bool(user.api_key_openai), + 'has_anthropic_key': bool(user.api_key_anthropic), + 'created_at': user.created_at, }), 200 except Exception as e: @@ -212,6 +219,11 @@ def update_user_preferences(): temp = data['preferred_temperature'] if temp is not None and (temp < 0 or temp > 2): return jsonify({'error': 'Temperature must be between 0 and 2'}), 400 + + if 'ab_participation_rate' in data: + rate = data['ab_participation_rate'] + if rate is not None and (rate < 0 or rate > 1): + return jsonify({'error': 'A/B participation rate must be between 0 and 1'}), 400 services = get_services() user = services.user_service.update_preferences( @@ -220,6 +232,7 @@ def update_user_preferences(): theme=data.get('theme'), preferred_model=data.get('preferred_model'), preferred_temperature=data.get('preferred_temperature'), + ab_participation_rate=data.get('ab_participation_rate'), ) if not user: @@ -230,8 +243,9 @@ def update_user_preferences(): 'display_name': user.display_name, 'theme': user.theme, 'preferred_model': user.preferred_model, - 'preferred_temperature': float(user.preferred_temperature) if user.preferred_temperature else None, - 'updated_at': user.updated_at.isoformat() if user.updated_at else None, + 'preferred_temperature': float(user.preferred_temperature) if user.preferred_temperature is not None else None, + 'ab_participation_rate': float(user.ab_participation_rate) if user.ab_participation_rate is not None else None, + 'updated_at': user.updated_at, }), 200 except Exception as e: diff --git a/src/interfaces/chat_app/app.py b/src/interfaces/chat_app/app.py index 6c3e877dc..14c9d9248 100644 --- a/src/interfaces/chat_app/app.py +++ b/src/interfaces/chat_app/app.py @@ -1,5 +1,6 @@ import json import os +import random import re import time import uuid @@ -31,6 +32,7 @@ from src.archi.archi import archi from src.archi.pipelines.agents.agent_spec import ( + AgentSpec, AgentSpecError, list_agent_files, load_agent_spec, @@ -39,7 +41,6 @@ slugify_agent_name, ) from src.archi.providers.base import ModelInfo, ProviderConfig, ProviderType -from src.utils.config_service import ConfigService from src.archi.utils.output_dataclass import PipelineOutput # from src.data_manager.data_manager import DataManager from src.data_manager.data_viewer_service import DataViewerService @@ -47,7 +48,7 @@ from src.utils.env import read_secret from src.utils.logging import get_logger from src.utils.config_access import get_full_config, get_services_config, get_global_config, get_dynamic_config -from src.utils.config_service import ConfigService +from src.utils.config_service import ConfigService, StaticConfig from src.utils.sql import ( SQL_INSERT_CONVO, SQL_INSERT_FEEDBACK, SQL_INSERT_TIMING, SQL_QUERY_CONVO, SQL_CREATE_CONVERSATION, SQL_UPDATE_CONVERSATION_TIMESTAMP, @@ -56,8 +57,6 @@ SQL_DELETE_CONVERSATION_BY_USER, SQL_UPDATE_CONVERSATION_TIMESTAMP_BY_USER, SQL_INSERT_TOOL_CALLS, SQL_QUERY_CONVO_WITH_FEEDBACK, SQL_DELETE_REACTION_FEEDBACK, SQL_GET_REACTION_FEEDBACK, - SQL_INSERT_AB_COMPARISON, SQL_UPDATE_AB_PREFERENCE, SQL_GET_AB_COMPARISON, - SQL_GET_PENDING_AB_COMPARISON, SQL_DELETE_AB_COMPARISON, SQL_GET_AB_COMPARISONS_BY_CONVERSATION, SQL_CREATE_AGENT_TRACE, SQL_UPDATE_AGENT_TRACE, SQL_GET_AGENT_TRACE, SQL_GET_TRACE_BY_MESSAGE, SQL_GET_ACTIVE_TRACE, SQL_CANCEL_ACTIVE_TRACES, ) @@ -66,7 +65,22 @@ register_service_alerts, get_active_banner_alerts, is_alert_manager, ) from src.interfaces.chat_app.utils import collapse_assistant_sequences +from src.utils.ab_testing import ( + ABPool, + ABPoolLoadState, + ABVariant, + ABPoolError, + DEFAULT_DISCLOSURE_MODE, + DEFAULT_TRACE_MODE, + load_ab_pool_state, + normalize_ab_disclosure_mode, + normalize_ab_trace_mode, + resolve_ab_agents_dir, +) +from src.interfaces.chat_app.event_formatter import PipelineEventFormatter +from src.utils.conversation_service import ConversationService from src.utils.user_service import UserService +from src.utils.ab_agent_spec_service import ABAgentSpecService, ABAgentSpecRecord # RBAC imports for role-based access control from src.utils.rbac import ( @@ -74,17 +88,54 @@ get_registry, get_user_roles, has_permission, + get_user_permissions, require_permission, require_any_permission, require_authenticated, ) -from src.utils.rbac.permissions import get_permission_context +from src.utils.rbac.permissions import get_permission_context, is_admin as rbac_is_admin from src.utils.rbac.audit import log_authentication_event logger = get_logger(__name__) +def _static_config_to_full_config( + static: StaticConfig, + *, + resolve_embeddings: bool = False, + config_service: Optional[ConfigService] = None, +) -> Dict[str, Any]: + """ + Build a full runtime config dict directly from a freshly loaded StaticConfig. + + This avoids routing post-write refreshes back through config_access helpers, + which may read from a different cached ConfigService instance. + """ + data_manager_config = dict(static.data_manager_config or {}) + if resolve_embeddings and config_service is not None: + try: + resolved_map = config_service.get_embedding_class_map(resolved=True) + if resolved_map: + data_manager_config["embedding_class_map"] = resolved_map + except Exception: + pass + + return { + "name": static.deployment_name, + "config_version": static.config_version, + "global": static.global_config, + "services": static.services_config, + "data_manager": data_manager_config, + "archi": static.archi_config, + "sources": static.sources_config, + "mcp_servers": static.mcp_servers_config or {}, + "available_pipelines": static.available_pipelines, + "available_models": static.available_models, + "available_providers": static.available_providers, + } + + def _build_provider_config_from_payload(config_payload: Dict[str, Any], provider_type: ProviderType) -> Optional[ProviderConfig]: """Helper to build ProviderConfig from loaded YAML for a provider.""" services_cfg = config_payload.get("services", {}) or {} @@ -221,6 +272,13 @@ class ChatRequestContext: class ChatWrapper: + AUTO_SOURCE_SECTION_LABEL = "Retrieved documents" + AUTO_SOURCE_SECTION_EXPLANATION = "These are the knowledge-base documents retrieved for this answer." + _AUTO_SOURCE_SECTION_PATTERN = re.compile( + r"(Show all sources|Retrieved documents|Sources cited in this answer)\s*\(\d+\)", + flags=re.IGNORECASE, + ) + """ Wrapper which holds functionality for the chatbot """ @@ -240,6 +298,7 @@ def __init__(self): "password": read_secret("PG_PASSWORD"), **self.services_config["postgres"], } + self.config_service = ConfigService(pg_config=self.pg_config) # initialize data manager (ingestion handled by data-manager service) # self.data_manager = DataManager(run_ingestion=False) @@ -259,6 +318,11 @@ def __init__(self): # initialize data viewer service for per-chat document selection self.data_viewer = DataViewerService(data_path=self.data_path, pg_config=self.pg_config) + # shared conversation service for A/B comparisons & metrics + self.conv_service = ConversationService(connection_params=self.pg_config) + self.user_service = UserService(pg_config=self.pg_config) + self.ab_agent_spec_service = ABAgentSpecService(pg_config=self.pg_config) + self.conn = None self.cursor = None @@ -281,7 +345,7 @@ def __init__(self): if self.current_agent_path and self.current_agent_path.exists(): self.current_agent_mtime = self.current_agent_path.stat().st_mtime - agent_class = chat_cfg.get("agent_class") or chat_cfg.get("pipeline") + agent_class = self._get_agent_class_from_cfg(chat_cfg) if not agent_class: raise ValueError("services.chat_app.agent_class must be configured.") default_provider = chat_cfg.get("default_provider") @@ -317,6 +381,147 @@ def __init__(self): if self.default_config_name: self.update_config(config_name=self.default_config_name) + # A/B testing pool (loaded from config; None if not configured) + self.refresh_ab_pool() + + def reload_static_state(self) -> None: + """ + Reload static config snapshots used by the chat wrapper. + + This is primarily used after runtime updates to the persisted chat A/B + configuration so the active process picks up the latest pool settings. + """ + static = self.config_service.get_static_config(force_reload=True) + if static is None: + raise ValueError("Static config not initialized") + self.config = _static_config_to_full_config(static, config_service=self.config_service) + self.global_config = self.config["global"] + self.services_config = self.config["services"] + self.data_path = self.global_config["DATA_PATH"] + self.sources_config = self.config["data_manager"]["sources"] + self.refresh_ab_pool() + + def refresh_ab_pool(self) -> None: + import_diagnostics = self._sync_ab_agent_specs_from_filesystem() + state = load_ab_pool_state( + self.config, + agent_spec_exists=self.ab_agent_spec_service.spec_exists, + ) + warnings = list(import_diagnostics.get("warnings", [])) + warnings.extend(list(getattr(state, "warnings", []) or [])) + self.ab_agent_import_diagnostics = import_diagnostics + self.ab_pool_state = ABPoolLoadState( + pool=state.pool, + warnings=warnings, + enabled_requested=state.enabled_requested, + agent_dir=state.agent_dir, + agent_dir_configured=state.agent_dir_configured, + ) + self.ab_pool = self.ab_pool_state.pool + for warning in self.ab_pool_state.warnings: + logger.warning("%s", warning) + if self.ab_pool: + logger.info( + "A/B pool active: %d variants, champion='%s'", + len(self.ab_pool.variants), self.ab_pool.champion_name, + ) + + def _get_ab_agents_dir(self) -> Path: + chat_cfg = self.services_config.get("chat_app", {}) or {} + path, _ = resolve_ab_agents_dir(chat_cfg) + return path + + def _sync_ab_agent_specs_from_filesystem(self) -> Dict[str, Any]: + """ + Import legacy A/B markdown specs into the DB-backed catalog. + + The database remains the runtime source of truth after import. + """ + directory = self._get_ab_agents_dir() + diagnostics: Dict[str, Any] = { + "directory": str(directory), + "source_exists": directory.exists() and directory.is_dir(), + "imported": 0, + "updated": 0, + "skipped": 0, + "conflicts": [], + "staged_unresolved": [], + "warnings": [], + } + try: + result = self.ab_agent_spec_service.import_directory( + directory, + created_by="system", + ) + diagnostics.update({ + "imported": int(result.get("imported", 0)), + "updated": int(result.get("updated", 0)), + "skipped": int(result.get("skipped", 0)), + "conflicts": list(result.get("conflicts", []) or []), + }) + if result["imported"] or result["updated"]: + logger.info( + "Imported A/B agent specs into DB: imported=%d updated=%d skipped=%d", + result["imported"], + result["updated"], + result["skipped"], + ) + for conflict in result["conflicts"]: + logger.warning("A/B agent import conflict: %s", conflict) + except Exception as exc: + logger.warning("Failed to sync A/B agent specs from filesystem: %s", exc) + diagnostics["conflicts"].append(str(exc)) + + for conflict in diagnostics["conflicts"]: + diagnostics["warnings"].append(f"A/B agent import conflict: {conflict}") + + chat_cfg = self.services_config.get("chat_app", {}) or {} + ab_cfg = (chat_cfg.get("ab_testing") or {}) if isinstance(chat_cfg.get("ab_testing"), dict) else {} + try: + configured_pool = ABPool.from_config(ab_cfg) if ab_cfg.get("enabled") else None + except ABPoolError: + configured_pool = None + + if configured_pool: + for variant in configured_pool.variants: + if self.ab_agent_spec_service.spec_exists(variant.agent_spec): + continue + disk_path = directory / variant.agent_spec + if disk_path.exists(): + diagnostics["staged_unresolved"].append(variant.agent_spec) + + if diagnostics["staged_unresolved"]: + unresolved = sorted(set(diagnostics["staged_unresolved"])) + diagnostics["warnings"].append( + "A/B agent specs are present in the staged import directory but unresolved in PostgreSQL after import: " + f"{unresolved}." + ) + + return diagnostics + + @staticmethod + def _variant_with_spec_record(variant: "ABVariant", record: ABAgentSpecRecord) -> ABVariant: + return ABVariant( + label=variant.label, + agent_spec=record.filename, + provider=variant.provider, + model=variant.model, + num_documents_to_retrieve=variant.num_documents_to_retrieve, + recursion_limit=variant.recursion_limit, + agent_spec_id=record.spec_id, + agent_spec_name=record.name, + agent_spec_version_id=record.version_id, + agent_spec_version_number=record.version_number, + agent_spec_content_hash=record.content_hash, + agent_spec_tools=list(record.tools), + agent_spec_prompt_hash=record.prompt_hash, + ) + + def _resolve_runtime_ab_variant(self, variant: "ABVariant") -> tuple["ABVariant", AgentSpec]: + record = self.ab_agent_spec_service.load_agent_spec(variant.agent_spec) + resolved = self._variant_with_spec_record(variant, record) + return resolved, record.to_agent_spec() + def update_config(self, config_name=None): """ Update the active config and apply it to the pipeline. @@ -358,7 +563,7 @@ def update_config(self, config_name=None): if self.current_config_name == target_config_name and not agent_changed: return - agent_class = chat_cfg.get("agent_class") or chat_cfg.get("pipeline") + agent_class = self._get_agent_class_from_cfg(chat_cfg) if not agent_class: raise ValueError("services.chat_app.agent_class must be configured.") is_enabled, disabled_reason = _is_provider_enabled_in_config( @@ -493,15 +698,10 @@ def get_top_sources(self, documents, scores): @staticmethod def _format_source_entry(entry): - score = entry["score"] + score_str = ChatWrapper._format_score_str(entry["score"]) link = entry["link"] display_name = entry["display"] - if score == -1.0 or score == "N/A": - score_str = "" - else: - score_str = f" ({score:.2f})" - if link: return f"- [{display_name}]({link}){score_str}\n" return f"- {display_name}{score_str}\n" @@ -524,15 +724,10 @@ def format_links(top_sources): ''' def _entry_html(entry): - score = entry["score"] + score_str = ChatWrapper._format_score_str(entry["score"]).strip() link = entry["link"] display_name = entry["display"] - if score == -1.0 or score == "N/A": - score_str = "" - else: - score_str = f"({score:.2f})" - if link: reference_html = f"{display_name}" else: @@ -546,7 +741,15 @@ def _entry_html(entry): ''' - _output += f'
Show all sources ({len(top_sources)})' + _output += ( + f'
' + f'{ChatWrapper.AUTO_SOURCE_SECTION_EXPLANATION}' + f'
' + ) + _output += ( + f'
{ChatWrapper.AUTO_SOURCE_SECTION_LABEL} ({len(top_sources)})' + ) for entry in top_sources: _output += _entry_html(entry) _output += '
' @@ -560,13 +763,29 @@ def format_links_markdown(top_sources): if not top_sources: return "" - _output = f"\n\n---\n
Show all sources ({len(top_sources)})\n\n" + _output = ( + "\n\n---\n" + f"*{ChatWrapper.AUTO_SOURCE_SECTION_EXPLANATION}*\n\n" + f"
{ChatWrapper.AUTO_SOURCE_SECTION_LABEL} ({len(top_sources)})\n\n" + ) for entry in top_sources: _output += ChatWrapper._format_source_entry(entry) _output += "\n
\n" return _output + @classmethod + def _contains_source_section(cls, output: str) -> bool: + return bool(output and cls._AUTO_SOURCE_SECTION_PATTERN.search(output)) + + @classmethod + def append_source_section(cls, output: str, top_sources, *, render_markdown: bool) -> str: + if not top_sources or cls._contains_source_section(output): + return output + if render_markdown: + return output + cls.format_links(top_sources) + return output + cls.format_links_markdown(top_sources) + @staticmethod def _looks_like_url(value: str | None) -> bool: return isinstance(value, str) and value.startswith(("http://", "https://")) @@ -670,189 +889,6 @@ def get_reaction_feedback(self, message_id: int): self.cursor, self.conn = None, None return row[0] if row else None - # ========================================================================= - # A/B Comparison Methods - # ========================================================================= - - def create_ab_comparison( - self, - conversation_id: int, - user_prompt_mid: int, - response_a_mid: int, - response_b_mid: int, - config_a_id: int, - config_b_id: int, - is_config_a_first: bool, - ) -> int: - """ - Create an A/B comparison record linking two responses to the same user prompt. - - Args: - conversation_id: The conversation this comparison belongs to - user_prompt_mid: Message ID of the user's question - response_a_mid: Message ID of response A - response_b_mid: Message ID of response B - config_a_id: Config ID used for response A - config_b_id: Config ID used for response B - is_config_a_first: True if config A was the "first" config before randomization - - Returns: - The comparison_id of the newly created record - """ - conn = psycopg2.connect(**self.pg_config) - cursor = conn.cursor() - try: - cursor.execute( - SQL_INSERT_AB_COMPARISON, - (conversation_id, user_prompt_mid, response_a_mid, response_b_mid, - config_a_id, config_b_id, is_config_a_first) - ) - comparison_id = cursor.fetchone()[0] - conn.commit() - logger.info(f"Created A/B comparison {comparison_id} for conversation {conversation_id}") - return comparison_id - finally: - cursor.close() - conn.close() - - def update_ab_preference(self, comparison_id: int, preference: str) -> None: - """ - Record user's preference for an A/B comparison. - - Args: - comparison_id: The comparison to update - preference: 'a', 'b', or 'tie' - """ - if preference not in ('a', 'b', 'tie'): - raise ValueError(f"Invalid preference: {preference}") - - conn = psycopg2.connect(**self.pg_config) - cursor = conn.cursor() - try: - cursor.execute( - SQL_UPDATE_AB_PREFERENCE, - (preference, datetime.now(timezone.utc), comparison_id) - ) - conn.commit() - logger.info(f"Updated A/B comparison {comparison_id} with preference '{preference}'") - finally: - cursor.close() - conn.close() - - def get_ab_comparison(self, comparison_id: int) -> Optional[Dict[str, Any]]: - """ - Get an A/B comparison by ID. - - Returns: - Dict with comparison data or None if not found - """ - conn = psycopg2.connect(**self.pg_config) - cursor = conn.cursor() - try: - cursor.execute(SQL_GET_AB_COMPARISON, (comparison_id,)) - row = cursor.fetchone() - if row is None: - return None - return { - 'comparison_id': row[0], - 'conversation_id': row[1], - 'user_prompt_mid': row[2], - 'response_a_mid': row[3], - 'response_b_mid': row[4], - 'config_a_id': row[5], - 'config_b_id': row[6], - 'is_config_a_first': row[7], - 'preference': row[8], - 'preference_ts': row[9].isoformat() if row[9] else None, - 'created_at': row[10].isoformat() if row[10] else None, - } - finally: - cursor.close() - conn.close() - - def get_pending_ab_comparison(self, conversation_id: int) -> Optional[Dict[str, Any]]: - """ - Get the most recent incomplete A/B comparison for a conversation. - - Returns: - Dict with comparison data or None if no pending comparison - """ - conn = psycopg2.connect(**self.pg_config) - cursor = conn.cursor() - try: - cursor.execute(SQL_GET_PENDING_AB_COMPARISON, (conversation_id,)) - row = cursor.fetchone() - if row is None: - return None - return { - 'comparison_id': row[0], - 'conversation_id': row[1], - 'user_prompt_mid': row[2], - 'response_a_mid': row[3], - 'response_b_mid': row[4], - 'config_a_id': row[5], - 'config_b_id': row[6], - 'is_config_a_first': row[7], - 'preference': row[8], - 'preference_ts': row[9].isoformat() if row[9] else None, - 'created_at': row[10].isoformat() if row[10] else None, - } - finally: - cursor.close() - conn.close() - - def delete_ab_comparison(self, comparison_id: int) -> bool: - """ - Delete an A/B comparison (e.g., on abort/failure). - - Returns: - True if a record was deleted, False otherwise - """ - conn = psycopg2.connect(**self.pg_config) - cursor = conn.cursor() - try: - cursor.execute(SQL_DELETE_AB_COMPARISON, (comparison_id,)) - deleted = cursor.rowcount > 0 - conn.commit() - if deleted: - logger.info(f"Deleted A/B comparison {comparison_id}") - return deleted - finally: - cursor.close() - conn.close() - - def get_ab_comparisons_by_conversation(self, conversation_id: int) -> List[Dict[str, Any]]: - """ - Get all A/B comparisons for a conversation. - - Returns: - List of comparison dicts - """ - conn = psycopg2.connect(**self.pg_config) - cursor = conn.cursor() - try: - cursor.execute(SQL_GET_AB_COMPARISONS_BY_CONVERSATION, (conversation_id,)) - rows = cursor.fetchall() - return [ - { - 'comparison_id': row[0], - 'conversation_id': row[1], - 'user_prompt_mid': row[2], - 'response_a_mid': row[3], - 'response_b_mid': row[4], - 'config_a_id': row[5], - 'config_b_id': row[6], - 'is_config_a_first': row[7], - 'preference': row[8], - 'preference_ts': row[9].isoformat() if row[9] else None, - 'created_at': row[10].isoformat() if row[10] else None, - } - for row in rows - ] - finally: - cursor.close() - conn.close() - # ========================================================================= # Agent Trace Methods # ========================================================================= @@ -933,24 +969,7 @@ def get_agent_trace(self, trace_id: str) -> Optional[Dict[str, Any]]: row = cursor.fetchone() if row is None: return None - return { - 'trace_id': row[0], - 'conversation_id': row[1], - 'message_id': row[2], - 'user_message_id': row[3], - 'config_id': row[4], - 'pipeline_name': row[5], - 'events': row[6], # Already JSON from JSONB - 'started_at': row[7].isoformat() if row[7] else None, - 'completed_at': row[8].isoformat() if row[8] else None, - 'status': row[9], - 'total_tool_calls': row[10], - 'total_tokens_used': row[11], - 'total_duration_ms': row[12], - 'cancelled_by': row[13], - 'cancellation_reason': row[14], - 'created_at': row[15].isoformat() if row[15] else None, - } + return self._trace_from_row(row) finally: cursor.close() conn.close() @@ -966,24 +985,7 @@ def get_trace_by_message(self, message_id: int) -> Optional[Dict[str, Any]]: row = cursor.fetchone() if row is None: return None - return { - 'trace_id': row[0], - 'conversation_id': row[1], - 'message_id': row[2], - 'user_message_id': row[3], - 'config_id': row[4], - 'pipeline_name': row[5], - 'events': row[6], - 'started_at': row[7].isoformat() if row[7] else None, - 'completed_at': row[8].isoformat() if row[8] else None, - 'status': row[9], - 'total_tool_calls': row[10], - 'total_tokens_used': row[11], - 'total_duration_ms': row[12], - 'cancelled_by': row[13], - 'cancellation_reason': row[14], - 'created_at': row[15].isoformat() if row[15] else None, - } + return self._trace_from_row(row) finally: cursor.close() conn.close() @@ -999,17 +1001,7 @@ def get_active_trace(self, conversation_id: int) -> Optional[Dict[str, Any]]: row = cursor.fetchone() if row is None: return None - return { - 'trace_id': row[0], - 'conversation_id': row[1], - 'message_id': row[2], - 'user_message_id': row[3], - 'config_id': row[4], - 'pipeline_name': row[5], - 'events': row[6], - 'started_at': row[7].isoformat() if row[7] else None, - 'status': row[8], - } + return self._trace_from_row(row) finally: cursor.close() conn.close() @@ -1066,8 +1058,13 @@ def query_conversation_history(self, conversation_id, client_id, user_id: Option # query conversation history cursor.execute(SQL_QUERY_CONVO, (conversation_id,)) - history = cursor.fetchall() - history = collapse_assistant_sequences(history, sender_name=ARCHI_SENDER) + history_rows = cursor.fetchall() + comparisons = self.conv_service.get_conversation_ab_comparisons(str(conversation_id)) + suppressed_ids = self._suppressed_ab_message_ids(comparisons) + if suppressed_ids: + history_rows = [row for row in history_rows if row[2] not in suppressed_ids] + history_rows = collapse_assistant_sequences(history_rows, sender_name=ARCHI_SENDER) + history = [(row[0], row[1]) for row in history_rows] # clean up database connection state cursor.close() @@ -1333,10 +1330,112 @@ def _create_provider_llm(self, provider: str, model: str, api_key: str = None): logger.warning(f"Failed to create provider LLM {provider}/{model}: {e}") raise + def _create_variant_archi( + self, + variant: "ABVariant", + *, + variant_agent_spec: Optional[AgentSpec] = None, + request_provider: Optional[str] = None, + request_model: Optional[str] = None, + request_provider_api_key: Optional[str] = None, + ) -> "archi": + """ + Build a temporary archi instance configured for a specific A/B variant. + + Uses the deployment defaults for provider/model unless the variant overrides + them, but always requires an explicit variant agent spec. + """ + chat_cfg = self.services_config.get("chat_app", {}) + + spec_name = (variant.agent_spec or "").strip() + if not spec_name: + raise ABPoolError(f"Variant '{variant.label}' is missing required agent_spec.") + if Path(spec_name).name != spec_name: + raise ABPoolError( + f"Variant '{variant.label}' must use an agent_spec filename in the A/B catalog, got '{spec_name}'." + ) + if variant_agent_spec is None: + record = self.ab_agent_spec_service.load_agent_spec(spec_name) + variant_agent_spec = record.to_agent_spec() + + agent_class = self._get_agent_class_from_cfg(chat_cfg) + default_provider = variant.provider or request_provider or chat_cfg.get("default_provider") + default_model = variant.model or request_model or chat_cfg.get("default_model") + prompt_overrides = chat_cfg.get("prompts", {}) + + variant_archi = archi( + pipeline=agent_class, + agent_spec=variant_agent_spec, + default_provider=default_provider, + default_model=default_model, + prompt_overrides=prompt_overrides, + ) + + if ( + request_provider_api_key + and default_provider + and default_model + and default_provider == request_provider + and hasattr(variant_archi, 'pipeline') + ): + override_llm = self._create_provider_llm( + default_provider, + default_model, + request_provider_api_key, + ) + if override_llm and hasattr(variant_archi.pipeline, 'agent_llm'): + variant_archi.pipeline.agent_llm = override_llm + if hasattr(variant_archi.pipeline, 'refresh_agent'): + variant_archi.pipeline.refresh_agent(force=True) + + # Apply retriever overrides if specified + if variant.num_documents_to_retrieve is not None and hasattr(variant_archi, 'pipeline'): + pipeline = variant_archi.pipeline + if hasattr(pipeline, 'pipeline_config'): + pipeline.pipeline_config['num_documents_to_retrieve'] = variant.num_documents_to_retrieve + + if variant.recursion_limit is not None and hasattr(variant_archi, 'pipeline'): + pipeline = variant_archi.pipeline + if hasattr(pipeline, 'recursion_limit'): + pipeline.recursion_limit = variant.recursion_limit + elif hasattr(pipeline, 'pipeline_config'): + pipeline.pipeline_config['recursion_limit'] = variant.recursion_limit + + return variant_archi + + @staticmethod + def _comparison_canonical_message_id(comparison) -> Optional[int]: + preference = getattr(comparison, "preference", None) + if preference == "b": + return getattr(comparison, "response_b_mid", None) + if preference in ("a", "tie", "skip"): + return getattr(comparison, "response_a_mid", None) + return None + + @classmethod + def _suppressed_ab_message_ids(cls, comparisons: List[Any]) -> set: + suppressed: set = set() + for comparison in comparisons or []: + preference = getattr(comparison, "preference", None) + a_mid = getattr(comparison, "response_a_mid", None) + b_mid = getattr(comparison, "response_b_mid", None) + if preference is None: + if a_mid: + suppressed.add(a_mid) + if b_mid: + suppressed.add(b_mid) + continue + + canonical_mid = cls._comparison_canonical_message_id(comparison) + for mid in (a_mid, b_mid): + if mid and mid != canonical_mid: + suppressed.add(mid) + return suppressed + def _prepare_chat_context( self, message: List[str], - conversation_id: int | None, + conversation_id: Optional[str], client_id: str, is_refresh: bool, server_received_msg_ts: datetime, @@ -1368,30 +1467,453 @@ def _prepare_chat_context( if not is_refresh: history = history + [(sender, content)] - if len(history) >= QUERY_LIMIT: - return None, 500 + if len(history) >= QUERY_LIMIT: + return None, 500 + + return ( + ChatRequestContext( + sender=sender, + content=content, + conversation_id=conversation_id, + history=history, + is_refresh=is_refresh, + ), + None, + ) + + def _message_content(self, message) -> str: + content = getattr(message, "content", "") + if isinstance(content, list): + content = " ".join(str(part) for part in content) + return str(content) + + def _truncate_text(self, text: str, max_chars: int) -> str: + if max_chars and len(text) > max_chars: + return text[: max_chars - 3].rstrip() + "..." + return text + + # ========================================================================= + # Shared Helpers (deduplicated from multiple call-sites) + # ========================================================================= + + @staticmethod + def _error_event(error_code: int) -> Dict[str, Any]: + """Map an error code to a structured error event dict.""" + if error_code == 408: + message = CLIENT_TIMEOUT_ERROR_MESSAGE + elif error_code == 403: + message = "conversation not found" + else: + message = "server error; see chat logs for message" + return {"type": "error", "status": error_code, "message": message} + + @staticmethod + def _trace_from_row(row) -> Dict[str, Any]: + """Convert a positional agent trace DB row to a dict. + + Handles both full rows (16 fields) and subset rows (9 fields from get_active_trace). + """ + result = { + 'trace_id': row[0], + 'conversation_id': row[1], + 'message_id': row[2], + 'user_message_id': row[3], + 'config_id': row[4], + 'pipeline_name': row[5], + 'events': row[6], + 'started_at': row[7].isoformat() if row[7] else None, + } + if len(row) > 9: + # Full row from get_agent_trace / get_trace_by_message + result.update({ + 'completed_at': row[8].isoformat() if row[8] else None, + 'status': row[9], + 'total_tool_calls': row[10], + 'total_tokens_used': row[11], + 'total_duration_ms': row[12], + 'cancelled_by': row[13], + 'cancellation_reason': row[14], + 'created_at': row[15].isoformat() if row[15] else None, + }) + else: + # Subset row from get_active_trace + result['status'] = row[8] + return result + + @staticmethod + def _get_agent_class_from_cfg(chat_cfg: dict) -> Optional[str]: + """Extract agent class name from a chat config dict.""" + return chat_cfg.get("agent_class") or chat_cfg.get("pipeline") + + @staticmethod + def _format_score_str(score) -> str: + """Format a source relevance score for display.""" + if score == -1.0 or score == "N/A": + return "" + return f" ({score:.2f})" + + # ========================================================================= + # Pool-based A/B Comparison Streaming + # ========================================================================= + + def stream_ab_comparison( + self, + message: List[str], + conversation_id: Optional[str], + client_id: str, + is_refresh: bool, + server_received_msg_ts: datetime, + client_sent_msg_ts: float, + client_timeout: float, + config_name: str, + *, + user_id: Optional[str] = None, + provider: Optional[str] = None, + model: Optional[str] = None, + provider_api_key: Optional[str] = None, + ) -> Iterator[Dict[str, Any]]: + """ + Stream a champion-vs-variant A/B comparison. + + Yields interleaved NDJSON events tagged with ``arm: 'a'`` or ``arm: 'b'`` + in real-time as each arm's pipeline produces output. + Each arm emits its own terminal ``final`` event when generation ends. + A final ``ab_meta`` event carries the comparison_id and variant mapping. + """ + import queue + import threading + + if not self.ab_pool: + yield {"type": "error", "message": "A/B pool not configured"} + return + + requested_config = self._resolve_config_name(config_name) + self.update_config(config_name=requested_config) + + # Sample matchup + arm_a_variant, arm_b_variant, is_champion_first = self.ab_pool.sample_matchup() + logger.info( + "A/B matchup: arm_a='%s' arm_b='%s' champion_first=%s", + arm_a_variant.name, arm_b_variant.name, is_champion_first, + ) + + # Prepare chat context (shared — same user message for both arms) + timestamps = self._init_timestamps() + context, error_code = self._prepare_chat_context( + message, + conversation_id, + client_id, + is_refresh, + server_received_msg_ts, + client_sent_msg_ts, + client_timeout, + timestamps, + user_id=user_id, + ) + if error_code is not None: + yield self._error_event(error_code) + return + + # Build variant archis + try: + arm_a_variant, arm_a_agent_spec = self._resolve_runtime_ab_variant(arm_a_variant) + arm_b_variant, arm_b_agent_spec = self._resolve_runtime_ab_variant(arm_b_variant) + archi_a = self._create_variant_archi( + arm_a_variant, + variant_agent_spec=arm_a_agent_spec, + request_provider=provider, + request_model=model, + request_provider_api_key=provider_api_key, + ) + archi_b = self._create_variant_archi( + arm_b_variant, + variant_agent_spec=arm_b_agent_spec, + request_provider=provider, + request_model=model, + request_provider_api_key=provider_api_key, + ) + except Exception as exc: + logger.error("Failed to create variant pipelines: %s", exc) + yield {"type": "error", "message": f"Failed to initialise A/B variants: {exc}"} + return + + # Shared queue for real-time interleaving + event_queue: queue.Queue = queue.Queue() + _SENTINEL = object() + + # Track final text per arm (mutated by threads). + # Thread-safety note: each thread writes to its own key ("a" or "b") + # which is safe under CPython's GIL. The "final_text" value relies on + # PipelineEventFormatter yielding *accumulated* content (not deltas); + # the last write per arm is therefore the complete response text. + arm_results = { + "a": { + "final_text": "", + "error": None, + "final_emitted": False, + "duration_ms": None, + }, + "b": { + "final_text": "", + "error": None, + "final_emitted": False, + "duration_ms": None, + }, + } + arm_model_used = { + "a": f"{arm_a_variant.provider or ''}/{arm_a_variant.model or ''}".strip("/"), + "b": f"{arm_b_variant.provider or ''}/{arm_b_variant.model or ''}".strip("/"), + } + + def _stream_arm(arm_archi, arm_label): + """Run one arm's stream in a thread, pushing events to the shared queue.""" + import time as _time + formatter = PipelineEventFormatter(message_content_fn=self._message_content) + t0 = _time.monotonic() + first_event_logged = False + try: + logger.info("A/B arm '%s' thread started (t+0.0s)", arm_label) + vs = self.archi.vs_connector.get_vectorstore() + logger.info( + "A/B arm '%s' vectorstore ready (t+%.1fs)", + arm_label, _time.monotonic() - t0, + ) + for output in arm_archi.pipeline.stream( + history=context.history, + conversation_id=context.conversation_id, + vectorstore=vs, + ): + output_meta = output.metadata or {} + for event in formatter.process(output): + if not first_event_logged: + logger.info( + "A/B arm '%s' first event (t+%.1fs): type=%s", + arm_label, _time.monotonic() - t0, event.get("type"), + ) + first_event_logged = True + event["arm"] = arm_label + if event["type"] == "text": + arm_results[arm_label]["final_text"] = event["content"] + event_queue.put(event) + if output_meta.get("event_type") == "final" and not arm_results[arm_label]["final_emitted"]: + if not first_event_logged: + logger.info( + "A/B arm '%s' first event (t+%.1fs): type=final", + arm_label, _time.monotonic() - t0, + ) + first_event_logged = True + final_text = getattr(output, "answer", "") or formatter.last_text or arm_results[arm_label]["final_text"] + arm_results[arm_label]["final_text"] = final_text + arm_results[arm_label]["final_emitted"] = True + duration_ms = int((_time.monotonic() - t0) * 1000) + arm_results[arm_label]["duration_ms"] = duration_ms + event_queue.put({ + "type": "final", + "arm": arm_label, + "response": final_text, + "usage": output_meta.get("usage"), + "model": output_meta.get("model"), + "model_used": arm_model_used[arm_label], + "duration_ms": duration_ms, + }) + except Exception as exc: + arm_results[arm_label]["error"] = str(exc) + event_queue.put({"type": "error", "arm": arm_label, "message": str(exc)}) + finally: + logger.info( + "A/B arm '%s' finished (t+%.1fs)", + arm_label, _time.monotonic() - t0, + ) + event_queue.put(_SENTINEL) + + # Yield arm labels early so the frontend can display variant names + yield { + "type": "ab_arms", + "arm_a_name": arm_a_variant.name, + "arm_b_name": arm_b_variant.name, + "variant_label_mode": self.ab_pool.variant_label_mode, + } + + # Start both arms in parallel threads + thread_a = threading.Thread(target=_stream_arm, args=(archi_a, "a"), daemon=True) + thread_b = threading.Thread(target=_stream_arm, args=(archi_b, "b"), daemon=True) + thread_a.start() + thread_b.start() + + # Drain the queue in real-time, yielding events as they arrive + finished_count = 0 + while finished_count < 2: + item = event_queue.get() + if item is _SENTINEL: + finished_count += 1 + continue + yield item + + thread_a.join() + thread_b.join() + + # Check for errors + arm_a_error = arm_results["a"]["error"] + arm_b_error = arm_results["b"]["error"] + arm_a_final_text = arm_results["a"]["final_text"] + arm_b_final_text = arm_results["b"]["final_text"] + arm_a_duration_ms = arm_results["a"]["duration_ms"] + arm_b_duration_ms = arm_results["b"]["duration_ms"] + + if arm_a_error and arm_b_error: + yield {"type": "error", "message": "Both A/B arms failed", + "arm_a_error": arm_a_error, "arm_b_error": arm_b_error} + return + + if arm_a_error or arm_b_error: + yield {"type": "error", "message": "One A/B arm failed", + "failed_arm": "a" if arm_a_error else "b", + "error": arm_a_error or arm_b_error} + return + + # Store user message first (normal chat stores it inline, AB must do so explicitly) + user_prompt_mid = None + if not is_refresh: + try: + conn = psycopg2.connect(**self.pg_config) + cursor = conn.cursor() + insert_tups = [ + ("chat", context.conversation_id, context.sender, context.content, + "", "", datetime.now(), None, None), + ] + psycopg2.extras.execute_values(cursor, SQL_INSERT_CONVO, insert_tups) + row = cursor.fetchone() + user_prompt_mid = row[0] if row else None + conn.commit() + cursor.close() + conn.close() + except Exception as exc: + logger.error("Failed to store user message: %s", exc) + + # Store both responses as messages + arm_a_mid = self._store_assistant_message( + context.conversation_id, + arm_a_final_text, + model_used=f"{arm_a_variant.provider or ''}/{arm_a_variant.model or ''}".strip("/"), + pipeline_used=self.current_pipeline_used, + ) + arm_b_mid = self._store_assistant_message( + context.conversation_id, + arm_b_final_text, + model_used=f"{arm_b_variant.provider or ''}/{arm_b_variant.model or ''}".strip("/"), + pipeline_used=self.current_pipeline_used, + ) + + # Persist per-arm latency for analysis by reusing the timing table keyed by message_id. + self._persist_ab_arm_timing(arm_a_mid, arm_a_duration_ms) + self._persist_ab_arm_timing(arm_b_mid, arm_b_duration_ms) + + # Get user prompt message ID if not already stored above + if not user_prompt_mid: + user_prompt_mid = self._get_last_user_message_id(context.conversation_id) + + # Create comparison record (skip if we have no valid message IDs) + comparison_id = None + if user_prompt_mid and arm_a_mid and arm_b_mid: + try: + comparison_id = self.conv_service.create_ab_comparison( + conversation_id=context.conversation_id, + user_prompt_mid=user_prompt_mid, + response_a_mid=arm_a_mid, + response_b_mid=arm_b_mid, + model_a=f"{arm_a_variant.provider or ''}/{arm_a_variant.model or ''}".strip("/"), + pipeline_a=self.current_pipeline_used or "", + model_b=f"{arm_b_variant.provider or ''}/{arm_b_variant.model or ''}".strip("/"), + pipeline_b=self.current_pipeline_used or "", + is_config_a_first=is_champion_first, + variant_a_name=arm_a_variant.name, + variant_b_name=arm_b_variant.name, + variant_a_meta=arm_a_variant.to_meta_json(), + variant_b_meta=arm_b_variant.to_meta_json(), + ) + except Exception as exc: + logger.error("Failed to create A/B comparison record: %s", exc) + comparison_id = None + + # Emit final metadata event + yield { + "type": "ab_meta", + "comparison_id": comparison_id, + "conversation_id": context.conversation_id, + "arm_a_variant": arm_a_variant.name, + "arm_b_variant": arm_b_variant.name, + "arm_a_model_used": f"{arm_a_variant.provider or ''}/{arm_a_variant.model or ''}".strip("/"), + "arm_b_model_used": f"{arm_b_variant.provider or ''}/{arm_b_variant.model or ''}".strip("/"), + "is_champion_first": is_champion_first, + "arm_a_message_id": arm_a_mid, + "arm_b_message_id": arm_b_mid, + "arm_a_duration_ms": arm_a_duration_ms, + "arm_b_duration_ms": arm_b_duration_ms, + "variant_label_mode": self.ab_pool.variant_label_mode, + } + + def _store_assistant_message(self, conversation_id, content, model_used=None, pipeline_used=None): + """Store an assistant message and return the message_id.""" + try: + conn = psycopg2.connect(**self.pg_config) + cursor = conn.cursor() + insert_tups = [ + ("chat", conversation_id, "archi", content, "", "", datetime.now(), model_used, pipeline_used), + ] + psycopg2.extras.execute_values(cursor, SQL_INSERT_CONVO, insert_tups) + row = cursor.fetchone() + mid = row[0] if row else None + conn.commit() + cursor.close() + conn.close() + return mid + except Exception as exc: + logger.error("Failed to store assistant message: %s", exc) + return None + + def _persist_ab_arm_timing(self, message_id: Optional[int], duration_ms: Optional[int]) -> None: + """Persist A/B arm latency into the timing table for post-hoc analysis.""" + if not message_id or duration_ms is None: + return - return ( - ChatRequestContext( - sender=sender, - content=content, - conversation_id=conversation_id, - history=history, - is_refresh=is_refresh, - ), - None, - ) + safe_duration_ms = max(int(duration_ms), 0) + end_ts = datetime.now(timezone.utc) + start_ts = end_ts - timedelta(milliseconds=safe_duration_ms) + + synthetic_timestamps = { + "client_sent_msg_ts": start_ts, + "server_received_msg_ts": start_ts, + "lock_acquisition_ts": start_ts, + "vectorstore_update_ts": start_ts, + "query_convo_history_ts": start_ts, + "chain_finished_ts": end_ts, + "archi_message_ts": end_ts, + "insert_convo_ts": end_ts, + "finish_call_ts": end_ts, + "server_response_msg_ts": end_ts, + } - def _message_content(self, message) -> str: - content = getattr(message, "content", "") - if isinstance(content, list): - content = " ".join(str(part) for part in content) - return str(content) + try: + self.insert_timing(message_id, synthetic_timestamps) + except Exception as exc: + logger.warning("Failed to persist A/B timing for message %s: %s", message_id, exc) - def _truncate_text(self, text: str, max_chars: int) -> str: - if max_chars and len(text) > max_chars: - return text[: max_chars - 3].rstrip() + "..." - return text + def _get_last_user_message_id(self, conversation_id): + """Get the most recent user message_id for a conversation.""" + try: + conn = psycopg2.connect(**self.pg_config) + cursor = conn.cursor() + cursor.execute( + "SELECT message_id FROM conversations WHERE conversation_id = %s AND LOWER(sender) = 'user' ORDER BY ts DESC LIMIT 1", + (conversation_id,), + ) + row = cursor.fetchone() + cursor.close() + conn.close() + return row[0] if row else None + except Exception as exc: + logger.error("Failed to get user message id: %s", exc) + return None def _stream_events_from_output( self, @@ -1488,11 +2010,11 @@ def _finalize_result( scores = result.get("metadata", {}).get("retriever_scores", []) top_sources = self.get_top_sources(documents, scores) - # Use markdown links for client-side rendering, HTML for server-side - if render_markdown: - output += self.format_links(top_sources) - else: - output += self.format_links_markdown(top_sources) + output = self.append_source_section( + output, + top_sources, + render_markdown=render_markdown, + ) timestamps["archi_message_ts"] = datetime.now(timezone.utc) context_data = self.prepare_context_for_storage(documents, scores) @@ -1597,7 +2119,7 @@ def __call__(self, message: List[str], conversation_id: int|None, client_id: str def stream( self, message: List[str], - conversation_id: int | None, + conversation_id: Optional[str], client_id: str, is_refresh: bool, server_received_msg_ts: datetime, @@ -1616,49 +2138,14 @@ def stream( timestamps = self._init_timestamps() context = None last_output = None + formatter = PipelineEventFormatter( + message_content_fn=self._message_content, + max_step_chars=max_step_chars, + ) last_streamed_text = "" trace_id = None trace_events: List[Dict[str, Any]] = [] - tool_call_count = 0 stream_start_time = time.time() - emitted_tool_call_ids = set() - emitted_tool_start_ids = set() - pending_tool_call_ids: List[str] = [] - tool_calls_by_id: Dict[str, Dict[str, Any]] = {} - synthetic_tool_counter = 0 - - def _next_tool_call_id(tool_name: str) -> str: - nonlocal synthetic_tool_counter - synthetic_tool_counter += 1 - safe_name = re.sub(r"[^a-zA-Z0-9_]+", "_", (tool_name or "unknown")).strip("_") or "unknown" - return f"synthetic_tool_{synthetic_tool_counter}_{safe_name}" - - def _is_empty_tool_args(tool_args: Any) -> bool: - return tool_args in (None, "", {}, []) - - def _has_meaningful_tool_payload(tool_name: Any, tool_args: Any) -> bool: - if isinstance(tool_name, str) and tool_name.strip() and tool_name.strip().lower() != "unknown": - return True - return not _is_empty_tool_args(tool_args) - - def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> None: - if not tool_call_id: - return - current = tool_calls_by_id.get(tool_call_id, {}) - current_name = current.get("tool_name", "unknown") - current_args = current.get("tool_args", {}) - merged_name = ( - tool_name - if isinstance(tool_name, str) - and tool_name.strip() - and tool_name.strip().lower() != "unknown" - else current_name - ) - merged_args = tool_args if not _is_empty_tool_args(tool_args) else current_args - tool_calls_by_id[tool_call_id] = { - "tool_name": merged_name or "unknown", - "tool_args": merged_args, - } try: context, error_code = self._prepare_chat_context( @@ -1673,12 +2160,7 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No user_id=user_id, ) if error_code is not None: - error_message = "server error; see chat logs for message" - if error_code == 408: - error_message = CLIENT_TIMEOUT_ERROR_MESSAGE - elif error_code == 403: - error_message = "conversation not found" - yield {"type": "error", "status": error_code, "message": error_message} + yield self._error_event(error_code) return requested_config = self._resolve_config_name(config_name) @@ -1724,223 +2206,19 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No cancellation_reason='Client timeout', total_duration_ms=total_duration_ms, ) - yield {"type": "error", "status": 408, "message": CLIENT_TIMEOUT_ERROR_MESSAGE} + yield self._error_event(408) return last_output = output - # Extract event_type from metadata (new structured events from BaseReActAgent) + # Use shared event formatter for structured event types event_type = output.metadata.get("event_type", "text") if output.metadata else "text" timestamp = datetime.now(timezone.utc).isoformat() - - # Handle different event types - if event_type == "tool_start": - tool_messages = getattr(output, "messages", []) or [] - tool_message = tool_messages[0] if tool_messages else None - tool_calls = getattr(tool_message, "tool_calls", None) if tool_message else None - memory_args_by_id = {} - if output.metadata: - memory_args_by_id = output.metadata.get("tool_inputs_by_id", {}) or {} - raw_args_by_id: Dict[str, Any] = {} - raw_name_by_id: Dict[str, str] = {} - if tool_message is not None: - try: - additional = getattr(tool_message, "additional_kwargs", {}) or {} - raw_tool_calls = additional.get("tool_calls") or [] - for raw_call in raw_tool_calls: - if not isinstance(raw_call, dict): - continue - raw_id = raw_call.get("id") - function_obj = raw_call.get("function") or {} - raw_name = function_obj.get("name") - raw_arguments = function_obj.get("arguments") - parsed_args: Any = None - if isinstance(raw_arguments, str) and raw_arguments.strip(): - try: - parsed_args = json.loads(raw_arguments) - except Exception: - parsed_args = {"_raw_arguments": raw_arguments} - elif isinstance(raw_arguments, dict): - parsed_args = raw_arguments - if raw_id and parsed_args is not None: - raw_args_by_id[raw_id] = parsed_args - if raw_id and isinstance(raw_name, str) and raw_name.strip(): - raw_name_by_id[raw_id] = raw_name.strip() - - # Newer OpenAI/LangChain payloads may carry partial tool calls here. - for chunk in getattr(tool_message, "tool_call_chunks", []) or []: - if not isinstance(chunk, dict): - continue - chunk_id = chunk.get("id") - chunk_name = chunk.get("name") - chunk_args = chunk.get("args") - parsed_chunk_args: Any = None - if isinstance(chunk_args, str) and chunk_args.strip(): - try: - parsed_chunk_args = json.loads(chunk_args) - except Exception: - parsed_chunk_args = {"_raw_arguments": chunk_args} - elif isinstance(chunk_args, dict): - parsed_chunk_args = chunk_args - if chunk_id and parsed_chunk_args is not None: - raw_args_by_id[chunk_id] = parsed_chunk_args - if chunk_id and isinstance(chunk_name, str) and chunk_name.strip(): - raw_name_by_id[chunk_id] = chunk_name.strip() - except Exception: - pass - if tool_calls: - for tool_call in tool_calls: - tool_call_id = tool_call.get("id", "") - tool_args = tool_call.get("args", {}) - if _is_empty_tool_args(tool_args): - tool_args = raw_args_by_id.get(tool_call_id, tool_args) - if _is_empty_tool_args(tool_args): - fallback = memory_args_by_id.get(tool_call_id, {}) - if isinstance(fallback, dict): - tool_args = fallback.get("tool_input", tool_args) - tool_name = tool_call.get("name", "unknown") - if (not tool_name or str(tool_name).strip().lower() == "unknown") and tool_call_id in raw_name_by_id: - tool_name = raw_name_by_id[tool_call_id] - if (not tool_name) and isinstance(memory_args_by_id.get(tool_call_id), dict): - tool_name = memory_args_by_id[tool_call_id].get("tool_name", "unknown") - if (not tool_call_id) and (not _has_meaningful_tool_payload(tool_name, tool_args)): - continue - if not tool_call_id: - tool_call_id = _next_tool_call_id(tool_name) - _remember_tool_call(tool_call_id, tool_name, tool_args) - if tool_call_id in emitted_tool_call_ids: - continue - emitted_tool_call_ids.add(tool_call_id) - pending_tool_call_ids.append(tool_call_id) - tool_call_count += 1 - elif memory_args_by_id: - for memory_id, memory_call in memory_args_by_id.items(): - if not isinstance(memory_call, dict): - continue - tool_name = memory_call.get("tool_name", "unknown") - tool_args = memory_call.get("tool_input", {}) - if not _has_meaningful_tool_payload(tool_name, tool_args): - continue - tool_call_id = memory_id or _next_tool_call_id(tool_name) - if tool_call_id in emitted_tool_call_ids: - continue - emitted_tool_call_ids.add(tool_call_id) - pending_tool_call_ids.append(tool_call_id) - _remember_tool_call(tool_call_id, tool_name, tool_args) - tool_call_count += 1 - - elif event_type == "tool_output": - tool_messages = getattr(output, "messages", []) or [] - tool_message = tool_messages[0] if tool_messages else None - tool_output = self._message_content(tool_message) if tool_message else "" - truncated = len(tool_output) > max_step_chars - full_length = len(tool_output) if truncated else None - display_output = self._truncate_text(tool_output, max_step_chars) - - output_tool_call_id = getattr(tool_message, "tool_call_id", "") if tool_message else "" - if not output_tool_call_id and pending_tool_call_ids: - output_tool_call_id = pending_tool_call_ids.pop(0) - elif output_tool_call_id in pending_tool_call_ids: - pending_tool_call_ids.remove(output_tool_call_id) - - # Emit tool_start once, immediately before first output for stable ordering. - if output_tool_call_id and output_tool_call_id not in emitted_tool_start_ids: - memory_args_by_id = output.metadata.get("tool_inputs_by_id", {}) if output.metadata else {} - fallback = memory_args_by_id.get(output_tool_call_id, {}) - fallback_name = "unknown" - fallback_args: Any = {} - if isinstance(fallback, dict): - fallback_name = fallback.get("tool_name", "unknown") - fallback_args = fallback.get("tool_input", {}) - _remember_tool_call(output_tool_call_id, fallback_name, fallback_args) - call_info = tool_calls_by_id.get(output_tool_call_id, {}) - start_event = { - "type": "tool_start", - "tool_call_id": output_tool_call_id, - "tool_name": call_info.get("tool_name", "unknown"), - "tool_args": call_info.get("tool_args", {}), - "timestamp": timestamp, - "conversation_id": context.conversation_id, - } - trace_events.append(start_event) - emitted_tool_start_ids.add(output_tool_call_id) - if include_tool_steps: - yield start_event - - trace_event = { - "type": "tool_output", - "tool_call_id": output_tool_call_id, - "output": display_output, - "truncated": truncated, - "full_length": full_length, - "timestamp": timestamp, - "conversation_id": context.conversation_id, - } - trace_events.append(trace_event) - if include_tool_steps: - yield trace_event - - elif event_type == "tool_end": - trace_event = { - "type": "tool_end", - "tool_call_id": output.metadata.get("tool_call_id", ""), - "status": output.metadata.get("status", "success"), - "duration_ms": output.metadata.get("duration_ms"), - "timestamp": timestamp, - "conversation_id": context.conversation_id, - } - trace_events.append(trace_event) - if include_tool_steps: - yield trace_event - - elif event_type == "thinking_start": - trace_event = { - "type": "thinking_start", - "step_id": output.metadata.get("step_id", ""), - "timestamp": timestamp, - "conversation_id": context.conversation_id, - } - trace_events.append(trace_event) - if include_tool_steps: - yield trace_event - - elif event_type == "thinking_end": - thinking_content = output.metadata.get("thinking_content", "") - trace_event = { - "type": "thinking_end", - "step_id": output.metadata.get("step_id", ""), - "duration_ms": output.metadata.get("duration_ms"), - "thinking_content": thinking_content, - "timestamp": timestamp, - "conversation_id": context.conversation_id, - } - trace_events.append(trace_event) - if include_tool_steps: - yield trace_event - - elif event_type == "text": - # Stream text content - content = getattr(output, "answer", "") or "" - if content and include_agent_steps: - last_streamed_text = content - yield { - "type": "chunk", - "content": content, - "accumulated": True, - "conversation_id": context.conversation_id, - } - # Record text event in trace - if content: - trace_events.append({ - "type": "text", - "content": content, - "timestamp": timestamp, - }) - - elif event_type == "final": - # Final event handled below after loop - pass - else: - # Fallback: legacy event handling for non-agent pipelines + + if event_type == "final": + pass # handled after the loop + elif event_type not in ("tool_start", "tool_output", "tool_end", + "thinking_start", "thinking_end", "text"): + # Legacy fallback for non-agent pipelines if getattr(output, "final", False): continue for event in self._stream_events_from_output( @@ -1951,7 +2229,6 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No max_chars=max_step_chars, ): yield event - if include_agent_steps: content = getattr(output, "answer", "") or "" if content: @@ -1967,6 +2244,30 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No "content": delta[i:i + chunk_size], "conversation_id": context.conversation_id, } + else: + # Formatter handles tool_start/output/end, thinking, text + for event in formatter.process(output): + event["timestamp"] = timestamp + event["conversation_id"] = context.conversation_id + if event["type"] == "text": + # Map to "chunk" type for backward compat with JS client + if include_agent_steps: + last_streamed_text = event["content"] + yield { + "type": "chunk", + "content": event["content"], + "accumulated": True, + "conversation_id": context.conversation_id, + } + trace_events.append({ + "type": "text", + "content": event["content"], + "timestamp": timestamp, + }) + else: + trace_events.append(event) + if include_tool_steps: + yield event timestamps["chain_finished_ts"] = datetime.now(timezone.utc) @@ -1981,20 +2282,6 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No ) yield {"type": "error", "status": 500, "message": "server error; see chat logs for message"} return - - # For providers like gpt-5, streamed tool chunks may carry empty args while - # the final AI message contains full tool arguments. Backfill before final. - try: - final_tool_calls = last_output.extract_tool_calls() if hasattr(last_output, "extract_tool_calls") else [] - for tc in final_tool_calls: - tool_call_id = tc.get("id", "") - tool_name = tc.get("name", "unknown") - tool_args = tc.get("args", {}) - if not tool_call_id or _is_empty_tool_args(tool_args): - continue - _remember_tool_call(tool_call_id, tool_name, tool_args) - except Exception: - pass # keep track of total number of queries and log this amount self.number_of_queries += 1 @@ -2045,7 +2332,7 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No events=trace_events, status='completed', message_id=message_ids[-1] if message_ids else None, - total_tool_calls=tool_call_count, + total_tool_calls=formatter.tool_call_count, total_duration_ms=total_duration_ms, ) @@ -2072,7 +2359,7 @@ def _remember_tool_call(tool_call_id: str, tool_name: Any, tool_args: Any) -> No trace_id=trace_id, events=trace_events, status='cancelled', - total_tool_calls=tool_call_count, + total_tool_calls=formatter.tool_call_count, total_duration_ms=total_duration_ms, cancelled_by='user', cancellation_reason='Stream cancelled by client', @@ -2147,6 +2434,12 @@ def __init__(self, app, **configs): # Initialize config service for dynamic settings self.config_service = ConfigService(pg_config=self.pg_config) + # Refresh the RBAC registry against the current deployment config. + try: + get_registry(force_reload=True) + except Exception as exc: + logger.warning("Failed to reload RBAC registry from current config: %s", exc) + # Data manager service URL for upload proxy dm_config = self.services_config.get("data_manager", {}) # Use 'hostname' for service discovery (Docker network name), fallback to 'host' for local dev @@ -2214,9 +2507,19 @@ def _inject_alerts(): # A/B testing endpoints logger.info("Adding A/B testing API endpoints") - self.add_endpoint('/api/ab/create', 'ab_create', self.require_auth(self.ab_create_comparison), methods=["POST"]) self.add_endpoint('/api/ab/preference', 'ab_preference', self.require_auth(self.ab_submit_preference), methods=["POST"]) self.add_endpoint('/api/ab/pending', 'ab_pending', self.require_auth(self.ab_get_pending), methods=["GET"]) + self.add_endpoint('/api/ab/pool', 'ab_pool', self.require_auth(self.ab_get_pool), methods=["GET"]) + self.add_endpoint('/api/ab/decision', 'ab_decision', self.require_auth(self.ab_get_decision), methods=["GET"]) + self.add_endpoint('/api/ab/pool/set', 'ab_pool_set', self.require_auth(self.ab_set_pool), methods=["POST"]) + self.add_endpoint('/api/ab/pool/settings/set', 'ab_pool_settings_set', self.require_auth(self.ab_set_settings), methods=["POST"]) + self.add_endpoint('/api/ab/pool/variants/set', 'ab_pool_variants_set', self.require_auth(self.ab_set_variants), methods=["POST"]) + self.add_endpoint('/api/ab/pool/disable', 'ab_pool_disable', self.require_auth(self.ab_disable_pool), methods=["POST"]) + self.add_endpoint('/api/ab/compare', 'ab_compare', self.require_auth(self.ab_compare_stream), methods=["POST"]) + self.add_endpoint('/api/ab/metrics', 'ab_metrics', self.require_auth(self.ab_get_metrics), methods=["GET"]) + self.add_endpoint('/api/ab/agents/list', 'list_ab_agents', self.require_auth(self.list_ab_agents), methods=["GET"]) + self.add_endpoint('/api/ab/agents/template', 'get_ab_agent_template', self.require_auth(self.get_ab_agent_template), methods=["GET"]) + self.add_endpoint('/api/ab/agents', 'save_ab_agent_spec', self.require_auth(self.save_ab_agent_spec), methods=["POST"]) # Agent trace endpoints logger.info("Adding agent trace API endpoints") @@ -2246,6 +2549,7 @@ def _inject_alerts(): # Enable/disable documents - requires documents:select permission logger.info("Adding data viewer API endpoints") self.add_endpoint('/data', 'data_viewer', self.require_perm(Permission.Documents.VIEW)(self.data_viewer_page)) + self.add_endpoint('/admin/ab-testing', 'ab_testing_admin_page', self.require_auth(self.ab_testing_admin_page)) self.add_endpoint('/api/data/documents', 'list_data_documents', self.require_perm(Permission.Documents.VIEW)(self.list_data_documents), methods=["GET"]) self.add_endpoint('/api/data/documents//content', 'get_data_document_content', self.require_perm(Permission.Documents.VIEW)(self.get_data_document_content), methods=["GET"]) self.add_endpoint('/api/data/documents//chunks', 'get_data_document_chunks', self.require_perm(Permission.Documents.VIEW)(self.get_data_document_chunks), methods=["GET"]) @@ -2391,7 +2695,7 @@ def login(self): flash('Invalid credentials') # Render login page with available auth methods - return render_template('landing.html', + return render_template('login.html', sso_enabled=self.sso_enabled, basic_auth_enabled=self.basic_auth_enabled) @@ -2545,12 +2849,9 @@ def decorated_function(*args, **kwargs): # Redirect to login page which will trigger SSO return redirect(url_for('login')) - # Return 401 Unauthorized response for API requests - return jsonify({'error': 'Unauthorized', 'message': 'Authentication required'}), 401 if request.path.startswith('/api/'): return jsonify({'error': 'Unauthorized', 'message': 'Authentication required'}), 401 - else: - return redirect(url_for('login')) + return redirect(url_for('login')) return f(*args, **kwargs) return decorated_function @@ -2785,7 +3086,7 @@ def get_pipeline_default_model(self): """ try: chat_cfg = self.config.get("services", {}).get("chat_app", {}) - agent_class = chat_cfg.get("agent_class") or chat_cfg.get("pipeline") + agent_class = ChatWrapper._get_agent_class_from_cfg(chat_cfg) provider = chat_cfg.get("default_provider") model = chat_cfg.get("default_model") model_name = f"{provider}/{model}" if provider and model else None @@ -2802,11 +3103,300 @@ def get_pipeline_default_model(self): def _get_agents_dir(self) -> Path: agents_dir = self.services_config.get("chat_app", {}).get("agents_dir") or "/root/archi/agents" - return Path(agents_dir) + return Path(agents_dir).expanduser() + + def _get_ab_agents_dir(self) -> Path: + chat_cfg = self.services_config.get("chat_app", {}) or {} + path, _ = resolve_ab_agents_dir(chat_cfg) + return path + + def _get_agent_scope(self) -> str: + scope = None + if request.is_json and request.json: + scope = request.json.get("scope") + if not scope: + scope = request.args.get("scope", "default") + scope = str(scope or "default").strip().lower() + return "ab" if scope == "ab" else "default" + + def _get_agent_dir_for_scope(self, scope: str, *, create: bool = False) -> Path: + if scope == "ab": + raise PermissionError("A/B agent specs are stored in the database") + directory = self._get_agents_dir() + if create: + directory.mkdir(parents=True, exist_ok=True) + return directory + + def _get_ab_agent_spec_service(self) -> ABAgentSpecService: + if hasattr(self, "chat") and getattr(self.chat, "ab_agent_spec_service", None) is not None: + return self.chat.ab_agent_spec_service + if not hasattr(self, "_ab_agent_spec_service"): + self._ab_agent_spec_service = ABAgentSpecService(pg_config=self.pg_config) + return self._ab_agent_spec_service + + def _get_ab_runtime_defaults(self) -> Dict[str, Any]: + chat_cfg = self.services_config.get("chat_app", {}) or {} + data_cfg = self.config.get("data_manager", {}) or {} + retrievers_cfg = data_cfg.get("retrievers", {}) or {} + hybrid_cfg = retrievers_cfg.get("hybrid_retriever", {}) or {} + return { + "provider": chat_cfg.get("default_provider"), + "model": chat_cfg.get("default_model"), + "recursion_limit": int(chat_cfg.get("recursion_limit", 50) or 50), + "num_documents_to_retrieve": int(hybrid_cfg.get("num_documents_to_retrieve", 5) or 5), + "ab_catalog_source": "database", + } + + @staticmethod + def _normalize_ab_variant_details(raw_variants: Any) -> List[Dict[str, Any]]: + details: List[Dict[str, Any]] = [] + if not isinstance(raw_variants, list): + return details + for entry in raw_variants: + if isinstance(entry, str): + details.append({"label": entry.strip(), "agent_spec": ""}) + continue + if not isinstance(entry, dict): + details.append({"label": "", "agent_spec": ""}) + continue + details.append({ + "label": str(entry.get("label") or entry.get("name") or "").strip(), + "agent_spec": str(entry.get("agent_spec") or "").strip(), + "provider": entry.get("provider") or None, + "model": entry.get("model") or None, + "num_documents_to_retrieve": entry.get("num_documents_to_retrieve"), + "recursion_limit": entry.get("recursion_limit"), + "agent_spec_id": entry.get("agent_spec_id"), + "agent_spec_name": entry.get("agent_spec_name"), + "agent_spec_version_id": entry.get("agent_spec_version_id"), + "agent_spec_version_number": entry.get("agent_spec_version_number"), + "agent_spec_content_hash": entry.get("agent_spec_content_hash"), + "agent_spec_tools": entry.get("agent_spec_tools"), + "agent_spec_prompt_hash": entry.get("agent_spec_prompt_hash"), + }) + return details + + @staticmethod + def _get_ab_setting( + mapping: Dict[str, Any], + canonical_key: str, + legacy_key: Optional[str] = None, + default: Any = None, + ) -> Any: + if isinstance(mapping, dict): + if canonical_key in mapping: + return mapping.get(canonical_key) + if legacy_key and legacy_key in mapping: + return mapping.get(legacy_key) + return default + + @classmethod + def _get_ab_pool_champion(cls, raw_pool: Dict[str, Any]) -> str: + return str(cls._get_ab_setting(raw_pool, "champion", "control", "") or "").strip() + + def _build_admin_ab_pool_payload(self) -> Dict[str, Any]: + chat_cfg = self.services_config.get("chat_app", {}) or {} + raw_ab_cfg = (chat_cfg.get("ab_testing") or {}) if isinstance(chat_cfg.get("ab_testing"), dict) else {} + raw_pool = raw_ab_cfg.get("pool") or {} + state = getattr(self.chat, "ab_pool_state", None) + active_pool = getattr(self.chat, "ab_pool", None) + participation = self._get_ab_participation_state() + defaults = self._get_ab_runtime_defaults() + + variant_details = self._normalize_ab_variant_details(raw_pool.get("variants")) + champion = self._get_ab_pool_champion(raw_pool) + comparison_rate = self._get_ab_setting(raw_ab_cfg, "comparison_rate", "sample_rate", 1.0) + variant_label_mode = normalize_ab_disclosure_mode( + self._get_ab_setting( + raw_ab_cfg, "variant_label_mode", "disclosure_mode", DEFAULT_DISCLOSURE_MODE + ) + ) + activity_panel_default_state = normalize_ab_trace_mode( + self._get_ab_setting( + raw_ab_cfg, + "activity_panel_default_state", + "default_trace_mode", + DEFAULT_TRACE_MODE, + ) + ) + max_pending = self._get_ab_setting( + raw_ab_cfg, + "max_pending_comparisons_per_conversation", + "max_pending_per_conversation", + 1, + ) + + if active_pool: + variant_details = [variant.to_meta() for variant in active_pool.variants] + champion = active_pool.champion_name + comparison_rate = active_pool.comparison_rate + variant_label_mode = active_pool.variant_label_mode + activity_panel_default_state = active_pool.activity_panel_default_state + max_pending = active_pool.max_pending_comparisons_per_conversation + + return { + "success": True, + "is_admin": self._is_admin_request(), + "can_view": self._can_view_ab_testing(), + "can_manage": self._can_manage_ab_testing(), + "can_view_metrics": self._can_view_ab_metrics(), + "can_participate": participation["can_participate"], + "participant_eligible": participation["eligible"], + "participant_reason": participation["reason"], + "participant_targeted": participation["targeted"], + "enabled": bool(active_pool and active_pool.enabled), + "enabled_requested": bool(raw_ab_cfg.get("enabled", False)), + "champion": champion, + "variants": [variant.get("label", "") for variant in variant_details if variant.get("label")], + "variant_details": variant_details, + "variant_count": len(variant_details), + "comparison_rate": comparison_rate, + "default_comparison_rate": float( + self._get_ab_setting(raw_ab_cfg, "comparison_rate", "sample_rate", comparison_rate or 1.0) + ), + "eligible_roles": list(self._get_ab_setting(raw_ab_cfg, "eligible_roles", "target_roles", []) or []), + "eligible_permissions": list( + self._get_ab_setting(raw_ab_cfg, "eligible_permissions", "target_permissions", []) or [] + ), + "max_pending_comparisons_per_conversation": max_pending, + "variant_label_mode": variant_label_mode, + "activity_panel_default_state": activity_panel_default_state, + "defaults": defaults, + "warnings": list(getattr(state, "warnings", []) or []), + "import_diagnostics": dict(getattr(self.chat, "ab_agent_import_diagnostics", {}) or {}), + } + + def _resolve_ab_variants( + self, + variant_items: Any, + *, + existing_variants: Optional[Dict[str, ABVariant]] = None, + ) -> tuple[list[ABVariant], list[str]]: + if not isinstance(variant_items, list) or len(variant_items) < 2: + raise ABPoolError("At least 2 variants are required") + + parsed_labels: List[str] = [] + for item in variant_items: + if isinstance(item, str): + label = item.strip() + elif isinstance(item, dict): + label = str(item.get('label') or item.get('name') or '').strip() + else: + label = '' + if not label: + raise ABPoolError("All variants must include a non-empty label") + parsed_labels.append(label) + + if len(set(parsed_labels)) != len(parsed_labels): + raise ABPoolError("Variant labels must be unique") + + ab_specs = self._get_ab_agent_spec_service() + spec_records = ab_specs.list_specs() + spec_map: Dict[str, ABAgentSpecRecord] = {record.name: record for record in spec_records} + + chat_cfg = self.chat.services_config.get("chat_app", {}) if hasattr(self, "chat") else self.services_config.get("chat_app", {}) + default_provider = chat_cfg.get("default_provider", "") + default_model = chat_cfg.get("default_model", "") + existing_variants = existing_variants or {} + + variants: List[ABVariant] = [] + for item, label in zip(variant_items, parsed_labels): + item_cfg = item if isinstance(item, dict) else {} + explicit_agent_spec = str(item_cfg.get('agent_spec') or '').strip() + if explicit_agent_spec: + if Path(explicit_agent_spec).name != explicit_agent_spec: + raise ABPoolError( + f"Variant '{label}' must use an A/B catalog filename" + ) + record = ab_specs.get_spec_by_filename(explicit_agent_spec) + if record is None: + raise ABPoolError( + f"Variant '{label}' references missing agent_spec '{explicit_agent_spec}'" + ) + else: + record = spec_map.get(label) + if not record: + raise ABPoolError( + f"Agent '{label}' not found in the A/B catalog; provide agent_spec explicitly" + ) + + existing = existing_variants.get(label) + provider_override = item_cfg.get('provider') + if provider_override is None and existing and existing.provider and existing.provider != default_provider: + provider_override = existing.provider + model_override = item_cfg.get('model') + if model_override is None and existing and existing.model and existing.model != default_model: + model_override = existing.model + + variants.append(ABVariant( + label=label, + agent_spec=record.filename, + provider=provider_override or None, + model=model_override or None, + num_documents_to_retrieve=item_cfg.get('num_documents_to_retrieve') or ( + existing.num_documents_to_retrieve if existing else None + ), + recursion_limit=item_cfg.get('recursion_limit') or ( + existing.recursion_limit if existing else None + ), + agent_spec_id=record.spec_id, + agent_spec_name=record.name, + agent_spec_version_id=record.version_id, + agent_spec_version_number=record.version_number, + agent_spec_content_hash=record.content_hash, + agent_spec_tools=list(record.tools), + agent_spec_prompt_hash=record.prompt_hash, + )) + + return variants, parsed_labels + + def _persist_ab_pool_config( + self, + *, + enabled: bool, + champion_name: str, + variants: List[ABVariant], + comparison_rate: float, + variant_label_mode: str, + activity_panel_default_state: str, + max_pending_comparisons_per_conversation: int, + ) -> None: + self.config_service.update_services_config({ + "chat_app": { + "ab_testing": { + "enabled": enabled, + "comparison_rate": comparison_rate, + "variant_label_mode": variant_label_mode, + "activity_panel_default_state": activity_panel_default_state, + "max_pending_comparisons_per_conversation": max_pending_comparisons_per_conversation, + "pool": { + "champion": champion_name, + "variants": [variant.to_meta() for variant in variants], + }, + } + } + }) + self._refresh_runtime_config() + + def _ndjson_response(self, event_iter) -> Response: + """Wrap an event iterator as an NDJSON streaming Response with standard headers.""" + def _event_stream() -> Iterator[str]: + padding = " " * 2048 + yield json.dumps({"type": "meta", "event": "stream_started", "padding": padding}) + "\n" + for event in event_iter: + yield json.dumps(event, default=str) + "\n" + + headers = { + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "Content-Encoding": "identity", + "Content-Type": "application/x-ndjson", + } + return Response(stream_with_context(_event_stream()), headers=headers) def _get_agent_class_name(self) -> Optional[str]: chat_cfg = self.services_config.get("chat_app", {}) - return chat_cfg.get("agent_class") or chat_cfg.get("pipeline") + return ChatWrapper._get_agent_class_from_cfg(chat_cfg) def _get_agent_tool_registry(self) -> List[str]: agent_class = self._get_agent_class_name() @@ -2814,10 +3404,10 @@ def _get_agent_tool_registry(self) -> List[str]: return [] try: from src.archi import pipelines + agent_cls = getattr(pipelines, agent_class, None) except Exception as exc: - logger.warning("Failed to import pipelines module: %s", exc) + logger.warning("Failed to load pipeline class %s: %s", agent_class, exc) return [] - agent_cls = getattr(pipelines, agent_class, None) if not agent_cls or not hasattr(agent_cls, "get_tool_registry"): return [] try: @@ -2834,10 +3424,10 @@ def _get_agent_tools(self) -> List[Dict[str, str]]: return [] try: from src.archi import pipelines + agent_cls = getattr(pipelines, agent_class, None) except Exception as exc: - logger.warning("Failed to import pipelines module: %s", exc) + logger.warning("Failed to load pipeline class %s: %s", agent_class, exc) return [] - agent_cls = getattr(pipelines, agent_class, None) if not agent_cls or not hasattr(agent_cls, "get_tool_registry"): return [] try: @@ -2871,32 +3461,145 @@ def _build_agent_template(self, name: str, tools: List[str]) -> str: "Write your system prompt here.\n\n" ) + def _build_agent_template_payload(self, name: str, *, scope: str = "default") -> Dict[str, Any]: + tool_items = self._get_agent_tools() + tools = [tool["name"] for tool in tool_items] + return { + "name": name, + "tools": tool_items, + "prompt": "Write your system prompt here.", + "template": self._build_agent_template(name, tools), + "scope": scope, + } + + def _build_ab_agent_content(self, name: str, tools: List[str], prompt: str) -> str: + normalized_name = str(name or "").strip() + normalized_prompt = str(prompt or "").strip() + normalized_tools = [ + str(tool).strip() + for tool in (tools or []) + if isinstance(tool, str) and str(tool).strip() + ] + if not normalized_name: + raise AgentSpecError("Agent name is required.") + if not normalized_tools: + raise AgentSpecError("At least one tool is required.") + if not normalized_prompt: + raise AgentSpecError("Prompt body is required.") + frontmatter = yaml.safe_dump( + { + "name": normalized_name, + "ab_only": True, + "tools": normalized_tools, + }, + sort_keys=False, + default_flow_style=False, + allow_unicode=False, + ).strip() + return f"---\n{frontmatter}\n---\n\n{normalized_prompt}\n" + + def _list_ab_agent_catalog_payload(self) -> Dict[str, Any]: + agents = [] + for record in self._get_ab_agent_spec_service().list_specs(): + agents.append({"name": record.name, "filename": record.filename, "ab_only": True}) + return { + "agents": agents, + "active_name": None, + "scope": "ab", + "directory": None, + } + + def list_ab_agents(self): + try: + if not self._can_view_ab_testing(): + return jsonify({"error": "A/B agent visibility requires A/B page access"}), 403 + return jsonify(self._list_ab_agent_catalog_payload()), 200 + except Exception as exc: + logger.error(f"Error listing A/B agents: {exc}") + return jsonify({"error": str(exc)}), 500 + + def get_ab_agent_template(self): + try: + if not self._can_manage_ab_testing(): + return jsonify({"error": "A/B agent management requires admin access"}), 403 + agent_name = request.args.get("name") or "New A/B Agent" + return jsonify(self._build_agent_template_payload(agent_name, scope="ab")), 200 + except Exception as exc: + logger.error(f"Error building A/B agent template: {exc}") + return jsonify({"error": str(exc)}), 500 + + def save_ab_agent_spec(self): + try: + if not self._can_manage_ab_testing(): + return jsonify({"error": "A/B agent management requires admin access"}), 403 + data = request.get_json(force=True) or {} + name = data.get("name") + tools = data.get("tools") + prompt = data.get("prompt") + if not isinstance(tools, list): + return jsonify({"error": "tools must be a list"}), 400 + content = self._build_ab_agent_content(name, tools, prompt) + created_by = ( + session.get("user", {}).get("email") + or session.get("user", {}).get("id") + or data.get("client_id") + or "system" + ) + record = self._get_ab_agent_spec_service().save_spec( + content, + created_by=created_by, + ) + return jsonify({ + "success": True, + "name": record.name, + "filename": record.filename, + "path": None, + "scope": "ab", + }), 200 + except AgentSpecError as exc: + logger.error(f"Invalid A/B agent spec: {exc}") + return jsonify({"error": f"Invalid agent spec: {exc}"}), 400 + except Exception as exc: + logger.error(f"Error saving A/B agent spec: {exc}") + return jsonify({"error": str(exc)}), 500 + def list_agents(self): """ List available agent specs for the dropdown. """ try: - agents_dir = self._get_agents_dir() + scope = self._get_agent_scope() + if scope == "ab": + if not self._can_view_ab_testing(): + return jsonify({"error": "A/B agent visibility requires A/B view access"}), 403 + return jsonify(self._list_ab_agent_catalog_payload()), 200 + agents_dir = self._get_agent_dir_for_scope(scope, create=(scope == "ab")) agent_files = list_agent_files(agents_dir) agents = [] for path in agent_files: try: spec = load_agent_spec(path) - agents.append({"name": spec.name, "filename": path.name}) + agents.append({"name": spec.name, "filename": path.name, "ab_only": spec.ab_only}) except AgentSpecError as exc: logger.warning("Skipping invalid agent spec %s: %s", path, exc) - try: - dynamic = get_dynamic_config() - except Exception: - dynamic = None - active_name = getattr(dynamic, "active_agent_name", None) if dynamic else None - if not active_name: - active_spec = getattr(self.chat, "agent_spec", None) - active_name = getattr(active_spec, "name", None) + active_name = None + if scope != "ab": + try: + dynamic = get_dynamic_config() + except Exception: + dynamic = None + active_name = getattr(dynamic, "active_agent_name", None) if dynamic else None + if not active_name: + active_spec = getattr(self.chat, "agent_spec", None) + active_name = getattr(active_spec, "name", None) return jsonify({ "agents": agents, "active_name": active_name, + "scope": scope, + "directory": str(agents_dir), }), 200 + except PermissionError as exc: + return jsonify({"error": str(exc)}), 403 except Exception as exc: logger.error(f"Error listing agents: {exc}") return jsonify({"error": str(exc)}), 500 @@ -2906,10 +3609,30 @@ def get_agent_spec(self): Fetch a single agent spec by name. """ try: + scope = self._get_agent_scope() name = request.args.get("name") - if not name: - return jsonify({"error": "name parameter required"}), 400 - agents_dir = self._get_agents_dir() + filename = request.args.get("filename") + if not name and not filename: + return jsonify({"error": "name or filename parameter required"}), 400 + if scope == "ab": + if not self._can_view_ab_testing(): + return jsonify({"error": "A/B agent visibility requires A/B view access"}), 403 + if filename: + record = self._get_ab_agent_spec_service().get_spec_by_filename(filename) + else: + record = self._get_ab_agent_spec_service().get_spec_by_name(name) + if record is None: + lookup = filename or name + return jsonify({"error": f"Agent '{lookup}' not found"}), 404 + return jsonify({ + "name": record.name, + "filename": record.filename, + "content": record.content, + "tools": list(record.tools), + "prompt": record.prompt, + "scope": scope, + }), 200 + agents_dir = self._get_agent_dir_for_scope(scope, create=(scope == "ab")) for path in list_agent_files(agents_dir): try: spec = load_agent_spec(path) @@ -2920,8 +3643,13 @@ def get_agent_spec(self): "name": spec.name, "filename": path.name, "content": path.read_text(), + "tools": list(getattr(spec, "tools", []) or []), + "prompt": getattr(spec, "prompt", ""), + "scope": scope, }), 200 return jsonify({"error": f"Agent '{name}' not found"}), 404 + except PermissionError as exc: + return jsonify({"error": str(exc)}), 403 except Exception as exc: logger.error(f"Error fetching agent spec: {exc}") return jsonify({"error": str(exc)}), 500 @@ -2931,14 +3659,11 @@ def get_agent_template(self): Return a prefilled agent spec template and available tools. """ try: + scope = self._get_agent_scope() + if scope == "ab" and not self._can_manage_ab_testing(): + return jsonify({"error": "A/B agent management requires admin access"}), 403 agent_name = request.args.get("name") or "New Agent" - tool_items = self._get_agent_tools() - tools = [tool["name"] for tool in tool_items] - return jsonify({ - "name": agent_name, - "tools": tool_items, - "template": self._build_agent_template(agent_name, tools), - }), 200 + return jsonify(self._build_agent_template_payload(agent_name, scope=scope)), 200 except Exception as exc: logger.error(f"Error building agent template: {exc}") return jsonify({'error': str(exc)}), 500 @@ -2984,14 +3709,35 @@ def save_agent_spec(self): """ try: data = request.get_json() or {} + scope = self._get_agent_scope() content = data.get("content") mode = data.get("mode", "create") existing_name = data.get("existing_name") if not content or not isinstance(content, str): return jsonify({'error': 'Content is required'}), 400 - agents_dir = self._get_agents_dir() - agents_dir.mkdir(parents=True, exist_ok=True) + if scope == "ab": + if not self._can_manage_ab_testing(): + return jsonify({"error": "A/B agent management requires admin access"}), 403 + created_by = session.get("user", {}).get("email") or session.get("user", {}).get("id") or data.get("client_id") or "system" + ab_service = self._get_ab_agent_spec_service() + if mode == "edit" or existing_name: + return jsonify({ + "error": "Editing A/B agent specs is not supported. Create a new A/B agent spec instead." + }), 400 + record = ab_service.save_spec( + content, + created_by=created_by, + ) + return jsonify({ + 'success': True, + 'name': record.name, + 'filename': record.filename, + 'path': None, + 'scope': scope, + }), 200 + + agents_dir = self._get_agent_dir_for_scope(scope, create=True) if mode == "edit" or existing_name: if not existing_name: @@ -3008,6 +3754,10 @@ def save_agent_spec(self): if not target_path: return jsonify({'error': f"Agent '{existing_name}' not found"}), 404 new_spec = load_agent_spec_from_text(content) + if new_spec.name != existing_name: + return jsonify({ + 'error': 'Agent name cannot be changed in edit mode. Create or clone a new agent instead.' + }), 400 for path in list_agent_files(agents_dir): if path == target_path: continue @@ -3018,18 +3768,12 @@ def save_agent_spec(self): if spec.name == new_spec.name: return jsonify({'error': f"Agent name '{new_spec.name}' already exists"}), 409 target_path.write_text(content) - try: - dynamic = get_dynamic_config() - except Exception: - dynamic = None - if dynamic and dynamic.active_agent_name == existing_name and new_spec.name != existing_name: - cfg = ConfigService(pg_config=self.pg_config) - cfg.update_dynamic_config(active_agent_name=new_spec.name, updated_by=data.get("client_id") or "system") return jsonify({ 'success': True, 'name': new_spec.name, 'filename': target_path.name, 'path': str(target_path), + 'scope': scope, }), 200 # create mode @@ -3062,7 +3806,10 @@ def save_agent_spec(self): 'name': spec.name, 'filename': target_path.name, 'path': str(target_path), + 'scope': scope, }), 200 + except PermissionError as exc: + return jsonify({"error": str(exc)}), 403 except AgentSpecError as exc: logger.error(f"Invalid agent spec: {exc}") return jsonify({'error': f'Invalid agent spec: {exc}'}), 400 @@ -3076,6 +3823,7 @@ def delete_agent_spec(self): """ try: data = request.get_json() or {} + scope = self._get_agent_scope() name = data.get("name") if not name: return jsonify({"error": "name is required"}), 400 @@ -3083,7 +3831,15 @@ def delete_agent_spec(self): if name.lower().startswith("name:"): name = name.split(":", 1)[1].strip() - agents_dir = self._get_agents_dir() + if scope == "ab": + if not self._can_manage_ab_testing(): + return jsonify({"error": "A/B agent management requires admin access"}), 403 + deleted = self._get_ab_agent_spec_service().delete_spec_by_name(name) + if not deleted: + return jsonify({"error": f"Agent '{name}' not found"}), 404 + return jsonify({"success": True, "deleted": name}), 200 + + agents_dir = self._get_agent_dir_for_scope(scope, create=(scope == "ab")) target_path = None for path in list_agent_files(agents_dir): try: @@ -3097,14 +3853,17 @@ def delete_agent_spec(self): return jsonify({"error": f"Agent '{name}' not found"}), 404 target_path.unlink() - try: - dynamic = get_dynamic_config() - except Exception: - dynamic = None - if dynamic and dynamic.active_agent_name == name: - cfg = ConfigService(pg_config=self.pg_config) - cfg.update_dynamic_config(active_agent_name=None, updated_by=data.get("client_id") or "system") + if scope != "ab": + try: + dynamic = get_dynamic_config() + except Exception: + dynamic = None + if dynamic and dynamic.active_agent_name == name: + cfg = ConfigService(pg_config=self.pg_config) + cfg.update_dynamic_config(active_agent_name=None, updated_by=data.get("client_id") or "system") return jsonify({"success": True, "deleted": name}), 200 + except PermissionError as exc: + return jsonify({"error": str(exc)}), 403 except Exception as exc: logger.error(f"Error deleting agent spec: {exc}") return jsonify({"error": str(exc)}), 500 @@ -3128,7 +3887,7 @@ def get_agent_info(self): config_payload = self.config chat_cfg = config_payload.get("services", {}).get("chat_app", {}) - agent_class = chat_cfg.get("agent_class") or chat_cfg.get("pipeline") + agent_class = ChatWrapper._get_agent_class_from_cfg(chat_cfg) embedding_name = config_payload.get("data_manager", {}).get("embedding_name") sources = config_payload.get("data_manager", {}).get("sources", {}) source_names = list(sources.keys()) if isinstance(sources, dict) else [] @@ -3448,6 +4207,179 @@ def validate_provider_api_key(self): 'error': str(e), }), 200 + def _get_request_client_id(self) -> str: + """Extract client_id from the current request (JSON body or query params).""" + if request.is_json and request.json: + cid = request.json.get('client_id') + if cid: + return cid + return request.args.get('client_id', '') + + def _is_admin_request(self) -> bool: + """Return True when the current request is from an RBAC admin user.""" + try: + return bool(rbac_is_admin()) + except Exception: + return False + + def _can_view_ab_testing(self) -> bool: + return ( + self._can_manage_ab_testing() + or has_permission(Permission.AB.VIEW) + or has_permission(Permission.AB.METRICS) + ) + + def _can_manage_ab_testing(self) -> bool: + return self._is_admin_request() or has_permission(Permission.AB.MANAGE) + + def _can_view_ab_metrics(self) -> bool: + return self._can_manage_ab_testing() or has_permission(Permission.AB.METRICS) + + def _current_request_roles(self) -> List[str]: + return list(session.get('roles', []) or []) + + def _current_request_permissions(self) -> List[str]: + return sorted(get_user_permissions(self._current_request_roles())) + + @staticmethod + def _current_user_id() -> Optional[str]: + user = session.get('user') or {} + return user.get('id') or session.get('client_id') or None + + def _get_effective_ab_sample_rate(self, default_rate: float) -> float: + effective_rate = float(default_rate) + user_id = self._current_user_id() + if not user_id: + return effective_rate + session_user = session.get('user') or {} + try: + user = self.chat.user_service.get_or_create_user( + user_id=user_id, + auth_provider=session_user.get('auth_method', 'anonymous') if session.get('logged_in') else 'anonymous', + display_name=session_user.get('name'), + email=session_user.get('email'), + ) + except Exception: + user = None + if user is not None and user.ab_participation_rate is not None: + effective_rate = float(user.ab_participation_rate) + return min(max(effective_rate, 0.0), 1.0) + + def _get_ab_participation_state(self) -> Dict[str, Any]: + pool = getattr(self.chat, "ab_pool", None) + can_participate = has_permission(Permission.AB.PARTICIPATE) + if not can_participate: + return { + "can_participate": False, + "eligible": False, + "reason": "not_participant", + "targeted": False, + } + if not pool or not pool.enabled: + return { + "can_participate": True, + "eligible": False, + "reason": "disabled", + "targeted": False, + } + targeted = pool.is_targeted_user( + roles=self._current_request_roles(), + permissions=self._current_request_permissions(), + ) + if not targeted: + return { + "can_participate": True, + "eligible": False, + "reason": "not_targeted", + "targeted": False, + } + return { + "can_participate": True, + "eligible": True, + "reason": "eligible", + "targeted": True, + } + + def _can_use_ab_testing(self) -> bool: + return bool(self._get_ab_participation_state()["eligible"]) + + def _refresh_runtime_config(self) -> None: + static = self.config_service.get_static_config(force_reload=True) + if static is None: + raise ValueError("Static config not initialized") + self.config = _static_config_to_full_config(static, config_service=self.config_service) + self.global_config = self.config["global"] + self.services_config = self.config["services"] + self.chat_app_config = self.services_config["chat_app"] + self.chat.reload_static_state() + + def _serialize_pending_ab_comparison( + self, + comparison, + ) -> Optional[Dict[str, Any]]: + if comparison is None: + return None + + mids = [comparison.response_a_mid, comparison.response_b_mid] + if not all(mids): + return None + + conn = psycopg2.connect(**self.pg_config) + cursor = conn.cursor() + try: + cursor.execute( + """ + SELECT message_id, sender, content, model_used + FROM conversations + WHERE message_id = ANY(%s) + ORDER BY message_id ASC + """, + (mids,), + ) + messages = { + row[0]: { + "message_id": row[0], + "sender": row[1], + "content": row[2], + "model_used": row[3], + "trace": self.chat.get_trace_by_message(row[0]), + } + for row in cursor.fetchall() + } + finally: + cursor.close() + conn.close() + + response_a = messages.get(comparison.response_a_mid) + response_b = messages.get(comparison.response_b_mid) + if not response_a or not response_b: + return None + + return { + "comparison_id": comparison.comparison_id, + "conversation_id": comparison.conversation_id, + "created_at": comparison.created_at.isoformat() if comparison.created_at else None, + "response_a": response_a, + "response_b": response_b, + "variant_a_name": comparison.variant_a_name, + "variant_b_name": comparison.variant_b_name, + "preference": comparison.preference, + "variant_label_mode": self.chat.ab_pool.variant_label_mode if self.chat.ab_pool else DEFAULT_DISCLOSURE_MODE, + "activity_panel_default_state": self.chat.ab_pool.activity_panel_default_state if self.chat.ab_pool else DEFAULT_TRACE_MODE, + } + + def _serialize_pending_ab_comparisons( + self, + comparisons, + ) -> List[Dict[str, Any]]: + """Serialize unresolved comparisons in stable creation order.""" + serialized: List[Dict[str, Any]] = [] + for comparison in comparisons or []: + payload = self._serialize_pending_ab_comparison(comparison) + if payload is not None: + serialized.append(payload) + return serialized + def _parse_chat_request(self) -> Dict[str, Any]: payload = request.get_json(silent=True) or {} @@ -3518,13 +4450,8 @@ def get_chat_response(self): # handle errors if error_code is not None: - if error_code == 408: - output = jsonify({'error': CLIENT_TIMEOUT_ERROR_MESSAGE}) - elif error_code == 403: - output = jsonify({'error': 'conversation not found'}) - else: - output = jsonify({'error': 'server error; see chat logs for message'}) - return output, error_code + err = ChatWrapper._error_event(error_code) + return jsonify({'error': err['message']}), error_code # compute timestamp at which message was returned to client timestamps['server_response_msg_ts'] = datetime.now(timezone.utc) @@ -3632,115 +4559,62 @@ def index(self): def terms(self): return render_template('terms.html') - def like(self): + def _with_feedback_lock(self, fn): + """Run fn() under the feedback lock with proper cleanup.""" self.chat.lock.acquire() logger.info("Acquired lock file") try: - data = request.json - message_id = data.get('message_id') - - if not message_id: - logger.warning("Like request missing message_id") - return jsonify({'error': 'message_id is required'}), 400 - - # Check current state for toggle behavior - current_reaction = self.chat.get_reaction_feedback(message_id) - - # Always delete existing reaction first - self.chat.delete_reaction_feedback(message_id) - - # If already liked, just remove (toggle off) - don't re-add - if current_reaction == 'like': - response = {'message': 'Reaction removed', 'state': None} - return jsonify(response), 200 - - # Otherwise, add the like - feedback = { - "message_id" : message_id, - "feedback" : "like", - "feedback_ts" : datetime.now(timezone.utc), - "feedback_msg" : None, - "incorrect" : None, - "unhelpful" : None, - "inappropriate": None, - } - self.chat.insert_feedback(feedback) - - response = {'message': 'Liked', 'state': 'like'} - return jsonify(response), 200 - + return fn() except Exception as e: logger.error(f"Request failed: {str(e)}") return jsonify({'error': str(e)}), 500 - finally: self.chat.lock.release() logger.info("Released lock file") - if self.chat.cursor is not None: self.chat.cursor.close() if self.chat.conn is not None: self.chat.conn.close() - def dislike(self): - self.chat.lock.acquire() - logger.info("Acquired lock file") - try: + def _toggle_reaction(self, reaction_type): + """Shared like/dislike toggle: remove if already set, else insert.""" + def _do(): data = request.json message_id = data.get('message_id') - if not message_id: - logger.warning("Dislike request missing message_id") + logger.warning(f"{reaction_type.capitalize()} request missing message_id") return jsonify({'error': 'message_id is required'}), 400 - feedback_msg = data.get('feedback_msg') - incorrect = data.get('incorrect') - unhelpful = data.get('unhelpful') - inappropriate = data.get('inappropriate') - - # Check current state for toggle behavior current_reaction = self.chat.get_reaction_feedback(message_id) - - # Always delete existing reaction first self.chat.delete_reaction_feedback(message_id) - # If already disliked, just remove (toggle off) - don't re-add - if current_reaction == 'dislike': - response = {'message': 'Reaction removed', 'state': None} - return jsonify(response), 200 + if current_reaction == reaction_type: + return jsonify({'message': 'Reaction removed', 'state': None}), 200 - # Otherwise, add the dislike feedback = { "message_id" : message_id, - "feedback" : "dislike", - "feedback_ts" : datetime.now(timezone.utc), - "feedback_msg" : feedback_msg, - "incorrect" : incorrect, - "unhelpful" : unhelpful, - "inappropriate": inappropriate, + "feedback" : reaction_type, + "feedback_ts" : datetime.now(), + "feedback_msg" : data.get('feedback_msg') if reaction_type == 'dislike' else None, + "incorrect" : data.get('incorrect') if reaction_type == 'dislike' else None, + "unhelpful" : data.get('unhelpful') if reaction_type == 'dislike' else None, + "inappropriate": data.get('inappropriate') if reaction_type == 'dislike' else None, } self.chat.insert_feedback(feedback) - response = {'message': 'Disliked', 'state': 'dislike'} - return jsonify(response), 200 + label = f"{reaction_type.capitalize()}d" + return jsonify({'message': label, 'state': reaction_type}), 200 - except Exception as e: - logger.error(f"Request failed: {str(e)}") - return jsonify({'error': str(e)}), 500 + return self._with_feedback_lock(_do) - finally: - self.chat.lock.release() - logger.info("Released lock file") + def like(self): + return self._toggle_reaction('like') - if self.chat.cursor is not None: - self.chat.cursor.close() - if self.chat.conn is not None: - self.chat.conn.close() + def dislike(self): + return self._toggle_reaction('dislike') def text_feedback(self): - self.chat.lock.acquire() - logger.info("Acquired lock file for text feedback") - try: + def _do(): data = request.json message_id = data.get('message_id') feedback_msg = (data.get('feedback_msg') or '').strip() @@ -3764,22 +4638,9 @@ def text_feedback(self): "inappropriate": None, } self.chat.insert_feedback(feedback) + return jsonify({'message': 'Feedback submitted'}), 200 - response = {'message': 'Feedback submitted'} - return jsonify(response), 200 - - except Exception as e: - logger.error(f"Request failed: {str(e)}") - return jsonify({'error': str(e)}), 500 - - finally: - self.chat.lock.release() - logger.info("Released lock file") - - if self.chat.cursor is not None: - self.chat.cursor.close() - if self.chat.conn is not None: - self.chat.conn.close() + return self._with_feedback_lock(_do) def list_conversations(self): """ @@ -3869,6 +4730,10 @@ def load_conversation(self): # get history of the conversation along with latest feedback state cursor.execute(SQL_QUERY_CONVO_WITH_FEEDBACK, (conversation_id, )) history_rows = cursor.fetchall() + comparisons = self.chat.conv_service.get_conversation_ab_comparisons(str(conversation_id)) + suppressed_ids = self.chat._suppressed_ab_message_ids(comparisons) + if suppressed_ids: + history_rows = [row for row in history_rows if row[2] not in suppressed_ids] history_rows = collapse_assistant_sequences(history_rows, sender_name=ARCHI_SENDER, sender_index=0) # Build messages list with trace data for assistant messages @@ -3913,12 +4778,17 @@ def load_conversation(self): messages.append(msg) + pending_comparisons = [c for c in comparisons if c.preference is None] + serialized_pending = self._serialize_pending_ab_comparisons(pending_comparisons) + conversation = { 'conversation_id': meta_row[0], 'title': meta_row[1] or "New Conversation", 'created_at': meta_row[2].isoformat() if meta_row[2] else None, 'last_message_at': meta_row[3].isoformat() if meta_row[3] else None, - 'messages': messages + 'messages': messages, + 'pending_ab_comparisons': serialized_pending, + 'pending_ab_comparison': serialized_pending[-1] if serialized_pending else None, } # clean up database connection state @@ -4002,74 +4872,6 @@ def delete_conversation(self): # A/B Testing API Endpoints # ========================================================================= - def ab_create_comparison(self): - """ - Create a new A/B comparison record linking two responses. - - POST body: - - conversation_id: The conversation ID - - user_prompt_mid: Message ID of the user's question - - response_a_mid: Message ID of response A - - response_b_mid: Message ID of response B - - config_a_id: Config ID used for response A - - config_b_id: Config ID used for response B - - is_config_a_first: True if config A was the "first" config before randomization - - client_id: Client ID for authorization - - Returns: - JSON with comparison_id - """ - try: - data = request.json - conversation_id = data.get('conversation_id') - user_prompt_mid = data.get('user_prompt_mid') - response_a_mid = data.get('response_a_mid') - response_b_mid = data.get('response_b_mid') - config_a_id = data.get('config_a_id') - config_b_id = data.get('config_b_id') - is_config_a_first = data.get('is_config_a_first', True) - client_id = data.get('client_id') - - # Validate required fields - missing = [] - if not conversation_id: - missing.append('conversation_id') - if not user_prompt_mid: - missing.append('user_prompt_mid') - if not response_a_mid: - missing.append('response_a_mid') - if not response_b_mid: - missing.append('response_b_mid') - if not config_a_id: - missing.append('config_a_id') - if not config_b_id: - missing.append('config_b_id') - if not client_id: - missing.append('client_id') - - if missing: - return jsonify({'error': f'Missing required fields: {", ".join(missing)}'}), 400 - - # Create the comparison - comparison_id = self.chat.create_ab_comparison( - conversation_id=conversation_id, - user_prompt_mid=user_prompt_mid, - response_a_mid=response_a_mid, - response_b_mid=response_b_mid, - config_a_id=config_a_id, - config_b_id=config_b_id, - is_config_a_first=is_config_a_first, - ) - - return jsonify({ - 'success': True, - 'comparison_id': comparison_id, - }), 200 - - except Exception as e: - logger.error(f"Error creating A/B comparison: {str(e)}") - return jsonify({'error': str(e)}), 500 - def ab_submit_preference(self): """ Submit user's preference for an A/B comparison. @@ -4087,6 +4889,7 @@ def ab_submit_preference(self): comparison_id = data.get('comparison_id') preference = data.get('preference') client_id = data.get('client_id') + user_id = session.get('user', {}).get('id') or None if not comparison_id: return jsonify({'error': 'comparison_id is required'}), 400 @@ -4097,13 +4900,25 @@ def ab_submit_preference(self): if not client_id: return jsonify({'error': 'client_id is required'}), 400 - # Update the preference - self.chat.update_ab_preference(comparison_id, preference) + # Verify the comparison belongs to the requesting client + comparison = self.chat.conv_service.get_ab_comparison(comparison_id) + if not comparison: + return jsonify({'error': 'Comparison not found'}), 404 + comp_conv_id = comparison.conversation_id if comparison else None + if comp_conv_id: + try: + self.chat.query_conversation_history(comp_conv_id, client_id, user_id) + except ConversationAccessError: + return jsonify({'error': 'Not authorized for this comparison'}), 403 + + result = self.chat.conv_service.submit_ab_preference(comparison_id, preference) return jsonify({ 'success': True, 'comparison_id': comparison_id, 'preference': preference, + 'updated': result.get('updated', False), + 'canonical_message_id': self.chat._comparison_canonical_message_id(result.get('comparison')), }), 200 except ValueError as e: @@ -4126,23 +4941,483 @@ def ab_get_pending(self): try: conversation_id = request.args.get('conversation_id', type=int) client_id = request.args.get('client_id') + user_id = session.get('user', {}).get('id') or None if not conversation_id: return jsonify({'error': 'conversation_id is required'}), 400 if not client_id: return jsonify({'error': 'client_id is required'}), 400 - comparison = self.chat.get_pending_ab_comparison(conversation_id) + try: + self.chat.query_conversation_history(conversation_id, client_id, user_id) + except ConversationAccessError: + return jsonify({'error': 'Not authorized for this conversation'}), 403 + comparisons = self.chat.conv_service.get_pending_ab_comparisons(conversation_id) + serialized = self._serialize_pending_ab_comparisons(comparisons) return jsonify({ 'success': True, - 'comparison': comparison, + 'comparison': serialized[-1] if serialized else None, + 'comparisons': serialized, + 'pending_count': len(serialized), }), 200 except Exception as e: logger.error(f"Error getting pending A/B comparison: {str(e)}") return jsonify({'error': str(e)}), 500 + def ab_get_pool(self): + """ + Get A/B testing pool configuration. + + Returns: + JSON with pool info (enabled, champion, variant names) or enabled=false. + Only admins see the full pool info; non-admins get enabled=false. + """ + try: + pool = self.chat.ab_pool + can_view = self._can_view_ab_testing() + can_manage = self._can_manage_ab_testing() + participation = self._get_ab_participation_state() + can_use = participation["eligible"] + can_participate = participation["can_participate"] + raw_ab_cfg = ((self.services_config.get("chat_app", {}) or {}).get("ab_testing") or {}) + default_comparison_rate = float( + self._get_ab_setting( + raw_ab_cfg, + "comparison_rate", + "sample_rate", + getattr(pool, "comparison_rate", getattr(pool, "sample_rate", 1.0) or 1.0), + ) + ) + if can_view: + return jsonify(self._build_admin_ab_pool_payload()), 200 + if pool and can_use: + effective_rate = self._get_effective_ab_sample_rate(pool.sample_rate) + return jsonify({ + 'success': True, + 'is_admin': self._is_admin_request(), + 'can_view': False, + 'can_manage': False, + 'can_view_metrics': False, + 'can_participate': can_participate, + 'participant_eligible': True, + 'participant_reason': participation["reason"], + 'participant_targeted': True, + **pool.participant_info(), + 'comparison_rate': effective_rate, + 'default_comparison_rate': default_comparison_rate, + }), 200 + return jsonify({ + 'success': True, + 'enabled': False, + 'enabled_requested': bool(raw_ab_cfg.get('enabled', False)), + 'is_admin': self._is_admin_request(), + 'can_view': can_view, + 'can_manage': can_manage, + 'can_view_metrics': self._can_view_ab_metrics(), + 'can_participate': can_participate, + 'participant_eligible': participation["eligible"], + 'participant_reason': participation["reason"], + 'participant_targeted': participation["targeted"], + 'comparison_rate': self._get_effective_ab_sample_rate(default_comparison_rate) if can_participate else default_comparison_rate, + 'default_comparison_rate': default_comparison_rate, + }), 200 + except Exception as e: + logger.error(f"Error getting A/B pool: {str(e)}") + return jsonify({'error': str(e)}), 500 + + def ab_get_decision(self): + """ + Decide on the server whether the next turn should use A/B comparison. + + This keeps sampling authoritative on the backend instead of relying on + browser-side Math.random(). + """ + try: + client_id = request.args.get('client_id', '') + conversation_id = request.args.get('conversation_id', type=int) + user_id = session.get('user', {}).get('id') or None + pool = self.chat.ab_pool + participation = self._get_ab_participation_state() + + if not pool or not pool.enabled: + return jsonify({'success': True, 'enabled': False, 'use_ab': False, 'reason': 'disabled'}), 200 + if not participation["can_participate"]: + return jsonify({'success': True, 'enabled': True, 'use_ab': False, 'reason': 'not_participant'}), 200 + if not participation["eligible"]: + return jsonify({'success': True, 'enabled': True, 'use_ab': False, 'reason': participation["reason"]}), 200 + + if conversation_id: + try: + self.chat.query_conversation_history(conversation_id, client_id, user_id) + except ConversationAccessError: + return jsonify({'error': 'Not authorized for this conversation'}), 403 + + pending_count = self.chat.conv_service.count_pending_ab_comparisons(conversation_id) + if pending_count >= int(pool.max_pending_per_conversation): + pending = self.chat.conv_service.get_pending_ab_comparison(conversation_id) + return jsonify({ + 'success': True, + 'enabled': True, + 'use_ab': False, + 'reason': 'pending_vote', + 'comparison_id': getattr(pending, 'comparison_id', None), + 'pending_count': pending_count, + 'max_pending_comparisons_per_conversation': pool.max_pending_comparisons_per_conversation, + }), 200 + + sample_rate = self._get_effective_ab_sample_rate(pool.sample_rate) + if sample_rate <= 0: + use_ab = False + roll = None + elif sample_rate >= 1: + use_ab = True + roll = None + else: + roll = random.random() + use_ab = roll < sample_rate + + logger.info( + "A/B decision: use_ab=%s comparison_rate=%.3f roll=%s conversation_id=%s client_id=%s", + use_ab, + sample_rate, + "forced" if roll is None else f"{roll:.5f}", + conversation_id, + client_id, + ) + + return jsonify({ + 'success': True, + 'enabled': True, + 'use_ab': use_ab, + 'reason': 'sampled' if use_ab else 'not_sampled', + 'comparison_rate': sample_rate, + 'default_comparison_rate': float(pool.comparison_rate), + 'variant_label_mode': pool.variant_label_mode, + 'activity_panel_default_state': pool.activity_panel_default_state, + 'max_pending_comparisons_per_conversation': pool.max_pending_comparisons_per_conversation, + }), 200 + except Exception as e: + logger.error(f"Error deciding A/B sampling: {str(e)}") + return jsonify({'error': str(e)}), 500 + + def ab_set_pool(self): + """ + Set the A/B testing pool from the UI. + Admin only. Accepts JSON with champion label plus at least two variants. + """ + if not self._can_manage_ab_testing(): + return jsonify({'error': 'Admin access required'}), 403 + try: + data = request.get_json(force=True) + champion_name = str(data.get('champion') or data.get('control') or '').strip() + comparison_rate = float(data.get('comparison_rate', data.get('sample_rate', 1.0))) + variant_label_mode = normalize_ab_disclosure_mode( + data.get('variant_label_mode') or data.get('disclosure_mode') or DEFAULT_DISCLOSURE_MODE + ) + activity_panel_default_state = normalize_ab_trace_mode( + data.get('activity_panel_default_state') or data.get('default_trace_mode') or DEFAULT_TRACE_MODE + ) + max_pending = int( + data.get('max_pending_comparisons_per_conversation', data.get('max_pending_per_conversation', 1)) + ) + if not champion_name: + return jsonify({'error': 'champion is required'}), 400 + variant_items = data.get('variants') or [] + existing_variants = { + variant.label: variant for variant in (self.chat.ab_pool.variants if self.chat.ab_pool else []) + } + variants, parsed_labels = self._resolve_ab_variants( + variant_items, + existing_variants=existing_variants, + ) + if champion_name not in parsed_labels: + return jsonify({'error': 'Champion must be one of the variants'}), 400 + + pool = ABPool( + variants=variants, + champion_name=champion_name, + enabled=True, + sample_rate=comparison_rate, + disclosure_mode=variant_label_mode, + default_trace_mode=activity_panel_default_state, + max_pending_per_conversation=max_pending, + ) + self._persist_ab_pool_config( + enabled=True, + champion_name=champion_name, + variants=variants, + comparison_rate=pool.comparison_rate, + variant_label_mode=pool.variant_label_mode, + activity_panel_default_state=pool.activity_panel_default_state, + max_pending_comparisons_per_conversation=pool.max_pending_comparisons_per_conversation, + ) + logger.info("Persisted A/B pool update: champion='%s', variants=%s", champion_name, parsed_labels) + return jsonify(self._build_admin_ab_pool_payload()), 200 + except ABPoolError as exc: + return jsonify({'error': str(exc)}), 400 + except Exception as exc: + logger.error("Error setting A/B pool: %s", exc) + return jsonify({'error': str(exc)}), 500 + + def ab_set_settings(self): + """Persist only the experiment-settings section of the A/B admin page.""" + if not self._can_manage_ab_testing(): + return jsonify({'error': 'Admin access required'}), 403 + try: + data = request.get_json(force=True) + chat_cfg = self.services_config.get("chat_app", {}) or {} + raw_ab_cfg = (chat_cfg.get("ab_testing") or {}) if isinstance(chat_cfg.get("ab_testing"), dict) else {} + raw_pool = raw_ab_cfg.get("pool") or {} + champion_name = str( + data.get('champion') or data.get('control') or self._get_ab_pool_champion(raw_pool) or '' + ).strip() + if not champion_name: + return jsonify({'error': 'champion is required'}), 400 + + existing_variants = { + variant.label: variant for variant in (self.chat.ab_pool.variants if self.chat.ab_pool else []) + } + variant_items = self._normalize_ab_variant_details(raw_pool.get("variants")) + variants, parsed_labels = self._resolve_ab_variants( + variant_items, + existing_variants=existing_variants, + ) + if champion_name not in parsed_labels: + return jsonify({'error': 'Champion must match one of the saved variants'}), 400 + + pool = ABPool( + variants=variants, + champion_name=champion_name, + enabled=True, + sample_rate=float( + data.get( + 'comparison_rate', + data.get( + 'sample_rate', + self._get_ab_setting(raw_ab_cfg, 'comparison_rate', 'sample_rate', 1.0), + ), + ) + ), + disclosure_mode=normalize_ab_disclosure_mode( + data.get('variant_label_mode') + or data.get('disclosure_mode') + or self._get_ab_setting(raw_ab_cfg, 'variant_label_mode', 'disclosure_mode', DEFAULT_DISCLOSURE_MODE) + ), + default_trace_mode=normalize_ab_trace_mode( + data.get('activity_panel_default_state') + or data.get('default_trace_mode') + or self._get_ab_setting( + raw_ab_cfg, 'activity_panel_default_state', 'default_trace_mode', DEFAULT_TRACE_MODE + ) + ), + max_pending_per_conversation=int( + data.get( + 'max_pending_comparisons_per_conversation', + data.get( + 'max_pending_per_conversation', + self._get_ab_setting( + raw_ab_cfg, + 'max_pending_comparisons_per_conversation', + 'max_pending_per_conversation', + 1, + ), + ), + ) + ), + ) + self._persist_ab_pool_config( + enabled=True, + champion_name=champion_name, + variants=variants, + comparison_rate=pool.comparison_rate, + variant_label_mode=pool.variant_label_mode, + activity_panel_default_state=pool.activity_panel_default_state, + max_pending_comparisons_per_conversation=pool.max_pending_comparisons_per_conversation, + ) + logger.info("Persisted A/B settings update: champion='%s'", champion_name) + return jsonify(self._build_admin_ab_pool_payload()), 200 + except ABPoolError as exc: + return jsonify({'error': str(exc)}), 400 + except Exception as exc: + logger.error("Error setting A/B experiment settings: %s", exc) + return jsonify({'error': str(exc)}), 500 + + def ab_set_variants(self): + """Persist only the variant-list section of the A/B admin page.""" + if not self._can_manage_ab_testing(): + return jsonify({'error': 'Admin access required'}), 403 + try: + data = request.get_json(force=True) + variant_items = data.get('variants') or [] + chat_cfg = self.services_config.get("chat_app", {}) or {} + raw_ab_cfg = (chat_cfg.get("ab_testing") or {}) if isinstance(chat_cfg.get("ab_testing"), dict) else {} + raw_pool = raw_ab_cfg.get("pool") or {} + existing_variants = { + variant.label: variant for variant in (self.chat.ab_pool.variants if self.chat.ab_pool else []) + } + variants, parsed_labels = self._resolve_ab_variants( + variant_items, + existing_variants=existing_variants, + ) + + champion_name = self._get_ab_pool_champion(raw_pool) + if champion_name not in parsed_labels: + champion_name = parsed_labels[0] + + comparison_rate = float(self._get_ab_setting(raw_ab_cfg, 'comparison_rate', 'sample_rate', 1.0)) + variant_label_mode = normalize_ab_disclosure_mode( + self._get_ab_setting(raw_ab_cfg, 'variant_label_mode', 'disclosure_mode', DEFAULT_DISCLOSURE_MODE) + ) + activity_panel_default_state = normalize_ab_trace_mode( + self._get_ab_setting( + raw_ab_cfg, 'activity_panel_default_state', 'default_trace_mode', DEFAULT_TRACE_MODE + ) + ) + max_pending = int( + self._get_ab_setting( + raw_ab_cfg, + 'max_pending_comparisons_per_conversation', + 'max_pending_per_conversation', + 1, + ) + ) + enabled_requested = bool(raw_ab_cfg.get('enabled', False)) + + # Validate the resulting pool shape even if currently disabled. + ABPool( + variants=variants, + champion_name=champion_name, + enabled=True, + sample_rate=comparison_rate, + disclosure_mode=variant_label_mode, + default_trace_mode=activity_panel_default_state, + max_pending_per_conversation=max_pending, + ) + self._persist_ab_pool_config( + enabled=enabled_requested, + champion_name=champion_name, + variants=variants, + comparison_rate=comparison_rate, + variant_label_mode=variant_label_mode, + activity_panel_default_state=activity_panel_default_state, + max_pending_comparisons_per_conversation=max_pending, + ) + logger.info("Persisted A/B variants update: champion='%s', variants=%s", champion_name, parsed_labels) + return jsonify(self._build_admin_ab_pool_payload()), 200 + except ABPoolError as exc: + return jsonify({'error': str(exc)}), 400 + except Exception as exc: + logger.error("Error setting A/B variants: %s", exc) + return jsonify({'error': str(exc)}), 500 + + def ab_disable_pool(self): + """ + Disable (clear) the A/B testing pool. Admin only. + + Note: Pool state is ephemeral (in-memory only). Changes made via the + UI will be lost on server restart. The pool reverts to whatever is + configured in config.yaml. + """ + if not self._can_manage_ab_testing(): + return jsonify({'error': 'Admin access required'}), 403 + try: + self.config_service.update_services_config({ + "chat_app": { + "ab_testing": { + "enabled": False, + } + } + }) + self._refresh_runtime_config() + logger.info("Persisted A/B pool disable") + return jsonify(self._build_admin_ab_pool_payload()), 200 + except Exception as exc: + logger.error("Error disabling A/B pool: %s", exc) + return jsonify({'error': str(exc)}), 500 + + def ab_compare_stream(self): + """ + Stream a pool-based A/B comparison (champion vs variant). + + POST body: + - message: [sender, content] pair + - conversation_id: The conversation ID (optional) + - client_id: Client ID for authorization + - config_name: Config name (optional) + + Returns: + NDJSON stream with arm-tagged events. + """ + if not self._can_use_ab_testing(): + return jsonify({'error': 'A/B testing is not enabled for this user'}), 403 + + server_received_msg_ts = datetime.now() + request_data = self._parse_chat_request() + + message = request_data["message"] + conversation_id = request_data["conversation_id"] + config_name = request_data["config_name"] + is_refresh = request_data["is_refresh"] + client_sent_msg_ts = request_data["client_sent_msg_ts"] + client_timeout = request_data["client_timeout"] + client_id = request_data["client_id"] + provider = request_data["provider"] + model = request_data["model"] + user_id = session.get('user', {}).get('id') or None + session_api_key = None + + if not client_id: + return jsonify({"error": "client_id missing"}), 400 + + if provider and 'provider_api_keys' in session: + session_api_key = session.get('provider_api_keys', {}).get(provider.lower()) + + if conversation_id: + pending_count = self.chat.conv_service.count_pending_ab_comparisons(conversation_id) + max_pending = ( + int(self.chat.ab_pool.max_pending_comparisons_per_conversation) + if self.chat.ab_pool else 1 + ) + if pending_count >= max_pending: + return jsonify({ + 'error': 'Resolve one of the pending comparisons before sending another message', + 'pending_count': pending_count, + 'max_pending_comparisons_per_conversation': max_pending, + }), 409 + + return self._ndjson_response(self.chat.stream_ab_comparison( + message, + conversation_id, + client_id, + is_refresh, + server_received_msg_ts, + client_sent_msg_ts, + client_timeout, + config_name, + user_id=user_id, + provider=provider, + model=model, + provider_api_key=session_api_key, + )) + + def ab_get_metrics(self): + """ + Get per-variant A/B testing metrics. Admin only. + + Returns: + JSON with variant metrics (wins, losses, ties, total). + """ + if not self._can_view_ab_metrics(): + return jsonify({'error': 'Admin access required'}), 403 + try: + metrics = self.chat.conv_service.get_all_variant_metrics() + return jsonify({'success': True, 'metrics': metrics}), 200 + except Exception as e: + logger.error(f"Error getting A/B metrics: {str(e)}") + return jsonify({'error': str(e)}), 500 + # ========================================================================= # Agent Trace Endpoints # ========================================================================= @@ -4238,7 +5513,21 @@ def cancel_stream(self): def data_viewer_page(self): """Render the data viewer page.""" - return render_template('data.html') + return render_template( + 'data.html', + can_view_ab_testing=self._can_view_ab_testing(), + ) + + def ab_testing_admin_page(self): + """Render the dedicated admin A/B testing management page.""" + can_view = self._can_view_ab_testing() + if not can_view: + return "Forbidden", 403 + return render_template( + 'ab_testing.html', + can_manage_ab_testing=self._can_manage_ab_testing(), + can_view_ab_metrics=self._can_view_ab_metrics(), + ) def list_data_documents(self): """ @@ -4655,6 +5944,49 @@ def upload_git(self): logger.error(f"Error cloning Git repo: {str(e)}") return jsonify({"error": str(e)}), 500 + def _delete_source_documents(self, source_type: str, where_clause: str, params: tuple, label: str): + """ + Shared helper: mark documents as deleted and remove their chunks. + + Args: + source_type: 'git' or 'jira' + where_clause: SQL WHERE fragment after 'source_type = %s AND NOT is_deleted AND' + params: bind-parameters for the WHERE clause + label: human-readable label for log/response messages + """ + conn = psycopg2.connect(**self.chat.pg_config) + try: + with conn.cursor() as cursor: + # Get resource hashes of documents to delete + cursor.execute( + f"SELECT resource_hash FROM documents WHERE source_type = %s AND NOT is_deleted AND {where_clause}", + (source_type, *params), + ) + hashes_to_delete = [row[0] for row in cursor.fetchall()] + + if hashes_to_delete: + cursor.execute( + "DELETE FROM document_chunks WHERE metadata->>'resource_hash' = ANY(%s)", + (hashes_to_delete,), + ) + logger.info(f"Deleted {cursor.rowcount} chunks for {len(hashes_to_delete)} {source_type} documents") + + cursor.execute( + f"UPDATE documents SET is_deleted = TRUE, deleted_at = NOW() WHERE source_type = %s AND NOT is_deleted AND {where_clause}", + (source_type, *params), + ) + deleted_count = cursor.rowcount + conn.commit() + + logger.info(f"Deleted {deleted_count} documents from {label}") + return jsonify({ + "success": True, + "deleted_count": deleted_count, + "message": f"Removed {deleted_count} documents from {label}", + }), 200 + finally: + conn.close() + def _delete_git_repo(self): """ Delete a Git repository and all its indexed documents. @@ -4667,56 +5999,12 @@ def _delete_git_repo(self): if not repo_name: return jsonify({"error": "missing_repo_name"}), 400 - # Build a pattern to match the repo URL - # repo_name could be a URL (https://github.com/org/repo) or just a repo name (org/repo) - # URLs in database are like: https://github.com/pallets/click/blob/main/file.py - conn = psycopg2.connect(**self.chat.pg_config) - try: - with conn.cursor() as cursor: - # First, get the resource hashes of documents to delete - cursor.execute(""" - SELECT resource_hash FROM documents - WHERE source_type = 'git' - AND NOT is_deleted - AND ( - url LIKE %s - OR url LIKE %s - ) - """, (f'{repo_name}/%', f'%/{repo_name}/%')) - hashes_to_delete = [row[0] for row in cursor.fetchall()] - - if hashes_to_delete: - # Delete chunks for these documents - cursor.execute(""" - DELETE FROM document_chunks - WHERE metadata->>'resource_hash' = ANY(%s) - """, (hashes_to_delete,)) - chunks_deleted = cursor.rowcount - logger.info(f"Deleted {chunks_deleted} chunks for {len(hashes_to_delete)} documents") - - # Mark documents as deleted - cursor.execute(""" - UPDATE documents - SET is_deleted = TRUE, deleted_at = NOW() - WHERE source_type = 'git' - AND NOT is_deleted - AND ( - url LIKE %s - OR url LIKE %s - ) - """, (f'{repo_name}/%', f'%/{repo_name}/%')) - deleted_count = cursor.rowcount - conn.commit() - - logger.info(f"Deleted {deleted_count} documents from git repo: {repo_name}") - return jsonify({ - "success": True, - "deleted_count": deleted_count, - "message": f"Removed {deleted_count} documents from repository" - }), 200 - finally: - conn.close() - + return self._delete_source_documents( + source_type='git', + where_clause='(url LIKE %s OR url LIKE %s)', + params=(f'{repo_name}/%', f'%/{repo_name}/%'), + label=f"git repo: {repo_name}", + ) except Exception as e: logger.error(f"Error deleting Git repo: {str(e)}") return jsonify({"error": str(e)}), 500 @@ -5203,47 +6491,12 @@ def _delete_jira_project(self): if not project_key: return jsonify({"error": "missing_project_key"}), 400 - conn = psycopg2.connect(**self.chat.pg_config) - try: - with conn.cursor() as cursor: - # First, get the resource hashes of documents to delete - cursor.execute(""" - SELECT resource_hash FROM documents - WHERE source_type = 'jira' - AND NOT is_deleted - AND display_name LIKE %s - """, (f'{project_key}-%',)) - hashes_to_delete = [row[0] for row in cursor.fetchall()] - - if hashes_to_delete: - # Delete chunks for these documents - cursor.execute(""" - DELETE FROM document_chunks - WHERE metadata->>'resource_hash' = ANY(%s) - """, (hashes_to_delete,)) - chunks_deleted = cursor.rowcount - logger.info(f"Deleted {chunks_deleted} chunks for {len(hashes_to_delete)} Jira documents") - - # Mark documents from this Jira project as deleted - cursor.execute(""" - UPDATE documents - SET is_deleted = TRUE, deleted_at = NOW() - WHERE source_type = 'jira' - AND NOT is_deleted - AND display_name LIKE %s - """, (f'{project_key}-%',)) - deleted_count = cursor.rowcount - conn.commit() - - logger.info(f"Deleted {deleted_count} documents from Jira project: {project_key}") - return jsonify({ - "success": True, - "deleted_count": deleted_count, - "message": f"Removed {deleted_count} tickets from project {project_key}" - }), 200 - finally: - conn.close() - + return self._delete_source_documents( + source_type='jira', + where_clause='display_name LIKE %s', + params=(f'{project_key}-%',), + label=f"Jira project: {project_key}", + ) except Exception as e: logger.error(f"Error deleting Jira project: {str(e)}") return jsonify({"error": str(e)}), 500 diff --git a/src/interfaces/chat_app/event_formatter.py b/src/interfaces/chat_app/event_formatter.py new file mode 100644 index 000000000..25b8e8d58 --- /dev/null +++ b/src/interfaces/chat_app/event_formatter.py @@ -0,0 +1,333 @@ +"""Shared event formatter for converting PipelineOutput into structured JSON events. + +Both Chat.stream() (regular chat) and _stream_arm() (A/B testing) use this +formatter so event structure is defined in exactly one place. + +Usage:: + + formatter = PipelineEventFormatter(message_content_fn=self._message_content) + for output in pipeline.stream(...): + for event in formatter.process(output): + # caller adds context fields (arm, conversation_id, timestamp …) + yield event +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any, Callable, Dict, Iterator, Optional + +from src.archi.utils.output_dataclass import PipelineOutput + +logger = logging.getLogger(__name__) + + +class PipelineEventFormatter: + """Stateful converter: PipelineOutput → structured streaming event dicts. + + Key behaviours: + * **Deferred tool_start** – on ``tool_start`` events the formatter parses + and *remembers* tool calls but does NOT yield events. When the + corresponding ``tool_output`` arrives it yields ``tool_start`` then + ``tool_output``, ensuring every tool-start has a matching output. + * **Progressive merging** – tool info is aggregated from ``tool_calls``, + ``additional_kwargs.tool_calls``, ``tool_call_chunks``, and + ``metadata.tool_inputs_by_id`` so the emitted ``tool_start`` carries the + best-available name and args. + * **Caller decorates** – yielded events contain only the canonical fields + (``type`` + type-specific data). Callers add ``conversation_id``, + ``timestamp``, ``arm``, etc. + """ + + def __init__( + self, + *, + message_content_fn: Callable, + max_step_chars: int = 800, + ) -> None: + self._message_content = message_content_fn + self._max_chars = max_step_chars + + # Tool-call tracking + self._emitted_ids: set[str] = set() # all tool_call_ids we've seen + self._emitted_start_ids: set[str] = set() # ids we've yielded tool_start for + self._pending_ids: list[str] = [] # ids awaiting their output (ordered) + self._calls: Dict[str, Dict[str, Any]] = {} # id → {tool_name, tool_args} + self._synthetic_counter: int = 0 + + # Public counters for callers + self.tool_call_count: int = 0 + self.last_text: str = "" + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def process(self, output: PipelineOutput) -> Iterator[Dict[str, Any]]: + """Yield zero or more event dicts from a single *PipelineOutput*.""" + if not isinstance(output, PipelineOutput): + return + + meta = output.metadata or {} + event_type = meta.get("event_type", "text") + + handler = { + "tool_start": self._on_tool_start, + "tool_output": self._on_tool_output, + "tool_end": self._on_tool_end, + "thinking_start": self._on_thinking_start, + "thinking_end": self._on_thinking_end, + "text": self._on_text, + "final": self._on_final, + }.get(event_type) + + if handler: + yield from handler(output, meta) + else: + yield from self._on_unknown(output, meta, event_type) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _next_id(self, name: str) -> str: + self._synthetic_counter += 1 + safe = re.sub(r"[^a-zA-Z0-9_]+", "_", (name or "unknown")).strip("_") or "unknown" + return f"synthetic_tool_{self._synthetic_counter}_{safe}" + + @staticmethod + def _empty_args(args: Any) -> bool: + return args in (None, "", {}, []) + + @staticmethod + def _meaningful(name: Any, args: Any) -> bool: + if isinstance(name, str) and name.strip() and name.strip().lower() != "unknown": + return True + return args not in (None, "", {}, []) + + def _remember(self, tc_id: str, name: Any, args: Any) -> None: + if not tc_id: + return + cur = self._calls.get(tc_id, {}) + merged_name = ( + name + if isinstance(name, str) and name.strip() and name.strip().lower() != "unknown" + else cur.get("tool_name", "unknown") + ) + merged_args = args if not self._empty_args(args) else cur.get("tool_args", {}) + self._calls[tc_id] = { + "tool_name": merged_name or "unknown", + "tool_args": merged_args, + } + + def _truncate(self, text: str) -> tuple: + """Return (display_text, truncated_bool, full_length_or_None).""" + if self._max_chars and len(text) > self._max_chars: + return text[: self._max_chars - 3].rstrip() + "...", True, len(text) + return text, False, None + + # ------------------------------------------------------------------ + # Raw arg extraction from message additional_kwargs / tool_call_chunks + # ------------------------------------------------------------------ + + @staticmethod + def _extract_raw_tool_info(msg) -> tuple: + """Parse additional_kwargs.tool_calls and tool_call_chunks. + + Returns (raw_args_by_id, raw_names_by_id) dicts. + """ + raw_args: Dict[str, Any] = {} + raw_names: Dict[str, str] = {} + if msg is None: + return raw_args, raw_names + + try: + additional = getattr(msg, "additional_kwargs", {}) or {} + for raw_call in additional.get("tool_calls") or []: + if not isinstance(raw_call, dict): + continue + rid = raw_call.get("id") + fn = raw_call.get("function") or {} + rname = fn.get("name") + rargs = fn.get("arguments") + parsed = _try_parse_args(rargs) + if rid and parsed is not None: + raw_args[rid] = parsed + if rid and isinstance(rname, str) and rname.strip(): + raw_names[rid] = rname.strip() + + for chunk in getattr(msg, "tool_call_chunks", []) or []: + if not isinstance(chunk, dict): + continue + cid = chunk.get("id") + cname = chunk.get("name") + cargs = chunk.get("args") + parsed = _try_parse_args(cargs) + if cid and parsed is not None: + raw_args[cid] = parsed + if cid and isinstance(cname, str) and cname.strip(): + raw_names[cid] = cname.strip() + except Exception: + pass + + return raw_args, raw_names + + # ------------------------------------------------------------------ + # Event handlers + # ------------------------------------------------------------------ + + def _on_tool_start(self, output: PipelineOutput, meta: dict) -> Iterator[Dict[str, Any]]: + """Parse and remember tool calls; actual emission is deferred.""" + msg = (output.messages or [None])[0] + tool_calls = getattr(msg, "tool_calls", None) if msg else None + memory = meta.get("tool_inputs_by_id", {}) or {} + raw_args, raw_names = self._extract_raw_tool_info(msg) + + if tool_calls: + for tc in tool_calls: + tc_id = tc.get("id", "") + args = tc.get("args", {}) + if self._empty_args(args): + args = raw_args.get(tc_id, args) + if self._empty_args(args): + fb = memory.get(tc_id, {}) + if isinstance(fb, dict): + args = fb.get("tool_input", args) + name = tc.get("name", "unknown") + if (not name or str(name).strip().lower() == "unknown") and tc_id in raw_names: + name = raw_names[tc_id] + if not name and isinstance(memory.get(tc_id), dict): + name = memory[tc_id].get("tool_name", "unknown") + if not tc_id and not self._meaningful(name, args): + continue + if not tc_id: + tc_id = self._next_id(name) + self._remember(tc_id, name, args) + if tc_id in self._emitted_ids: + continue + self._emitted_ids.add(tc_id) + self._pending_ids.append(tc_id) + self.tool_call_count += 1 + elif memory: + for mid, mc in memory.items(): + if not isinstance(mc, dict): + continue + name = mc.get("tool_name", "unknown") + args = mc.get("tool_input", {}) + if not self._meaningful(name, args): + continue + tc_id = mid or self._next_id(name) + if tc_id in self._emitted_ids: + continue + self._emitted_ids.add(tc_id) + self._pending_ids.append(tc_id) + self._remember(tc_id, name, args) + self.tool_call_count += 1 + + # Deferred – don't yield anything here + return () + + def _on_tool_output(self, output: PipelineOutput, meta: dict) -> Iterator[Dict[str, Any]]: + """Emit deferred tool_start (if needed) then tool_output.""" + msg = (output.messages or [None])[0] + tool_output = self._message_content(msg) if msg else "" + tc_id = getattr(msg, "tool_call_id", "") if msg else "" + + if not tc_id and self._pending_ids: + tc_id = self._pending_ids.pop(0) + elif tc_id in self._pending_ids: + self._pending_ids.remove(tc_id) + + # Emit deferred tool_start if not yet sent + if tc_id and tc_id not in self._emitted_start_ids: + memory = meta.get("tool_inputs_by_id", {}) or {} + fb = memory.get(tc_id, {}) + fb_name: str = "unknown" + fb_args: Any = {} + if isinstance(fb, dict): + fb_name = fb.get("tool_name", "unknown") + fb_args = fb.get("tool_input", {}) + self._remember(tc_id, fb_name, fb_args) + info = self._calls.get(tc_id, {}) + self._emitted_start_ids.add(tc_id) + yield { + "type": "tool_start", + "tool_call_id": tc_id, + "tool_name": info.get("tool_name", "unknown"), + "tool_args": info.get("tool_args", {}), + } + + display, truncated, full_length = self._truncate(tool_output) + evt: Dict[str, Any] = { + "type": "tool_output", + "tool_call_id": tc_id, + "output": display, + "truncated": truncated, + } + if full_length is not None: + evt["full_length"] = full_length + yield evt + + def _on_tool_end(self, _output: PipelineOutput, meta: dict) -> Iterator[Dict[str, Any]]: + yield { + "type": "tool_end", + "tool_call_id": meta.get("tool_call_id", ""), + "status": meta.get("status", "success"), + "duration_ms": meta.get("duration_ms"), + } + + def _on_thinking_start(self, _output: PipelineOutput, meta: dict) -> Iterator[Dict[str, Any]]: + yield { + "type": "thinking_start", + "step_id": meta.get("step_id", ""), + } + + def _on_thinking_end(self, _output: PipelineOutput, meta: dict) -> Iterator[Dict[str, Any]]: + yield { + "type": "thinking_end", + "step_id": meta.get("step_id", ""), + "duration_ms": meta.get("duration_ms"), + "thinking_content": meta.get("thinking_content", ""), + } + + def _on_text(self, output: PipelineOutput, _meta: dict) -> Iterator[Dict[str, Any]]: + content = output.answer or "" + if content: + self.last_text = content + yield { + "type": "text", + "content": content, + } + + def _on_final(self, _output: PipelineOutput, _meta: dict) -> Iterator[Dict[str, Any]]: + # Callers handle finalization themselves + return () + + def _on_unknown(self, output: PipelineOutput, _meta: dict, event_type: str) -> Iterator[Dict[str, Any]]: + """Fallback for unrecognised event types.""" + if getattr(output, "final", False): + return + content = output.answer or "" + if content: + yield { + "type": event_type, + "content": content, + } + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + +def _try_parse_args(raw: Any) -> Any: + """Attempt to parse raw tool arguments into a dict.""" + if isinstance(raw, str) and raw.strip(): + try: + return json.loads(raw) + except Exception: + return {"_raw_arguments": raw} + elif isinstance(raw, dict): + return raw + return None diff --git a/src/interfaces/chat_app/static/chat.css b/src/interfaces/chat_app/static/chat.css index 775d0772c..6b8cd0916 100644 --- a/src/interfaces/chat_app/static/chat.css +++ b/src/interfaces/chat_app/static/chat.css @@ -1214,7 +1214,8 @@ body { } .agent-dropdown-edit, -.agent-dropdown-delete { +.agent-dropdown-delete, +.agent-dropdown-clone { display: inline-flex; align-items: center; justify-content: center; @@ -1227,6 +1228,11 @@ body { transition: color var(--transition-fast), background var(--transition-fast); } +.agent-dropdown-clone:hover { + color: var(--primary); + background: rgba(16, 185, 129, 0.1); +} + .agent-dropdown-edit:hover { color: var(--text-primary); background: var(--surface); @@ -1520,6 +1526,51 @@ body { line-height: 1.5; } +.settings-link-btn { + display: inline-flex; + align-items: center; + justify-content: center; + min-height: 36px; + padding: 0 14px; + border-radius: 10px; + border: 1px solid var(--border-color); + background: var(--bg-secondary); + color: var(--text-primary); + text-decoration: none; + font-size: var(--text-sm); + font-weight: 600; + transition: background-color 0.15s ease, border-color 0.15s ease; +} + +.settings-link-btn:hover { + background: var(--bg-tertiary); + border-color: var(--accent-soft); +} + +.settings-range { + width: 100%; + accent-color: var(--accent); + margin: 4px 0 10px; +} + +.settings-range-meta { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + font-size: var(--text-xs); + color: var(--text-secondary); +} + +.settings-ab-percent { + font-size: var(--text-sm); + color: var(--accent); +} + +.settings-inline-error { + color: #dc2626; +} + /* Compact Toggle (for inline use) */ .settings-toggle-compact { position: relative; @@ -2242,96 +2293,454 @@ body { ============================================================================= */ /* ----------------------------------------------------------------------------- - A/B Comparison Container - Side-by-Side Layout + A/B Pool Editor ----------------------------------------------------------------------------- */ -.ab-comparison { - display: flex; - gap: 16px; - margin: 0 auto; - max-width: calc(var(--message-max-width) * 2 + 32px); - padding: var(--message-padding-y) var(--message-padding-x); +.ab-pool-status { + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.04em; + padding: 3px 8px; + border-radius: 4px; + background: var(--bg-tertiary, #e2e8f0); + color: var(--text-tertiary, #94a3b8); } -.ab-response { - flex: 1; - background: var(--bg-secondary); - border: 1px solid var(--border-color); - border-radius: 12px; - overflow: hidden; +.ab-pool-status.active { + background: rgba(16, 185, 129, 0.12); + color: #10b981; +} + +.ab-pool-editor { + margin-top: 12px; +} + +.ab-pool-agent-list { display: flex; flex-direction: column; - max-height: 500px; + gap: 0; + border: 1px solid var(--border-color, #e2e8f0); + border-radius: 8px; + overflow: hidden; + max-height: 260px; + overflow-y: auto; +} + +.ab-pool-agent-row { + display: grid; + grid-template-columns: 32px 1fr auto; + align-items: center; + gap: 10px; + padding: 9px 12px; + background: var(--bg-primary); + border-bottom: 1px solid var(--border-color, #e2e8f0); + transition: background 0.1s ease; + cursor: pointer; +} + +.ab-pool-agent-row:last-child { + border-bottom: none; } -.ab-response-header { +.ab-pool-agent-row:hover { + background: var(--bg-secondary, #f8fafc); +} + +.ab-pool-agent-row.selected { + background: rgba(16, 185, 129, 0.04); +} + +.ab-pool-agent-row.selected.champion { + background: rgba(59, 130, 246, 0.06); +} + +.ab-pool-agent-check { display: flex; align-items: center; - gap: 8px; - padding: 12px 16px; - background: var(--bg-tertiary); - border-bottom: 1px solid var(--border-color); - font-size: var(--text-sm); + justify-content: center; +} + +.ab-pool-agent-check input[type="checkbox"] { + width: 16px; + height: 16px; + accent-color: var(--accent-color, #10b981); + cursor: pointer; +} + +.ab-pool-agent-name { + font-size: 0.85rem; + font-weight: 500; + color: var(--text-primary); + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.ab-pool-champion-btn { + display: none; /* hidden until agent is selected */ + align-items: center; + gap: 4px; + padding: 3px 10px; + font-size: 0.7rem; font-weight: 600; - color: var(--text-secondary); + text-transform: uppercase; + letter-spacing: 0.03em; + border: 1px solid var(--border-color, #e2e8f0); + border-radius: 4px; + background: transparent; + color: var(--text-tertiary, #94a3b8); + cursor: pointer; + transition: all 0.15s ease; + white-space: nowrap; +} + +.ab-pool-agent-row.selected .ab-pool-champion-btn { + display: inline-flex; +} + +.ab-pool-champion-btn:hover { + border-color: #3b82f6; + color: #3b82f6; + background: rgba(59, 130, 246, 0.06); } -.ab-response-label { +.ab-pool-champion-btn.is-champion { + border-color: #3b82f6; + background: #3b82f6; + color: #fff; + cursor: default; +} + +.ab-pool-actions { display: flex; - align-items: center; - gap: 6px; + gap: 8px; + margin-top: 12px; } -.ab-response-label::before { - content: ''; +.ab-pool-btn { + padding: 7px 16px; + font-size: 0.8rem; + font-weight: 500; + border-radius: 6px; + border: 1px solid transparent; + cursor: pointer; + transition: all 0.15s ease; } -.ab-response-content { - padding: 16px; - overflow-y: auto; - flex: 1; +.ab-pool-btn:disabled { + opacity: 0.45; + cursor: not-allowed; } -/* Winner/Loser States */ -.ab-response-winner { - border: none; +.ab-pool-btn-save { + background: var(--accent-color, #10b981); + color: #fff; + border-color: var(--accent-color, #10b981); +} + +.ab-pool-btn-save:not(:disabled):hover { + filter: brightness(1.1); +} + +.ab-pool-btn-disable { background: transparent; - flex: none; - width: 100%; - max-width: var(--message-max-width); + color: #ef4444; + border-color: rgba(239, 68, 68, 0.35); } -.ab-comparison-resolved { - display: block; - gap: 0; +.ab-pool-btn-disable:hover { + background: rgba(239, 68, 68, 0.08); + border-color: #ef4444; } -.ab-comparison-resolved .ab-response-winner { - padding: 0; +.ab-pool-message { + margin-top: 8px; + font-size: 0.78rem; + min-height: 1.2em; } -.ab-comparison-resolved .ab-response-winner .ab-response-content { - padding: 0; +.ab-pool-message.success { + color: #10b981; +} + +.ab-pool-message.error { + color: #ef4444; +} + +/* Agent actions inside pool row */ +.ab-pool-agent-actions { + display: flex; + align-items: center; + gap: 4px; } -.ab-response-loser { +.ab-pool-variant-btn { display: none; + align-items: center; + justify-content: center; + width: 22px; + height: 22px; + padding: 0; + font-size: 1rem; + font-weight: 600; + line-height: 1; + border: 1px solid var(--border-color, #e2e8f0); + border-radius: 4px; + background: transparent; + color: var(--text-tertiary, #94a3b8); + cursor: pointer; + transition: all 0.15s ease; } -.ab-response-tie { - border-color: var(--text-tertiary); +.ab-pool-agent-row:hover .ab-pool-variant-btn { + display: inline-flex; +} + +.ab-pool-variant-btn:hover { + border-color: #10b981; + color: #10b981; + background: rgba(16, 185, 129, 0.06); +} + +/* AB-only badge */ +.ab-pool-ab-badge { + display: inline-block; + margin-left: 6px; + padding: 1px 5px; + font-size: 0.6rem; + font-weight: 700; + text-transform: uppercase; + letter-spacing: 0.04em; + border-radius: 3px; + background: rgba(59, 130, 246, 0.1); + color: #3b82f6; + vertical-align: middle; +} + +/* Quick Variant Panel */ +.ab-quick-variant-panel { + border: 1px solid var(--border-color, #e2e8f0); + border-radius: 8px; + background: var(--bg-secondary, #f8fafc); + padding: 12px; + margin-bottom: 12px; + animation: modalSlideIn 0.15s ease-out; +} + +.ab-qv-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 10px; } -.ab-winner-badge { +.ab-qv-header strong { + font-size: 0.85rem; + color: var(--text-primary); +} + +.ab-qv-close { + background: none; + border: none; + font-size: 1.2rem; + color: var(--text-tertiary); + cursor: pointer; + padding: 0 4px; + line-height: 1; +} + +.ab-qv-close:hover { + color: var(--text-primary); +} + +.ab-qv-label { + display: block; + font-size: 0.75rem; + font-weight: 600; + color: var(--text-secondary); + margin-bottom: 4px; + margin-top: 8px; +} + +.ab-qv-name { + width: 100%; + padding: 6px 10px; + border: 1px solid var(--border-color, #e2e8f0); + border-radius: 6px; + background: var(--bg-primary); + color: var(--text-primary); + font-size: 0.82rem; + font-family: inherit; +} + +.ab-qv-name:focus { + outline: none; + border-color: var(--accent-color, #10b981); + box-shadow: 0 0 0 2px rgba(16, 185, 129, 0.1); +} + +.ab-qv-tools { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-top: 4px; + max-height: 100px; + overflow-y: auto; +} + +.ab-qv-tool { display: inline-flex; align-items: center; gap: 4px; - margin-left: auto; - padding: 2px 8px; - background: var(--accent); - color: white; - border-radius: 9999px; + padding: 3px 8px; + font-size: 0.75rem; + background: var(--bg-primary); + border: 1px solid var(--border-color, #e2e8f0); + border-radius: 4px; + cursor: pointer; + transition: all 0.1s ease; + color: var(--text-secondary); +} + +.ab-qv-tool:has(input:checked) { + border-color: #10b981; + background: rgba(16, 185, 129, 0.06); + color: var(--text-primary); +} + +.ab-qv-tool input { + width: 12px; + height: 12px; + accent-color: #10b981; +} + +.ab-qv-footer { + display: flex; + align-items: center; + gap: 10px; + margin-top: 10px; +} + +.ab-qv-msg { + font-size: 0.75rem; +} + +.ab-qv-msg.error { + color: #ef4444; +} + +.ab-qv-msg.success { + color: #10b981; +} + +/* Legacy banner (kept for backward compat) */ +.ab-pool-banner { + background: var(--bg-secondary, #f0f4f8); + border: 1px solid var(--border-color, #d1d9e6); + border-radius: 8px; + padding: 10px 14px; + margin-bottom: 12px; +} + +.ab-pool-info { + display: flex; + flex-direction: column; + gap: 4px; + font-size: 0.85rem; + color: var(--text-secondary, #5a6a7a); +} + +.ab-pool-info strong { + color: var(--text-primary, #1a2a3a); + font-size: 0.9rem; +} + +.ab-pool-info em { + font-style: normal; + font-weight: 500; + color: var(--accent-color, #4a7eff); +} + +/* ----------------------------------------------------------------------------- + A/B Comparison Container - Side-by-Side Layout + ----------------------------------------------------------------------------- */ +.ab-comparison { + display: flex; + gap: 16px; + margin: 0 auto; + max-width: calc(var(--message-max-width) * 2 + 32px); + padding: var(--message-padding-y) var(--message-padding-x); +} + +/* Each arm is a normal .message.assistant with an extra .ab-arm class */ +.ab-arm { + flex: 1; + min-width: 0; /* allow flex shrink */ +} + +/* Remove default message max-width when inside AB comparison */ +.ab-comparison .ab-arm.message { + max-width: none; +} + +.ab-arm-header-copy { + display: flex; + flex-direction: column; + gap: 2px; + min-width: 0; + flex: 1; +} + +.ab-arm-title-row { + display: flex; + align-items: center; + gap: 8px; + min-width: 0; + flex-wrap: wrap; +} + +/* A/B label badge */ +.ab-arm-label { + display: inline-flex; + align-items: center; + justify-content: center; + width: auto; + min-height: 22px; + padding: 0 9px; + margin-left: 0; + background: var(--accent-light); + color: var(--accent); + border: 1px solid var(--accent); + border-radius: 999px; + font-size: var(--text-xs); + font-weight: 700; + flex-shrink: 0; + letter-spacing: 0.01em; +} + +/* A/B variant name shown below the sender row */ +.ab-arm-variant-name { font-size: var(--text-xs); + color: var(--text-secondary); font-weight: 500; + line-height: 1.35; + white-space: normal; + overflow: hidden; + text-overflow: ellipsis; + max-width: 100%; +} + +/* Tie state — dim both arms */ +.ab-arm-tie { + opacity: 0.6; +} + +/* Responsive: stack vertically on small screens */ +@media (max-width: 768px) { + .ab-comparison { + flex-direction: column; + max-width: var(--message-max-width); + } } /* ----------------------------------------------------------------------------- diff --git a/src/interfaces/chat_app/static/chat.js b/src/interfaces/chat_app/static/chat.js index 187f7b4dc..ab24e19c6 100644 --- a/src/interfaces/chat_app/static/chat.js +++ b/src/interfaces/chat_app/static/chat.js @@ -19,8 +19,6 @@ const CONFIG = { SELECTED_PROVIDER: 'archi_selected_provider', SELECTED_MODEL: 'archi_selected_model', SELECTED_MODEL_CUSTOM: 'archi_selected_model_custom', - SELECTED_PROVIDER_B: 'archi_selected_provider_b', - SELECTED_MODEL_B: 'archi_selected_model_b', }, ENDPOINTS: { STREAM: '/api/get_chat_response_stream', @@ -29,9 +27,14 @@ const CONFIG = { LOAD_CONVERSATION: '/api/load_conversation', NEW_CONVERSATION: '/api/new_conversation', DELETE_CONVERSATION: '/api/delete_conversation', - AB_CREATE: '/api/ab/create', AB_PREFERENCE: '/api/ab/preference', AB_PENDING: '/api/ab/pending', + AB_POOL: '/api/ab/pool', + AB_DECISION: '/api/ab/decision', + AB_POOL_SET: '/api/ab/pool/set', + AB_POOL_DISABLE: '/api/ab/pool/disable', + AB_COMPARE: '/api/ab/compare', + AB_METRICS: '/api/ab/metrics', TRACE_GET: '/api/trace', CANCEL_STREAM: '/api/cancel_stream', PROVIDERS: '/api/providers', @@ -47,6 +50,8 @@ const CONFIG = { AGENTS_LIST: '/api/agents/list', AGENT_SPEC: '/api/agents/spec', AGENT_ACTIVE: '/api/agents/active', + USER_ME: '/api/users/me', + USER_PREFERENCES: '/api/users/me/preferences', LIKE: '/api/like', DISLIKE: '/api/dislike', TEXT_FEEDBACK: '/api/text_feedback', @@ -219,6 +224,47 @@ const API = { return data; }, + /** + * Shared NDJSON reader: reads a fetch Response body and yields parsed JSON objects. + * Properly flushes any remaining buffer content after the stream ends. + */ + async *_readNDJSON(response) { + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + try { + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop(); // Keep incomplete line in buffer + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed) continue; + try { + yield JSON.parse(trimmed); + } catch (e) { + console.warn('Failed to parse NDJSON line:', trimmed); + } + } + } + // Flush remaining buffer after stream ends + if (buffer.trim()) { + try { + yield JSON.parse(buffer.trim()); + } catch (e) { + console.warn('Failed to parse final NDJSON line:', buffer.trim()); + } + } + } finally { + reader.releaseLock(); + } + }, + async getConfigs() { return this.fetchJson(CONFIG.ENDPOINTS.CONFIGS); }, @@ -269,10 +315,10 @@ const API = { client_sent_msg_ts: Date.now(), client_timeout: CONFIG.STREAMING.TIMEOUT, client_id: this.clientId, - include_agent_steps: true, // Required for streaming chunks - include_tool_steps: true, // Enable tool step events for trace - provider: provider, // Provider-based model selection - model: model, // Model ID/name for the provider + include_agent_steps: true, + include_tool_steps: true, + provider: provider, + model: model, }), signal: signal, }); @@ -287,47 +333,10 @@ const API = { throw new Error(text || `Request failed (${response.status})`); } - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; - - try { - while (true) { - const { value, done } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split('\n'); - buffer = lines.pop(); - - for (const line of lines) { - const trimmed = line.trim(); - if (!trimmed) continue; - - try { - yield JSON.parse(trimmed); - } catch (e) { - console.error('Failed to parse stream event:', e); - } - } - } - } finally { - reader.releaseLock(); - } + yield* this._readNDJSON(response); }, // A/B Testing API methods - async createABComparison(data) { - return this.fetchJson(CONFIG.ENDPOINTS.AB_CREATE, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - ...data, - client_id: this.clientId, - }), - }); - }, - async submitABPreference(comparisonId, preference) { return this.fetchJson(CONFIG.ENDPOINTS.AB_PREFERENCE, { method: 'POST', @@ -345,6 +354,84 @@ const API = { return this.fetchJson(url); }, + // Pool-based A/B testing API methods + async getABPool() { + return this.fetchJson(`${CONFIG.ENDPOINTS.AB_POOL}?client_id=${encodeURIComponent(this.clientId)}`); + }, + + async getABDecision(conversationId = null) { + const params = new URLSearchParams({ client_id: this.clientId }); + if (conversationId != null) { + params.set('conversation_id', String(conversationId)); + } + return this.fetchJson(`${CONFIG.ENDPOINTS.AB_DECISION}?${params.toString()}`); + }, + + async getABMetrics() { + return this.fetchJson(`${CONFIG.ENDPOINTS.AB_METRICS}?client_id=${encodeURIComponent(this.clientId)}`); + }, + + async saveABPool(payload) { + return this.fetchJson(CONFIG.ENDPOINTS.AB_POOL_SET, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ ...payload, client_id: this.clientId }), + }); + }, + + async disableABPool() { + return this.fetchJson(CONFIG.ENDPOINTS.AB_POOL_DISABLE, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ client_id: this.clientId }), + }); + }, + + /** + * Stream a pool-based A/B comparison. Returns an async iterator of NDJSON events. + * Each event has an 'arm' field ('a' or 'b') plus 'type', 'content', etc. + */ + async *streamABComparison(history, conversationId, configName, signal, provider = null, model = null) { + const streamOverride = window.__ARCHI_PLAYWRIGHT__?.ab?.streamOverride; + if (typeof streamOverride === 'function') { + yield* streamOverride({ + history, + conversationId, + configName, + signal, + provider, + model, + clientId: this.clientId, + }); + return; + } + + const body = { + last_message: history.slice(-1), + conversation_id: conversationId, + config_name: configName || null, + client_id: this.clientId, + client_sent_msg_ts: Date.now(), + client_timeout: CONFIG.STREAMING.TIMEOUT, + provider, + model, + }; + + const response = await fetch(CONFIG.ENDPOINTS.AB_COMPARE, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + signal, + }); + + if (!response.ok) { + const errText = await response.text(); + throw new Error(`A/B compare failed: ${response.status} ${errText}`); + } + + yield* this._readNDJSON(response); + }, + // Provider API methods async getProviders() { return this.fetchJson(CONFIG.ENDPOINTS.PROVIDERS); @@ -377,6 +464,18 @@ const API = { return this.fetchJson(url); }, + async getCurrentUser() { + return this.fetchJson(CONFIG.ENDPOINTS.USER_ME); + }, + + async updateUserPreferences(payload) { + return this.fetchJson(CONFIG.ENDPOINTS.USER_PREFERENCES, { + method: 'PATCH', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload), + }); + }, + async setActiveAgent(name) { return this.fetchJson(CONFIG.ENDPOINTS.AGENT_ACTIVE, { method: 'POST', @@ -586,6 +685,7 @@ window.Markdown = Markdown; const UI = { elements: {}, sendBtnDefaultHtml: null, + traceTimerIntervals: new Map(), init() { this.elements = { @@ -600,14 +700,12 @@ const UI = { inputField: document.querySelector('.input-field'), sendBtn: document.querySelector('.send-btn'), modelSelectA: null, - modelSelectB: document.querySelector('.model-select-b'), + settingsBtn: document.querySelector('.settings-btn'), dataTab: document.getElementById('data-tab'), settingsModal: document.querySelector('.settings-modal'), settingsBackdrop: document.querySelector('.settings-backdrop'), settingsClose: document.querySelector('.settings-close'), - abCheckbox: document.querySelector('.ab-checkbox'), - abModelGroup: document.querySelector('.ab-model-group'), traceVerboseOptions: document.querySelector('.trace-verbose-options'), agentDropdown: document.querySelector('.agent-dropdown'), agentDropdownBtn: document.querySelector('.agent-dropdown-btn'), @@ -635,7 +733,7 @@ const UI = { // Provider selection elements providerSelect: document.getElementById('provider-select'), modelSelectPrimary: document.getElementById('model-select-primary'), - providerSelectB: document.getElementById('provider-select-b'), + providerStatus: document.getElementById('provider-status'), // User profile elements userProfileWidget: document.getElementById('user-profile-widget'), @@ -649,6 +747,15 @@ const UI = { customModelRow: document.getElementById('custom-model-row'), activeModelLabel: document.getElementById('active-model-label'), darkModeToggle: document.getElementById('dark-mode-toggle'), + abSettingsNav: document.getElementById('ab-settings-nav'), + abSettingsSection: document.getElementById('settings-ab-testing'), + abParticipationGroup: document.getElementById('ab-participation-group'), + abParticipationSlider: document.getElementById('ab-participation-slider'), + abParticipationValue: document.getElementById('ab-participation-value'), + abParticipationDefault: document.getElementById('ab-participation-default'), + abParticipationNote: document.getElementById('ab-participation-note'), + abParticipationInactive: document.getElementById('ab-participation-inactive'), + abAdminLinkSection: document.getElementById('ab-settings-section'), }; this.sendBtnDefaultHtml = this.elements.sendBtn?.innerHTML || ''; @@ -746,16 +853,23 @@ const UI = { const target = e.target; const row = target.closest('.agent-dropdown-item'); if (!row) return; + e.preventDefault(); + e.stopPropagation(); // Handle inline delete confirmation buttons if (target.closest('.agent-dropdown-confirm-yes')) { const name = row.dataset.agentName; - this.closeAgentDropdown(); this.doDeleteAgent(name); return; } if (target.closest('.agent-dropdown-confirm-no')) { // Cancel: re-render list to remove confirmation state - Chat.loadAgents(); + this.renderAgentsList(Chat.state.allAgents || Chat.state.agents || [], Chat.state.activeAgentName); + return; + } + if (target.closest('.agent-dropdown-clone')) { + const name = row.dataset.agentName; + this.closeAgentDropdown(); + this.openAgentSpecEditor({ mode: 'clone', name }); return; } if (target.closest('.agent-dropdown-edit')) { @@ -789,37 +903,58 @@ const UI = { // Resize handle for agent spec modal this.initAgentSpecResize(); - // A/B toggle in settings - this.elements.abCheckbox?.addEventListener('change', (e) => { - const isEnabled = e.target.checked; - if (isEnabled) { - // Show warning modal before enabling - const dismissed = sessionStorage.getItem(CONFIG.STORAGE_KEYS.AB_WARNING_DISMISSED); - if (!dismissed) { - e.target.checked = false; // Reset checkbox - this.showABWarningModal( - () => { - // On confirm - e.target.checked = true; - if (this.elements.abModelGroup) { - this.elements.abModelGroup.style.display = 'block'; - } - sessionStorage.setItem(CONFIG.STORAGE_KEYS.AB_WARNING_DISMISSED, 'true'); - }, - () => { - // On cancel - e.target.checked = false; - } - ); - return; + // A/B pool editor — save & disable buttons + document.getElementById('ab-pool-save')?.addEventListener('click', async () => { + const sel = UI._getABPoolSelection(); + if (!sel || !sel.champion || sel.variants.length < 2) return; + const saveBtn = document.getElementById('ab-pool-save'); + const msgEl = document.getElementById('ab-pool-message'); + const sampleRate = Number(document.getElementById('ab-sample-rate')?.value || 1); + const disclosureMode = document.getElementById('ab-disclosure-mode')?.value || 'post_vote_reveal'; + const defaultTraceMode = document.getElementById('ab-trace-mode')?.value || 'hidden'; + saveBtn.disabled = true; + saveBtn.textContent = 'Saving…'; + try { + const result = await API.saveABPool({ + champion: sel.champion, + variants: sel.variants, + comparison_rate: sampleRate, + variant_label_mode: disclosureMode, + activity_panel_default_state: defaultTraceMode, + }); + if (result?.success) { + if (msgEl) { msgEl.textContent = 'Pool saved'; msgEl.className = 'ab-pool-message success'; } + Chat.state.abPool = result; + // Re-render to reflect saved state + UI.updateABPoolUI(result); + } else { + if (msgEl) { msgEl.textContent = result?.error || 'Save failed'; msgEl.className = 'ab-pool-message error'; } } + } catch (e) { + if (msgEl) { msgEl.textContent = e.message || 'Save failed'; msgEl.className = 'ab-pool-message error'; } + } finally { + saveBtn.textContent = 'Save Pool'; + UI._updateABPoolSaveState(); } - if (this.elements.abModelGroup) { - this.elements.abModelGroup.style.display = isEnabled ? 'block' : 'none'; - } - // If disabling A/B mode while vote is pending, re-enable input - if (!isEnabled && Chat.state.abVotePending) { - Chat.cancelPendingABComparison(); + }); + + document.getElementById('ab-pool-disable')?.addEventListener('click', async () => { + const disableBtn = document.getElementById('ab-pool-disable'); + const msgEl = document.getElementById('ab-pool-message'); + disableBtn.disabled = true; + try { + const result = await API.disableABPool(); + if (result?.success) { + Chat.state.abPool = null; + UI.updateABPoolUI({ enabled: false }); + if (msgEl) { msgEl.textContent = 'Pool disabled'; msgEl.className = 'ab-pool-message success'; } + // If A/B mode was active in chat, deactivate + if (Chat.state.abVotePending) Chat.cancelPendingABComparison(); + } + } catch (e) { + if (msgEl) { msgEl.textContent = e.message || 'Failed'; msgEl.className = 'ab-pool-message error'; } + } finally { + disableBtn.disabled = false; } }); @@ -836,6 +971,14 @@ const UI = { localStorage.setItem('archi_theme', isDark ? 'dark' : 'light'); }); + this.elements.abParticipationSlider?.addEventListener('input', (e) => { + this.updateABParticipationPreview(Number(e.target.value)); + }); + + this.elements.abParticipationSlider?.addEventListener('change', async (e) => { + await Chat.saveABParticipationPreference(Number(e.target.value) / 100); + }); + // Provider selection this.elements.providerSelect?.addEventListener('change', (e) => { Chat.handleProviderChange(e.target.value); @@ -849,25 +992,21 @@ const UI = { Chat.handleCustomModelChange(e.target.value); }); - this.elements.providerSelectB?.addEventListener('change', (e) => { - Chat.handleProviderBChange(e.target.value); - }); - // User profile widget interactions this.elements.userRolesToggle?.addEventListener('click', (e) => { e.stopPropagation(); this.toggleUserRolesPanel(); }); - + this.elements.userProfileWidget?.addEventListener('click', () => { this.toggleUserRolesPanel(); }); - + this.elements.userLogoutBtn?.addEventListener('click', (e) => { e.stopPropagation(); window.location.href = '/logout'; }); - + // Close modal on Escape document.addEventListener('keydown', (e) => { if (e.key === 'Escape' && this.elements.settingsModal?.style.display !== 'none') { @@ -1159,6 +1298,9 @@ const UI = {
${checkmark}${Utils.escapeHtml(name)}
+ @@ -1177,27 +1319,52 @@ const UI = { if (!this.elements.agentSpecModal) return; this.elements.agentSpecModal.style.display = 'flex'; this.setAgentSpecStatus(''); - this.agentSpecMode = mode; - this.agentSpecName = name; + // Clone mode → load source spec, then switch to create for saving + this.agentSpecMode = mode === 'clone' ? 'create' : mode; + this.agentSpecName = mode === 'clone' ? null : name; + this.agentSpecOriginalName = mode === 'edit' ? name : null; // Restore persisted size this.restoreAgentSpecSize(); if (this.elements.agentSpecTitle) { - this.elements.agentSpecTitle.textContent = mode === 'edit' ? `Edit ${name || 'Agent'}` : 'New Agent'; + if (mode === 'clone') { + this.elements.agentSpecTitle.textContent = `New Variant of ${name || 'Agent'}`; + } else if (mode === 'edit') { + this.elements.agentSpecTitle.textContent = `Edit ${name || 'Agent'}`; + } else { + this.elements.agentSpecTitle.textContent = 'New Agent'; + } } // Update reset button label if (this.elements.agentSpecReset) { this.elements.agentSpecReset.textContent = mode === 'edit' ? 'Revert changes' : 'Reset template'; } + if (this.elements.agentSpecName) { + this.elements.agentSpecName.readOnly = mode === 'edit'; + this.elements.agentSpecName.title = mode === 'edit' + ? 'Agent name is fixed while editing. Clone or create a new agent to use a different name.' + : ''; + } // Clear validation errors this.clearAgentSpecValidation(); - if (mode === 'edit' && name) { + if (mode === 'clone' && name) { + // Load tool palette first, then load source spec and modify name + await this.loadAgentToolPalette(); + await this.loadAgentSpecByName(name); + // Append " (variant)" to the name so user can tweak tools & save + if (this.elements.agentSpecName) { + this.elements.agentSpecName.value = `${name} (variant)`; + } + this.setAgentSpecStatus('Cloned — adjust tools and name, then save.', 'info'); + setTimeout(() => this.elements.agentSpecName?.select(), 100); + } else if (mode === 'edit' && name) { await this.loadAgentToolPalette(); await this.loadAgentSpecByName(name); + this.setAgentSpecStatus('Editing updates this agent in place. Clone or create a new agent to use a different name.', 'info'); } else { await this.loadAgentSpecTemplate(); } // Auto-focus name in create mode - if (mode === 'create') { + if (mode === 'create' && !name) { setTimeout(() => this.elements.agentSpecName?.focus(), 100); } }, @@ -1243,8 +1410,9 @@ const UI = { }, /** Serialise structured form fields back to .md format */ - serialiseAgentSpec(name, tools, prompt) { + serialiseAgentSpec(name, tools, prompt, { ab_only = false } = {}) { let yaml = `---\nname: ${name}\n`; + if (ab_only) yaml += 'ab_only: true\n'; if (tools.length) { yaml += 'tools:\n'; for (const t of tools) yaml += ` - ${t}\n`; @@ -1327,9 +1495,9 @@ const UI = { resetAgentSpecForm() { this.clearAgentSpecValidation(); this.setAgentSpecStatus(''); - if (this.agentSpecMode === 'edit' && this.agentSpecName) { + if (this.agentSpecMode === 'edit' && this.agentSpecOriginalName) { // Revert to saved version - this.loadAgentSpecByName(this.agentSpecName); + this.loadAgentSpecByName(this.agentSpecOriginalName); } else { this.loadAgentSpecTemplate(); } @@ -1375,6 +1543,11 @@ const UI = { hasError = true; } if (hasError) return; + if (this.agentSpecMode === 'edit' && this.agentSpecOriginalName && name !== this.agentSpecOriginalName) { + this.elements.agentSpecName?.classList.add('field-error'); + this.setAgentSpecStatus('Agent name cannot be changed in edit mode. Clone or create a new agent instead.', 'error'); + return; + } // Serialise to .md format const content = this.serialiseAgentSpec(name, tools, prompt); if (this.elements.agentSpecEditor) this.elements.agentSpecEditor.value = content; @@ -1386,12 +1559,13 @@ const UI = { const response = await API.saveAgentSpec({ content, mode: this.agentSpecMode || 'create', - existing_name: this.agentSpecName || null, + existing_name: this.agentSpecOriginalName || this.agentSpecName || null, }); if (this.agentSpecMode === 'edit') { - const savedName = Utils.normalizeAgentName(response?.name || this.agentSpecName || ''); + const savedName = Utils.normalizeAgentName(response?.name || this.agentSpecOriginalName || this.agentSpecName || ''); if (savedName) { this.agentSpecName = savedName; + this.agentSpecOriginalName = savedName; } if (Utils.normalizeAgentName(Chat.state.activeAgentName) === Utils.normalizeAgentName(savedName)) { await Chat.setActiveAgent(savedName); @@ -1478,7 +1652,51 @@ const UI = { }, isABEnabled() { - return this.elements.abCheckbox?.checked ?? false; + // A/B mode is active when the server reports this user is eligible + return Chat.state.abPool?.enabled === true; + }, + + getABDisclosureMode() { + return this.normalizeABDisclosureMode( + Chat.state.abPool?.variant_label_mode ?? Chat.state.abPool?.disclosure_mode + ); + }, + + getABTraceMode() { + return this.normalizeABTraceMode( + Chat.state.abPool?.activity_panel_default_state ?? Chat.state.abPool?.default_trace_mode + ); + }, + + normalizeABDisclosureMode(mode) { + if (mode === 'reveal_after_vote') return 'post_vote_reveal'; + if (mode === 'show_during_streaming') return 'always_visible'; + return ['hidden', 'post_vote_reveal', 'always_visible'].includes(mode) + ? mode + : 'post_vote_reveal'; + }, + + normalizeABTraceMode(mode) { + return ['hidden', 'collapsed', 'expanded'].includes(mode) + ? mode + : 'hidden'; + }, + + isTraceVisibleMode(mode) { + return !['minimal', 'hidden'].includes(mode); + }, + + isTraceCollapsedMode(mode) { + return ['normal', 'collapsed'].includes(mode); + }, + + isTraceExpandedMode(mode) { + return ['verbose', 'expanded'].includes(mode); + }, + + shouldUseABForNextTurn() { + if (!this.isABEnabled()) return false; + return true; }, autoResizeInput() { @@ -1566,14 +1784,6 @@ const UI = { select.value = ''; } - // Also populate provider B select for A/B testing - const selectB = this.elements.providerSelectB; - if (selectB) { - selectB.innerHTML = '' + - enabledProviders - .map(p => ``) - .join(''); - } }, renderProviderModels(models, selectedModel = null, providerType = null) { @@ -1608,29 +1818,7 @@ const UI = { } }, - renderModelBOptions(models, selectedModel = null, providerType = null) { - const select = this.elements.modelSelectB; - if (!select) return; - - if (!models || models.length === 0) { - select.innerHTML = ''; - return; - } - - const options = models - .map(m => ``) - .join(''); - const customOption = providerType === 'openrouter' - ? '' - : ''; - select.innerHTML = options + customOption; - if (selectedModel === '__custom__' && providerType === 'openrouter') { - select.value = '__custom__'; - } else if (selectedModel && models.some(m => m.id === selectedModel)) { - select.value = selectedModel; - } - }, updateProviderStatus(status, message) { const statusEl = this.elements.providerStatus; @@ -1949,70 +2137,417 @@ const UI = { // A/B Testing UI Methods // ========================================================================= - showABWarningModal(onConfirm, onCancel) { - // Prevent duplicate modals - if (document.getElementById('ab-warning-modal')) { - return; + setABSectionVisible(visible) { + if (this.elements.abAdminLinkSection) { + this.elements.abAdminLinkSection.style.display = visible ? '' : 'none'; } - - const modalHtml = ` -
-
-
- -

Enable A/B Testing Mode

-
-
-

This will compare two AI responses for each message.

-
    -
  • 2× API usage - Each message generates two responses
  • -
  • Voting required - You must choose the better response before continuing
  • -
  • You can disable A/B mode at any time to skip voting
  • -
-
-
- - -
-
-
`; + }, - document.body.insertAdjacentHTML('beforeend', modalHtml); - const modal = document.getElementById('ab-warning-modal'); + setABSettingsVisible(visible) { + if (this.elements.abSettingsNav) { + this.elements.abSettingsNav.style.display = visible ? '' : 'none'; + } + if (this.elements.abSettingsSection && !visible) { + this.elements.abSettingsSection.hidden = true; + this.elements.abSettingsSection.classList.remove('active'); + } + }, - const closeModal = () => modal?.remove(); + updateABParticipationPreview(value) { + if (this.elements.abParticipationValue) { + this.elements.abParticipationValue.textContent = `${Math.round(value)}%`; + } + }, - modal.querySelector('.ab-warning-btn-cancel').addEventListener('click', () => { - closeModal(); - onCancel?.(); - }); + updateABSettingsSection() { + const abState = Chat.state.abPool || {}; + const capabilities = Chat.state.abCapabilities || {}; + const currentUser = Chat.state.currentUser || {}; + const preferenceSaveState = Chat.state.abPreferenceSaveState || null; + const canParticipate = capabilities.canParticipate === true; + const canViewAdmin = capabilities.canView === true; + const shouldShow = canParticipate || canViewAdmin; - modal.querySelector('.ab-warning-btn-confirm').addEventListener('click', () => { - closeModal(); - onConfirm?.(); - }); + this.setABSettingsVisible(shouldShow); + this.setABSectionVisible(canViewAdmin); - // Close on backdrop click - modal.addEventListener('click', (e) => { - if (e.target === modal) { - closeModal(); - onCancel?.(); - } - }); - }, + if (this.elements.abParticipationGroup) { + this.elements.abParticipationGroup.style.display = canParticipate ? '' : 'none'; + } + if (!canParticipate) { + return; + } - showToast(message, duration = 3000) { - // Remove existing toast - document.querySelector('.toast')?.remove(); + const defaultRate = Number( + abState.default_comparison_rate + ?? abState.default_sample_rate + ?? abState.comparison_rate + ?? abState.sample_rate + ?? 1 + ); + const usingDefault = currentUser.ab_participation_rate == null || Number.isNaN(Number(currentUser.ab_participation_rate)); + const effectiveRate = usingDefault ? defaultRate : Number(currentUser.ab_participation_rate); + const percent = Math.max(0, Math.min(100, Math.round(effectiveRate * 100))); - const toast = document.createElement('div'); - toast.className = 'toast'; - toast.textContent = message; - document.body.appendChild(toast); + if (this.elements.abParticipationSlider) { + this.elements.abParticipationSlider.value = String(percent); + } + this.updateABParticipationPreview(percent); + + if (this.elements.abParticipationDefault) { + this.elements.abParticipationDefault.textContent = `Default: ${Math.round(defaultRate * 100)}%`; + } + if (this.elements.abParticipationNote) { + if (preferenceSaveState?.type === 'error') { + this.elements.abParticipationNote.textContent = preferenceSaveState.message || 'Your last change was not saved.'; + this.elements.abParticipationNote.classList.add('settings-inline-error'); + } else { + this.elements.abParticipationNote.textContent = preferenceSaveState?.type === 'success' + ? (preferenceSaveState.message || 'Saved for your account.') + : (usingDefault + ? 'Currently using the deployment default until you choose your own rate.' + : 'Your saved setting applies only to your account.'); + this.elements.abParticipationNote.classList.remove('settings-inline-error'); + } + } + if (this.elements.abParticipationInactive) { + const reason = String(abState.participant_reason || ''); + let inactiveMessage = ''; + if (reason === 'not_targeted') { + inactiveMessage = 'The current experiment does not target your role or permissions. Your saved rate will apply automatically if a future experiment includes you.'; + } else if (reason === 'disabled') { + inactiveMessage = 'Experiments are currently inactive. Your preference will be used again if A/B testing is enabled.'; + } + this.elements.abParticipationInactive.textContent = inactiveMessage; + this.elements.abParticipationInactive.style.display = inactiveMessage ? '' : 'none'; + } + }, + + updateABPoolUI(poolInfo) { + // Render pool editor with current agents + pool state + const agentList = document.getElementById('ab-pool-agent-list'); + const statusBadge = document.getElementById('ab-pool-status'); + const disableBtn = document.getElementById('ab-pool-disable'); + const sampleRateInput = document.getElementById('ab-sample-rate'); + const disclosureModeInput = document.getElementById('ab-disclosure-mode'); + const traceModeInput = document.getElementById('ab-trace-mode'); + if (!agentList) return; + + // Use allAgents so AB-only variants appear in the pool editor + const agents = Chat.state.allAgents || Chat.state.agents || []; + const poolEnabled = poolInfo?.enabled === true; + const currentChampion = poolInfo?.champion || poolInfo?.control || null; + const currentVariants = poolInfo?.variants || []; + + // Update status badge + if (statusBadge) { + statusBadge.textContent = poolEnabled ? 'Active' : 'Inactive'; + statusBadge.classList.toggle('active', poolEnabled); + } + + // Show/hide disable button + if (disableBtn) { + disableBtn.style.display = poolEnabled ? '' : 'none'; + } + if (sampleRateInput) { + sampleRateInput.value = String(poolInfo?.comparison_rate ?? poolInfo?.sample_rate ?? 1); + } + if (disclosureModeInput) { + disclosureModeInput.value = poolInfo?.variant_label_mode || poolInfo?.disclosure_mode || 'post_vote_reveal'; + } + if (traceModeInput) { + traceModeInput.value = poolInfo?.activity_panel_default_state || poolInfo?.default_trace_mode || 'hidden'; + } + + // Render agent rows + agentList.innerHTML = agents.map(agent => { + const inPool = currentVariants.includes(agent.name); + const isChampion = agent.name === currentChampion; + const selectedClass = inPool ? ' selected' : ''; + const championClass = isChampion ? ' champion' : ''; + const isABOnly = agent.ab_only === true; + return ` + `; + }).join(''); + + // Wire up events + agentList.querySelectorAll('.ab-pool-agent-row').forEach(row => { + const agentName = row.dataset.agent; + const checkbox = row.querySelector('input[type="checkbox"]'); + const champBtn = row.querySelector('.ab-pool-champion-btn'); + const variantBtn = row.querySelector('.ab-pool-variant-btn'); + + checkbox.addEventListener('change', () => { + row.classList.toggle('selected', checkbox.checked); + if (!checkbox.checked && row.classList.contains('champion')) { + row.classList.remove('champion'); + champBtn.classList.remove('is-champion'); + champBtn.innerHTML = '☆ Champion'; + } + this._updateABPoolSaveState(); + }); + + champBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + if (!checkbox.checked) return; + // Clear previous champion + agentList.querySelectorAll('.ab-pool-agent-row').forEach(r => { + r.classList.remove('champion'); + const btn = r.querySelector('.ab-pool-champion-btn'); + btn.classList.remove('is-champion'); + btn.innerHTML = '☆ Champion'; + }); + row.classList.add('champion'); + champBtn.classList.add('is-champion'); + champBtn.innerHTML = '★ Champion'; + this._updateABPoolSaveState(); + }); + + if (variantBtn) { + variantBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + this._showQuickVariantPanel(agentName); + }); + } + }); + + this._updateABPoolSaveState(); + }, + + _updateABPoolSaveState() { + const agentList = document.getElementById('ab-pool-agent-list'); + const saveBtn = document.getElementById('ab-pool-save'); + const msgEl = document.getElementById('ab-pool-message'); + if (!agentList || !saveBtn) return; + + const selected = agentList.querySelectorAll('.ab-pool-agent-row.selected'); + const hasChampion = !!agentList.querySelector('.ab-pool-agent-row.champion'); + const valid = selected.length >= 2 && hasChampion; + saveBtn.disabled = !valid; + + if (msgEl) { + if (selected.length < 2 && selected.length > 0) { + msgEl.textContent = 'Select at least 2 agents'; + msgEl.className = 'ab-pool-message error'; + } else if (selected.length >= 2 && !hasChampion) { + msgEl.textContent = 'Click "Champion" to designate the baseline variant'; + msgEl.className = 'ab-pool-message error'; + } else { + msgEl.textContent = ''; + msgEl.className = 'ab-pool-message'; + } + } + }, + + _getABPoolSelection() { + const agentList = document.getElementById('ab-pool-agent-list'); + if (!agentList) return null; + const variants = []; + let champion = null; + agentList.querySelectorAll('.ab-pool-agent-row.selected').forEach(row => { + const name = row.dataset.agent; + variants.push(name); + if (row.classList.contains('champion')) champion = name; + }); + return { champion, variants }; + }, + + /** + * Show an inline panel to quickly create a variant of an existing agent. + * The variant is saved with ab_only: true so it only appears in the pool editor. + */ + async _showQuickVariantPanel(sourceAgentName) { + // Remove any existing panel + document.querySelector('.ab-quick-variant-panel')?.remove(); + + // Fetch source agent spec and tool palette in parallel + let sourceSpec = null; + let availableTools = []; + try { + const [specResp, templateResp] = await Promise.all([ + API.getAgentSpec(sourceAgentName), + API.getAgentTemplate(), + ]); + sourceSpec = UI.parseAgentSpec(specResp?.content || ''); + availableTools = (templateResp?.tools || []).map(t => typeof t === 'string' ? t : t.name); + } catch (e) { + console.error('Failed to load source agent for variant:', e); + return; + } + + const sourceTools = sourceSpec?.tools || []; + + // Build panel HTML + const panel = document.createElement('div'); + panel.className = 'ab-quick-variant-panel'; + panel.innerHTML = ` +
+ New variant of "${Utils.escapeHtml(sourceAgentName)}" + +
+ + + +
+ ${availableTools.map(t => ` + + `).join('')} +
+ + `; + + // Insert panel after the agent list + const agentList = document.getElementById('ab-pool-agent-list'); + agentList?.parentElement?.insertBefore(panel, agentList.nextSibling); + + // References + const nameInput = panel.querySelector('.ab-qv-name'); + const saveBtn = panel.querySelector('.ab-qv-save'); + const closeBtn = panel.querySelector('.ab-qv-close'); + const msgEl = panel.querySelector('.ab-qv-msg'); + + closeBtn.addEventListener('click', () => panel.remove()); + + nameInput.addEventListener('input', () => { + if (msgEl) { msgEl.textContent = ''; msgEl.className = 'ab-qv-msg'; } + }); + + saveBtn.addEventListener('click', async () => { + const variantName = (nameInput.value || '').trim(); + if (!variantName) { + if (msgEl) { msgEl.textContent = 'Name is required'; msgEl.className = 'ab-qv-msg error'; } + nameInput.focus(); + return; + } + + // Client-side duplicate check + const existingNames = (Chat.state.allAgents || []).map(a => a.name); + if (existingNames.includes(variantName)) { + if (msgEl) { msgEl.textContent = '"' + variantName + '" already exists \u2014 choose a different name'; msgEl.className = 'ab-qv-msg error'; } + nameInput.focus(); + nameInput.select(); + return; + } + + const selectedTools = [...panel.querySelectorAll('.ab-qv-tools input:checked')].map(cb => cb.value); + const specContent = UI.serialiseAgentSpec(variantName, selectedTools, sourceSpec?.prompt || '', { ab_only: true }); + + saveBtn.disabled = true; + saveBtn.textContent = 'Saving\u2026'; + if (msgEl) { msgEl.textContent = ''; msgEl.className = 'ab-qv-msg'; } + + try { + const result = await API.saveAgentSpec({ content: specContent, mode: 'create' }); + if (result?.success) { + if (msgEl) { msgEl.textContent = 'Created!'; msgEl.className = 'ab-qv-msg success'; } + // Refresh agent lists + re-render pool editor + await Chat.loadAgents(); + UI.updateABPoolUI(Chat.state.abPool || {}); + // Panel replaced by re-render; remove just in case + setTimeout(() => panel.remove(), 600); + } else { + if (msgEl) { msgEl.textContent = result?.error || 'Save failed'; msgEl.className = 'ab-qv-msg error'; } + saveBtn.disabled = false; + saveBtn.textContent = 'Create Variant'; + } + } catch (e) { + if (msgEl) { msgEl.textContent = e.message || 'Save failed'; msgEl.className = 'ab-qv-msg error'; } + saveBtn.disabled = false; + saveBtn.textContent = 'Create Variant'; + } + }); + + // Focus the name input + nameInput.focus(); + nameInput.select(); + }, + + showABWarningModal(onConfirm, onCancel) { + // Prevent duplicate modals + if (document.getElementById('ab-warning-modal')) { + return; + } + + const modalHtml = ` +
+
+
+ +

Enable A/B Testing Mode

+
+
+

This will compare two AI responses for each message.

+
    +
  • 2× API usage - Each message generates two responses
  • +
  • Voting required - Once the pending comparison limit is reached, you must resolve one before continuing
  • +
  • You can disable A/B mode at any time to skip voting
  • +
+
+
+ + +
+
+
`; + + document.body.insertAdjacentHTML('beforeend', modalHtml); + const modal = document.getElementById('ab-warning-modal'); + + const closeModal = () => modal?.remove(); + + modal.querySelector('.ab-warning-btn-cancel').addEventListener('click', () => { + closeModal(); + onCancel?.(); + }); + + modal.querySelector('.ab-warning-btn-confirm').addEventListener('click', () => { + closeModal(); + onConfirm?.(); + }); + + // Close on backdrop click + modal.addEventListener('click', (e) => { + if (e.target === modal) { + closeModal(); + onCancel?.(); + } + }); + }, + + showToast(message, duration = 3000) { + // Remove existing toast + document.querySelector('.toast')?.remove(); + + const toast = document.createElement('div'); + toast.className = 'toast'; + toast.textContent = message; + document.body.appendChild(toast); // Trigger animation requestAnimationFrame(() => toast.classList.add('show')); @@ -2023,50 +2558,197 @@ const UI = { }, duration); }, - addABComparisonContainer(msgIdA, msgIdB) { + getTraceModeForMessage(messageId) { + const container = document.querySelector(`.trace-container[data-message-id="${messageId}"]`); + return container?.dataset.traceMode || Chat.state.traceVerboseMode || 'normal'; + }, + + getTraceIconSvg() { + return ``; + }, + + getTraceLabelText(toolCount = 0) { + return toolCount > 0 + ? `Agent Activity (${toolCount} tool${toolCount === 1 ? '' : 's'})` + : 'Agent Activity'; + }, + + bindTraceToggleHandlers(root = document) { + if (!root?.querySelectorAll) return; + root.querySelectorAll('[data-trace-toggle]').forEach(el => { + if (!el._traceToggleBound) { + el._traceToggleBound = true; + el.addEventListener('click', () => UI.toggleTraceExpanded(el.dataset.traceToggle)); + } + }); + }, + + addABComparisonContainer(msgIdA, msgIdB, options = {}) { // Remove empty state if present const empty = this.elements.messagesInner?.querySelector('.messages-empty'); if (empty) empty.remove(); - const showTrace = Chat.state.traceVerboseMode !== 'minimal'; - const traceIconSvg = ``; + const traceMode = this.normalizeABTraceMode(options.traceMode || this.getABTraceMode()); + const showTrace = this.isTraceVisibleMode(traceMode); + const traceIconSvg = this.getTraceIconSvg(); + const traceCollapsed = this.isTraceCollapsedMode(traceMode); const traceHtml = (id) => showTrace ? ` -
-
+
+
${traceIconSvg} - Agent Activity - + ${this.getTraceLabelText()} + 0.0s + +
+
+
-
` : ''; - const html = ` -
-
-
- Model A -
- ${traceHtml(msgIdA)} -
-
-
-
- Model B + // Use normal message structure for each arm — looks like two regular chat messages side by side + const armHtml = (id, label) => ` +
+
+
+
+
+
+ archi + ${label} +
+ +
+
+ ${traceHtml(id)} +
+ +
+ + + +
- ${traceHtml(msgIdB)} -
-
+
`; + + const comparisonKey = Utils.escapeAttr(options.comparisonKey || `${msgIdA}-${msgIdB}`); + const html = ` +
+ ${armHtml(msgIdA, 'Response A')} + ${armHtml(msgIdB, 'Response B')}
`; this.elements.messagesInner?.insertAdjacentHTML('beforeend', html); + // Bind trace toggle handlers (replacing inline onclick for CSP compliance) + this.bindTraceToggleHandlers(this.elements.messagesInner || document); + // Start timers for both A/B arms + if (showTrace) { + this.startTraceTimer(msgIdA); + this.startTraceTimer(msgIdB); + } this.scrollToBottom(); }, + findABComparisonElement(comparisonState = null) { + if (!comparisonState) return null; + + const comparisonId = comparisonState.comparisonId; + if (comparisonId != null) { + const byId = document.querySelector(`.ab-comparison[data-comparison-id="${comparisonId}"]`); + if (byId) return byId; + } + + const armIds = [ + comparisonState.responseAUiId, + comparisonState.responseBUiId, + comparisonState.responseAId, + comparisonState.responseBId, + ].filter(Boolean); + + for (const armId of armIds) { + const arm = document.querySelector(`.ab-arm[data-id="${armId}"], .ab-response[data-id="${armId}"]`); + if (arm) { + const container = arm.closest('.ab-comparison'); + if (container) return container; + } + } + return null; + }, + + setABComparisonId(comparisonState = null) { + const comparison = this.findABComparisonElement(comparisonState); + if (!comparison || !comparisonState?.comparisonId) return; + comparison.dataset.comparisonId = String(comparisonState.comparisonId); + }, + + updateABVariantLabel(armId, variantName) { + const labelEl = document.querySelector(`.ab-arm-variant-name[data-arm-id="${armId}"]`); + if (labelEl) { + labelEl.textContent = variantName || ''; + } + }, + + updateABArmMeta(armId, metaText, visible = true) { + const container = document.querySelector(`.ab-arm[data-id="${armId}"], .ab-response[data-id="${armId}"]`); + const metaEl = container?.querySelector('.message-meta'); + if (!metaEl) return; + metaEl.textContent = metaText || ''; + metaEl.style.display = visible && metaText ? '' : 'none'; + }, + + updateABArmPresentation( + armId, + { variantName = '', modelUsed = '' } = {}, + { disclosureMode = 'post_vote_reveal', reveal = false } = {}, + ) { + const normalizedDisclosureMode = this.normalizeABDisclosureMode(disclosureMode); + const showVariant = normalizedDisclosureMode === 'always_visible' + || (normalizedDisclosureMode === 'post_vote_reveal' && reveal); + this.updateABVariantLabel(armId, showVariant ? variantName : ''); + this.updateABArmMeta(armId, modelUsed, showVariant && !!modelUsed); + }, + + rekeyABArm(oldId, newId) { + if (!oldId || !newId || String(oldId) === String(newId)) return; + + const arm = document.querySelector(`.ab-arm[data-id="${oldId}"], .ab-response[data-id="${oldId}"]`); + if (arm) { + arm.dataset.id = String(newId); + } + + const variantLabel = document.querySelector(`.ab-arm-variant-name[data-arm-id="${oldId}"]`); + if (variantLabel) { + variantLabel.dataset.armId = String(newId); + } + + const traceContainer = document.querySelector(`.trace-container[data-message-id="${oldId}"]`); + if (traceContainer) { + traceContainer.dataset.messageId = String(newId); + const toggle = traceContainer.querySelector('[data-trace-toggle]'); + if (toggle) { + toggle.dataset.traceToggle = String(newId); + } + } + + const activeInterval = this.traceTimerIntervals.get(String(oldId)); + if (activeInterval != null) { + this.traceTimerIntervals.set(String(newId), activeInterval); + this.traceTimerIntervals.delete(String(oldId)); + } + }, + updateABResponse(responseId, html, streaming = false) { - const container = document.querySelector(`.ab-response[data-id="${responseId}"]`); + const container = document.querySelector(`.ab-arm[data-id="${responseId}"], .ab-response[data-id="${responseId}"]`); if (!container) return; - const contentEl = container.querySelector('.ab-response-content'); + const contentEl = container.querySelector('.message-content'); if (contentEl) { contentEl.innerHTML = html; if (streaming) { @@ -2076,21 +2758,27 @@ const UI = { this.scrollToBottom(); }, - showABVoteButtons(comparisonId) { - const comparison = document.getElementById('ab-comparison-active'); + showABVoteButtons(comparisonState) { + const comparison = this.findABComparisonElement(comparisonState); if (!comparison) return; + this.hideABVoteButtons(); + const voteHtml = ` -
-
Which response was better?
+
+
Which response do you prefer?
+
`; @@ -2098,7 +2786,7 @@ const UI = { comparison.insertAdjacentHTML('afterend', voteHtml); // Bind vote button events - document.querySelectorAll('.ab-vote-btn').forEach((btn) => { + comparison.nextElementSibling?.querySelectorAll('.ab-vote-btn').forEach((btn) => { btn.addEventListener('click', () => { const vote = btn.dataset.vote; Chat.submitABPreference(vote); @@ -2112,54 +2800,75 @@ const UI = { document.querySelector('.ab-vote-container')?.remove(); }, - markABWinner(preference) { - const comparison = document.getElementById('ab-comparison-active'); - if (!comparison) return; + stopTraceTimersInElement(container) { + if (!container) return; + container.querySelectorAll('.trace-container[data-message-id]').forEach((traceEl) => { + this.stopTraceTimer(traceEl.dataset.messageId); + }); + }, - const responseA = comparison.querySelector('.ab-response-a'); - const responseB = comparison.querySelector('.ab-response-b'); - - let winnerContent = ''; - let winnerTrace = ''; - if (preference === 'a') { - winnerContent = responseA?.querySelector('.ab-response-content')?.innerHTML || ''; - winnerTrace = responseA?.querySelector('.trace-container')?.outerHTML || ''; - } else if (preference === 'b') { - winnerContent = responseB?.querySelector('.ab-response-content')?.innerHTML || ''; - winnerTrace = responseB?.querySelector('.trace-container')?.outerHTML || ''; - } else { - // Tie - keep both visible but mark them - responseA?.classList.add('ab-response-tie'); - responseB?.classList.add('ab-response-tie'); - comparison.removeAttribute('id'); + markABWinner(preference, comparisonState = null) { + const comparisonEl = this.findABComparisonElement(comparisonState); + if (!comparisonEl) return; + + if (comparisonState?.responseAId) { + this.rekeyABArm(comparisonState.responseAUiId || comparisonState.responseAId, comparisonState.responseAId); + this.updateABArmPresentation(comparisonState.responseAId, { + variantName: comparisonState.variantA, + modelUsed: comparisonState.responseAModelUsed, + }, { + disclosureMode: comparisonState.disclosureMode, + reveal: true, + }); + } + if (comparisonState?.responseBId) { + this.rekeyABArm(comparisonState.responseBUiId || comparisonState.responseBId, comparisonState.responseBId); + this.updateABArmPresentation(comparisonState.responseBId, { + variantName: comparisonState.variantB, + modelUsed: comparisonState.responseBModelUsed, + }, { + disclosureMode: comparisonState.disclosureMode, + reveal: true, + }); + } + + this.stopTraceTimersInElement(comparisonEl); + + const arms = comparisonEl.querySelectorAll('.ab-arm'); + const armA = arms[0]; + const armB = arms[1]; + + if (preference === 'tie') { + // Tie — dim both equally and add a badge + armA?.classList.add('ab-arm-tie'); + armB?.classList.add('ab-arm-tie'); + delete comparisonEl.dataset.comparisonId; return; } - // Replace the entire comparison with a normal archi message (matching createMessageHTML format) - // Include the trace container from the winning response - const metaLabel = Chat.getEntryMetaLabel(); - const metaHtml = metaLabel - ? `
${Utils.escapeHtml(metaLabel)}
` - : ''; + // Winner/loser — collapse to single message + const winner = preference === 'a' ? armA : armB; - const normalMessage = ` -
-
-
-
- archi -
- ${winnerTrace} -
${winnerContent}
- ${metaHtml} -
-
`; + if (winner) { + // Remove the AB label + winner.querySelector('.ab-arm-label')?.remove(); + winner.classList.remove('ab-arm'); + } - comparison.outerHTML = normalMessage; + // Move the live winner node out of the comparison container so its + // finalized timer text and bound trace interactions are preserved. + if (winner && comparisonEl.parentNode) { + comparisonEl.parentNode.insertBefore(winner, comparisonEl); + comparisonEl.remove(); + } }, - removeABComparisonContainer() { - document.getElementById('ab-comparison-active')?.remove(); + removeABComparisonContainer(comparisonState = null) { + const comparisonEl = comparisonState + ? this.findABComparisonElement(comparisonState) + : document.querySelector('.ab-comparison:last-of-type'); + this.stopTraceTimersInElement(comparisonEl); + comparisonEl?.remove(); this.hideABVoteButtons(); }, @@ -2197,12 +2906,12 @@ const UI = { const existingTrace = inner.querySelector('.trace-container'); if (existingTrace) return; - const traceIconSvg = ``; + const traceIconSvg = this.getTraceIconSvg(); const traceHtml = `
${traceIconSvg} - Agent Activity + ${this.getTraceLabelText()} 0.0s
`; inner.insertAdjacentHTML('afterbegin', traceHtml); - + + // Start collapsed in normal mode (user can expand on demand) + if (Chat.state.traceVerboseMode === 'normal') { + const tc = inner.querySelector('.trace-container'); + if (tc) { + tc.classList.add('collapsed'); + const ti = tc.querySelector('.toggle-icon'); + if (ti) ti.innerHTML = '▶'; + } + } + // Start elapsed timer this.startTraceTimer(messageId); }, + getTraceTimerElement(messageId) { + return document.querySelector(`.trace-container[data-message-id="${messageId}"] .trace-timer`); + }, + + getTraceElapsedMs(messageId) { + const timerEl = this.getTraceTimerElement(messageId); + if (!timerEl?.dataset.start) return null; + const startTime = Number.parseInt(timerEl.dataset.start, 10); + if (!Number.isFinite(startTime)) return null; + return Math.max(Date.now() - startTime, 0); + }, + startTraceTimer(messageId) { - const timerEl = document.querySelector(`.trace-container[data-message-id="${messageId}"] .trace-timer`); + const timerKey = String(messageId); + const timerEl = this.getTraceTimerElement(messageId); if (!timerEl) return; - - const startTime = parseInt(timerEl.dataset.start, 10); + + this.stopTraceTimer(timerKey); + delete timerEl.dataset.finalDurationMs; + + const startTime = Number.parseInt(timerEl.dataset.start, 10); + if (!Number.isFinite(startTime)) { + timerEl.dataset.start = String(Date.now()); + } const updateTimer = () => { - const elapsed = (Date.now() - startTime) / 1000; + const baseTime = Number.parseInt(timerEl.dataset.start, 10); + const elapsed = (Date.now() - baseTime) / 1000; timerEl.textContent = elapsed.toFixed(1) + 's'; }; - - const intervalId = setInterval(updateTimer, 100); - timerEl.dataset.intervalId = intervalId; + + updateTimer(); + const intervalId = window.setInterval(updateTimer, 100); + this.traceTimerIntervals.set(timerKey, intervalId); }, - stopTraceTimer(messageId) { - const timerEl = document.querySelector(`.trace-container[data-message-id="${messageId}"] .trace-timer`); - if (!timerEl || !timerEl.dataset.intervalId) return; - - clearInterval(parseInt(timerEl.dataset.intervalId, 10)); - delete timerEl.dataset.intervalId; + stopTraceTimer(messageId, durationMs = null) { + const timerKey = String(messageId); + const intervalId = this.traceTimerIntervals.get(timerKey); + if (intervalId != null) { + clearInterval(intervalId); + this.traceTimerIntervals.delete(timerKey); + } + + const timerEl = this.getTraceTimerElement(messageId); + if (!timerEl) return; + + let resolvedDurationMs = durationMs; + if (resolvedDurationMs == null && timerEl.dataset.finalDurationMs) { + const storedDurationMs = Number.parseInt(timerEl.dataset.finalDurationMs, 10); + if (Number.isFinite(storedDurationMs)) { + resolvedDurationMs = storedDurationMs; + } + } + + const elapsedMs = resolvedDurationMs ?? this.getTraceElapsedMs(messageId); + if (elapsedMs != null) { + timerEl.dataset.finalDurationMs = String(elapsedMs); + timerEl.textContent = Utils.formatDuration(elapsedMs); + } }, toggleTraceExpanded(messageId) { @@ -2291,7 +3049,7 @@ const UI = { }, renderThinkingEnd(messageId, event) { - const step = document.querySelector(`.thinking-step[data-step-id="${event.step_id}"]`); + const step = document.querySelector(`.trace-container[data-message-id="${messageId}"] .thinking-step[data-step-id="${event.step_id}"]`); if (!step) return; // If no thinking content, remove the step entirely - it's just noise @@ -2372,7 +3130,7 @@ const UI = { this.scrollToBottom(); // Auto-expand if verbose mode - if (Chat.state.traceVerboseMode === 'verbose') { + if (this.isTraceExpandedMode(this.getTraceModeForMessage(messageId))) { const step = timeline.querySelector(`[data-step-id="${event.tool_call_id}"]`); step?.classList.add('expanded'); const details = step?.querySelector('.step-details'); @@ -2397,7 +3155,7 @@ const UI = { }, renderToolOutput(messageId, event) { - const step = document.querySelector(`.tool-step[data-tool-call-id="${event.tool_call_id}"]`); + const step = document.querySelector(`.trace-container[data-message-id="${messageId}"] .tool-step[data-tool-call-id="${event.tool_call_id}"]`); if (!step) return; const outputSection = step.querySelector('.tool-output-section'); @@ -2426,7 +3184,7 @@ const UI = { }, renderToolEnd(messageId, event) { - const step = document.querySelector(`.tool-step[data-tool-call-id="${event.tool_call_id}"]`); + const step = document.querySelector(`.trace-container[data-message-id="${messageId}"] .tool-step[data-tool-call-id="${event.tool_call_id}"]`); if (!step) return; step.classList.remove('tool-running'); @@ -2449,8 +3207,9 @@ const UI = { } // Auto-collapse if many tools - const toolCount = document.querySelectorAll('.tool-step').length; - if (Chat.state.traceVerboseMode === 'normal' && toolCount > CONFIG.TRACE.AUTO_COLLAPSE_TOOL_COUNT) { + const timeline = step.closest('.step-timeline'); + const toolCount = timeline?.querySelectorAll('.tool-step').length || 0; + if (this.isTraceCollapsedMode(this.getTraceModeForMessage(messageId)) && toolCount > CONFIG.TRACE.AUTO_COLLAPSE_TOOL_COUNT) { step.classList.remove('expanded'); const details = step.querySelector('.step-details'); if (details) details.style.display = 'none'; @@ -2501,15 +3260,20 @@ const UI = { // ========================================================================= finalizeTrace(messageId, trace, finalEvent) { - this.stopTraceTimer(messageId); + this.stopTraceTimer(messageId, finalEvent?.duration_ms ?? null); const container = document.querySelector(`.trace-container[data-message-id="${messageId}"]`); if (!container) return; - const toolCount = trace.toolCalls.size; + const toolCalls = trace?.toolCalls; + const toolCount = toolCalls instanceof Map + ? toolCalls.size + : Array.isArray(toolCalls) + ? toolCalls.length + : 0; const label = container.querySelector('.trace-label'); - if (label && toolCount > 0) { - label.textContent = `Agent Activity (${toolCount} tool${toolCount === 1 ? '' : 's'})`; + if (label) { + label.textContent = this.getTraceLabelText(toolCount); } // Update context meter if usage available @@ -2518,7 +3282,7 @@ const UI = { } // Auto-collapse in normal mode - if (Chat.state.traceVerboseMode === 'normal') { + if (this.isTraceCollapsedMode(this.getTraceModeForMessage(messageId))) { container.classList.add('collapsed'); const toggleIcon = container.querySelector('.toggle-icon'); if (toggleIcon) toggleIcon.innerHTML = '▶'; @@ -2568,12 +3332,10 @@ const UI = { const durationMs = trace.total_duration_ms || 0; const durationStr = Utils.formatDuration(durationMs); - const traceIconSvg = ``; + const traceIconSvg = this.getTraceIconSvg(); // Build trace container with collapsed state - const labelText = toolCount > 0 - ? `Agent Activity (${toolCount} tool${toolCount === 1 ? '' : 's'})` - : 'Agent Activity'; + const labelText = this.getTraceLabelText(toolCount); const traceHtml = ` diff --git a/src/interfaces/chat_app/templates/index.html b/src/interfaces/chat_app/templates/index.html index 670101a55..699e8e560 100644 --- a/src/interfaces/chat_app/templates/index.html +++ b/src/interfaces/chat_app/templates/index.html @@ -186,6 +186,14 @@

Settings

Advanced --> +
--> - - - - - - - - +