Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions backend/app/api/ws/chat_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ async def on_chat_send(self, sid: str, data: dict) -> dict:

db = SessionLocal()
pipeline_info = None
pipeline_previous_bot_id = None
allow_pending_status = False
try:
# Get user
user = db.query(User).filter(User.id == user_id).first()
Expand Down Expand Up @@ -513,6 +515,11 @@ async def on_chat_send(self, sid: str, data: dict) -> dict:
)
}

pipeline_previous_bot_id = confirm_result.get(
"current_stage_bot_id"
)
allow_pending_status = True

# Emit task:status event to notify frontend that task status changed
# This triggers PipelineStageIndicator to re-fetch pipeline stage info
task_room = f"task:{payload.task_id}"
Expand All @@ -526,7 +533,7 @@ async def on_chat_send(self, sid: str, data: dict) -> dict:
room=task_room,
)
logger.info(
f"[WS] pipeline:confirm emitted task:status PENDING for task {payload.task_id}"
f"[WS] pipeline:confirm emitted task:status RUNNING for task {payload.task_id}"
)

# Get pipeline info (unified logic for all pipeline operations)
Expand Down Expand Up @@ -645,8 +652,8 @@ async def on_chat_send(self, sid: str, data: dict) -> dict:
}

# For pipeline confirm, get the previous stage's bot_id for session management
previous_bot_id = None
if pipeline_info:
previous_bot_id = pipeline_previous_bot_id
if previous_bot_id is None and pipeline_info:
previous_bot_id = pipeline_info.get("current_stage_bot_id")

params = TaskCreationParams(
Expand All @@ -669,6 +676,7 @@ async def on_chat_send(self, sid: str, data: dict) -> dict:
# TaskRequestBuilder will compare this with current bot_id to determine
# if a new session is needed (different bot = new session)
previous_bot_id=previous_bot_id,
allow_pending_status=allow_pending_status,
device_id=payload.device_id,
generate_params=generate_params_dict,
)
Expand Down
18 changes: 14 additions & 4 deletions backend/app/services/chat/storage/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class TaskCreationParams:
# When set and different from current bot_id, a new session will be created
# This ensures each pipeline stage has independent context
previous_bot_id: Optional[int] = None
# Pipeline confirmation marks the task PENDING before creating next-stage subtasks.
allow_pending_status: bool = False
# Device ID for local device execution (saved at task creation to avoid race condition)
device_id: Optional[str] = None
# Video generation parameters (user-selected at generation time)
Expand Down Expand Up @@ -184,19 +186,27 @@ def get_task_with_access_check(
return None, user_id


def check_task_status(db: Session, task: TaskResource) -> None:
def check_task_status(
task: TaskResource,
*,
allow_pending: bool = False,
) -> None:
"""
Check if task is in a valid state for new messages.

Args:
db: Database session
task: Task resource to check
allow_pending: Whether PENDING is allowed for internal pipeline handoff

Raises:
HTTPException: If task is still running or pending
"""
task_crd = Task.model_validate(task.json)
if task_crd.status and task_crd.status.status in ("RUNNING", "PENDING"):
if not task_crd.status:
return

status = task_crd.status.status
if status == "RUNNING" or (status == "PENDING" and not allow_pending):
raise HTTPException(status_code=400, detail="Task is still running")


Expand Down Expand Up @@ -692,7 +702,7 @@ async def create_task_and_subtasks(
# Get existing task with access check
task, subtask_user_id = get_task_with_access_check(db, task_id, user.id)
if task:
check_task_status(db, task)
check_task_status(task, allow_pending=params.allow_pending_status)
if should_trigger_ai:
mark_task_pending(task)
# Update modelId in existing task if provided
Expand Down
83 changes: 83 additions & 0 deletions backend/tests/services/chat/storage/test_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from unittest.mock import AsyncMock, patch

import pytest
from fastapi import HTTPException
from sqlalchemy.orm import Session

from app.models.subtask import Subtask, SubtaskRole, SubtaskStatus
from app.models.task import TaskResource
from app.models.user import User
from app.services.chat.storage.task_manager import (
TaskCreationParams,
check_task_status,
create_assistant_subtask,
create_task_and_subtasks,
)
Expand Down Expand Up @@ -172,6 +174,33 @@ def _build_existing_task(task_id: int, user_id: int) -> TaskResource:
)


def test_check_task_status_rejects_pending_by_default(test_user: User):
task = _build_existing_task(task_id=1386, user_id=test_user.id)
task.json["status"]["status"] = "PENDING"

with pytest.raises(HTTPException) as exc_info:
check_task_status(task)

assert exc_info.value.status_code == 400


def test_check_task_status_allows_pending_when_requested(test_user: User):
task = _build_existing_task(task_id=1386, user_id=test_user.id)
task.json["status"]["status"] = "PENDING"

check_task_status(task, allow_pending=True)


def test_check_task_status_rejects_running_when_pending_is_allowed(test_user: User):
task = _build_existing_task(task_id=1386, user_id=test_user.id)
task.json["status"]["status"] = "RUNNING"

with pytest.raises(HTTPException) as exc_info:
check_task_status(task, allow_pending=True)

assert exc_info.value.status_code == 400


@pytest.mark.asyncio
async def test_create_task_and_subtasks_resets_existing_task_status_to_pending(
test_db: Session,
Expand Down Expand Up @@ -223,3 +252,57 @@ async def test_create_task_and_subtasks_resets_existing_task_status_to_pending(
assert status["progress"] == 0
assert status["errorMessage"] == ""
assert result.assistant_subtask is not None


@pytest.mark.asyncio
async def test_create_task_and_subtasks_allows_pipeline_confirm_pending_task(
test_db: Session,
test_user: User,
):
task = _build_existing_task(task_id=1387, user_id=test_user.id)
task.json["status"]["status"] = "PENDING"
test_db.add(task)
test_db.commit()
test_db.refresh(task)

team = SimpleNamespace(
id=1256,
user_id=test_user.id,
name="quickstart",
namespace="default",
)
params = TaskCreationParams(
message="handoff to quickstart",
allow_pending_status=True,
pipeline_bot_ids=[1255],
)

with (
patch(
"app.services.chat.storage.task_manager.initialize_redis_chat_history",
new=AsyncMock(),
),
patch(
"app.services.memory.is_memory_enabled_for_user",
return_value=False,
),
patch(
"app.services.chat.trigger.group_chat.is_task_group_chat",
return_value=False,
),
):
result = await create_task_and_subtasks(
db=test_db,
user=test_user,
team=team,
message=params.message,
params=params,
task_id=task.id,
should_trigger_ai=True,
)

assert result.task.id == task.id
assert result.user_subtask.prompt == "handoff to quickstart"
assert result.user_subtask.bot_ids == [1255]
assert result.assistant_subtask is not None
assert result.assistant_subtask.bot_ids == [1255]