diff --git a/backend/app/api/ws/chat_namespace.py b/backend/app/api/ws/chat_namespace.py index 61bdb6b94..9cf395c53 100644 --- a/backend/app/api/ws/chat_namespace.py +++ b/backend/app/api/ws/chat_namespace.py @@ -464,6 +464,8 @@ async def on_chat_send(self, sid: str, data: dict) -> dict: db = SessionLocal() pipeline_info = None + pipeline_previous_bot_id = None + allow_pending_status = False try: # Get user user = db.query(User).filter(User.id == user_id).first() @@ -513,6 +515,11 @@ async def on_chat_send(self, sid: str, data: dict) -> dict: ) } + pipeline_previous_bot_id = confirm_result.get( + "current_stage_bot_id" + ) + allow_pending_status = True + # Emit task:status event to notify frontend that task status changed # This triggers PipelineStageIndicator to re-fetch pipeline stage info task_room = f"task:{payload.task_id}" @@ -526,7 +533,7 @@ async def on_chat_send(self, sid: str, data: dict) -> dict: room=task_room, ) logger.info( - f"[WS] pipeline:confirm emitted task:status PENDING for task {payload.task_id}" + f"[WS] pipeline:confirm emitted task:status RUNNING for task {payload.task_id}" ) # Get pipeline info (unified logic for all pipeline operations) @@ -645,8 +652,8 @@ async def on_chat_send(self, sid: str, data: dict) -> dict: } # For pipeline confirm, get the previous stage's bot_id for session management - previous_bot_id = None - if pipeline_info: + previous_bot_id = pipeline_previous_bot_id + if previous_bot_id is None and pipeline_info: previous_bot_id = pipeline_info.get("current_stage_bot_id") params = TaskCreationParams( @@ -669,6 +676,7 @@ async def on_chat_send(self, sid: str, data: dict) -> dict: # TaskRequestBuilder will compare this with current bot_id to determine # if a new session is needed (different bot = new session) previous_bot_id=previous_bot_id, + allow_pending_status=allow_pending_status, device_id=payload.device_id, generate_params=generate_params_dict, ) diff --git a/backend/app/services/chat/storage/task_manager.py b/backend/app/services/chat/storage/task_manager.py index 8b462144d..31e3cb98a 100644 --- a/backend/app/services/chat/storage/task_manager.py +++ b/backend/app/services/chat/storage/task_manager.py @@ -73,6 +73,8 @@ class TaskCreationParams: # When set and different from current bot_id, a new session will be created # This ensures each pipeline stage has independent context previous_bot_id: Optional[int] = None + # Pipeline confirmation marks the task PENDING before creating next-stage subtasks. + allow_pending_status: bool = False # Device ID for local device execution (saved at task creation to avoid race condition) device_id: Optional[str] = None # Video generation parameters (user-selected at generation time) @@ -184,19 +186,27 @@ def get_task_with_access_check( return None, user_id -def check_task_status(db: Session, task: TaskResource) -> None: +def check_task_status( + task: TaskResource, + *, + allow_pending: bool = False, +) -> None: """ Check if task is in a valid state for new messages. Args: - db: Database session task: Task resource to check + allow_pending: Whether PENDING is allowed for internal pipeline handoff Raises: HTTPException: If task is still running or pending """ task_crd = Task.model_validate(task.json) - if task_crd.status and task_crd.status.status in ("RUNNING", "PENDING"): + if not task_crd.status: + return + + status = task_crd.status.status + if status == "RUNNING" or (status == "PENDING" and not allow_pending): raise HTTPException(status_code=400, detail="Task is still running") @@ -692,7 +702,7 @@ async def create_task_and_subtasks( # Get existing task with access check task, subtask_user_id = get_task_with_access_check(db, task_id, user.id) if task: - check_task_status(db, task) + check_task_status(task, allow_pending=params.allow_pending_status) if should_trigger_ai: mark_task_pending(task) # Update modelId in existing task if provided diff --git a/backend/tests/services/chat/storage/test_task_manager.py b/backend/tests/services/chat/storage/test_task_manager.py index 54a0262fb..e14455bd0 100644 --- a/backend/tests/services/chat/storage/test_task_manager.py +++ b/backend/tests/services/chat/storage/test_task_manager.py @@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, patch import pytest +from fastapi import HTTPException from sqlalchemy.orm import Session from app.models.subtask import Subtask, SubtaskRole, SubtaskStatus @@ -16,6 +17,7 @@ from app.models.user import User from app.services.chat.storage.task_manager import ( TaskCreationParams, + check_task_status, create_assistant_subtask, create_task_and_subtasks, ) @@ -172,6 +174,33 @@ def _build_existing_task(task_id: int, user_id: int) -> TaskResource: ) +def test_check_task_status_rejects_pending_by_default(test_user: User): + task = _build_existing_task(task_id=1386, user_id=test_user.id) + task.json["status"]["status"] = "PENDING" + + with pytest.raises(HTTPException) as exc_info: + check_task_status(task) + + assert exc_info.value.status_code == 400 + + +def test_check_task_status_allows_pending_when_requested(test_user: User): + task = _build_existing_task(task_id=1386, user_id=test_user.id) + task.json["status"]["status"] = "PENDING" + + check_task_status(task, allow_pending=True) + + +def test_check_task_status_rejects_running_when_pending_is_allowed(test_user: User): + task = _build_existing_task(task_id=1386, user_id=test_user.id) + task.json["status"]["status"] = "RUNNING" + + with pytest.raises(HTTPException) as exc_info: + check_task_status(task, allow_pending=True) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio async def test_create_task_and_subtasks_resets_existing_task_status_to_pending( test_db: Session, @@ -223,3 +252,57 @@ async def test_create_task_and_subtasks_resets_existing_task_status_to_pending( assert status["progress"] == 0 assert status["errorMessage"] == "" assert result.assistant_subtask is not None + + +@pytest.mark.asyncio +async def test_create_task_and_subtasks_allows_pipeline_confirm_pending_task( + test_db: Session, + test_user: User, +): + task = _build_existing_task(task_id=1387, user_id=test_user.id) + task.json["status"]["status"] = "PENDING" + test_db.add(task) + test_db.commit() + test_db.refresh(task) + + team = SimpleNamespace( + id=1256, + user_id=test_user.id, + name="quickstart", + namespace="default", + ) + params = TaskCreationParams( + message="handoff to quickstart", + allow_pending_status=True, + pipeline_bot_ids=[1255], + ) + + with ( + patch( + "app.services.chat.storage.task_manager.initialize_redis_chat_history", + new=AsyncMock(), + ), + patch( + "app.services.memory.is_memory_enabled_for_user", + return_value=False, + ), + patch( + "app.services.chat.trigger.group_chat.is_task_group_chat", + return_value=False, + ), + ): + result = await create_task_and_subtasks( + db=test_db, + user=test_user, + team=team, + message=params.message, + params=params, + task_id=task.id, + should_trigger_ai=True, + ) + + assert result.task.id == task.id + assert result.user_subtask.prompt == "handoff to quickstart" + assert result.user_subtask.bot_ids == [1255] + assert result.assistant_subtask is not None + assert result.assistant_subtask.bot_ids == [1255]