Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,24 @@

from typing import Any, TypeVar

from pydantic_ai.messages import (
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
VideoUrl,
)

import marvin
from marvin.agents.agent import Agent
from marvin.handlers.handlers import AsyncHandler, Handler
from marvin.thread import Thread
from marvin.utilities.asyncio import run_sync
from marvin.utilities.types import TargetType

# Non-string UserContent types that should be passed as attachments
_ATTACHMENT_TYPES = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)

T = TypeVar("T")

DEFAULT_PROMPT = """
Expand Down Expand Up @@ -83,14 +94,24 @@ async def cast_async(
raise ValueError("Instructions are required when casting to string values.")

task_context = context or {}
task_context["Data to transform"] = data
attachments = []

# Handle non-string UserContent types (images, audio, etc.) as attachments
# to avoid serializing binary data into the prompt text
if isinstance(data, _ATTACHMENT_TYPES):
attachments.append(data)
task_context["Data to transform"] = "(provided as attachment)"
else:
task_context["Data to transform"] = data

prompt = prompt or DEFAULT_PROMPT
if instructions:
prompt += f"\n\nYou must follow these instructions for your transformation:\n{instructions}"

task = marvin.Task[target](
name="Cast Task",
instructions=prompt,
attachments=attachments,
context=task_context,
result_type=target,
agents=[agent] if agent else None,
Expand Down
Empty file added tests/basic/fns/__init__.py
Empty file.
110 changes: 110 additions & 0 deletions tests/basic/fns/test_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Tests for cast function - basic unit tests."""

from pydantic_ai.messages import BinaryImage, ImageUrl

import marvin
from marvin.tasks.task import Task


class TestCastWithAttachments:
"""Test that cast properly handles attachment types like images."""

def test_binary_image_passed_as_attachment(self, test_model):
"""Test that BinaryImage is passed as attachment, not in context."""
binary_image = BinaryImage(data=b"fake image data", media_type="image/png")

# We need to inspect the task that gets created
original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.cast(
binary_image,
target=str,
instructions="Describe the image",
)
finally:
Task.__init__ = original_task_init

# Verify the task was constructed correctly
assert captured_task is not None
assert len(captured_task.attachments) == 1
assert captured_task.attachments[0] is binary_image
assert captured_task.context["Data to transform"] == "(provided as attachment)"

def test_image_url_passed_as_attachment(self, test_model):
"""Test that ImageUrl is passed as attachment, not in context."""
image_url = ImageUrl(url="https://example.com/image.png")

original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.cast(
image_url,
target=str,
instructions="Describe the image",
)
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 1
assert captured_task.attachments[0] is image_url
assert captured_task.context["Data to transform"] == "(provided as attachment)"

def test_string_data_not_treated_as_attachment(self, test_model):
"""Test that string data is still passed in context, not as attachment."""
original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.cast("hello world", target=int)
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 0
assert captured_task.context["Data to transform"] == "hello world"

def test_dict_data_not_treated_as_attachment(self, test_model):
"""Test that dict data is still passed in context, not as attachment."""
original_task_init = Task.__init__
captured_task = None

def capture_task(self, *args, **kwargs):
nonlocal captured_task
original_task_init(self, *args, **kwargs)
captured_task = self

Task.__init__ = capture_task

try:
marvin.cast({"key": "value"}, target=str, instructions="Convert to JSON")
finally:
Task.__init__ = original_task_init

assert captured_task is not None
assert len(captured_task.attachments) == 0
assert captured_task.context["Data to transform"] == {"key": "value"}