diff --git a/rlm/core/rlm.py b/rlm/core/rlm.py index ce247ecd..f4527714 100644 --- a/rlm/core/rlm.py +++ b/rlm/core/rlm.py @@ -8,6 +8,7 @@ from rlm.core.types import ( ClientBackend, CodeBlock, + ContextPayload, EnvironmentType, REPLResult, RLMChatCompletion, @@ -186,7 +187,7 @@ def __init__( self.verbose.print_metadata(metadata) @contextmanager - def _spawn_completion_context(self, prompt: str | dict[str, Any]): + def _spawn_completion_context(self, prompt: ContextPayload): """ Spawn an LM handler and environment for a single completion call. @@ -250,7 +251,7 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]): if not self.persistent and hasattr(environment, "cleanup"): environment.cleanup() - def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]: + def _setup_prompt(self, prompt: ContextPayload) -> list[dict[str, Any]]: """ Setup the system prompt for the RLM. Also include metadata about the prompt and build up the initial message history. @@ -268,9 +269,7 @@ def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]: ) return message_history - def completion( - self, prompt: str | dict[str, Any], root_prompt: str | None = None - ) -> RLMChatCompletion: + def completion(self, prompt: ContextPayload, root_prompt: str | None = None) -> RLMChatCompletion: """ Recursive Language Model completion call. This is the main entry point for querying an RLM, and can replace a regular LM completion call. @@ -278,7 +277,7 @@ def completion( Spawns its own environment and LM handler for the duration of this call. Args: - prompt: A single string or dictionary of messages to pass as context to the model. + prompt: A string, dict/list, or pandas DataFrame to pass as context to the model. root_prompt: We allow the RLM's root LM to see a (small) prompt that the user specifies. A common example of this is if the user is asking the RLM to answer a question, we can pass the question as the root prompt. Returns: diff --git a/rlm/core/types.py b/rlm/core/types.py index bd9e0d0a..8931fd62 100644 --- a/rlm/core/types.py +++ b/rlm/core/types.py @@ -2,6 +2,11 @@ from types import ModuleType from typing import Any, Literal +# Type alias for context payloads. At runtime this is Any (to accept pandas +# DataFrames without requiring pandas), but intended types are: +# str | dict[str, Any] | list[Any] | pd.DataFrame +ContextPayload = Any + ClientBackend = Literal[ "openai", "portkey", @@ -261,8 +266,20 @@ class QueryMetadata: context_lengths: list[int] context_total_length: int context_type: str + context_summary: str | None = None + + def __init__(self, prompt: "ContextPayload"): + from rlm.utils.dataframe_utils import dataframe_metadata, get_dataframe_type + + df_type = get_dataframe_type(prompt) + if df_type is not None: + rows, cols = prompt.shape + self.context_lengths = [rows] + self.context_total_length = rows * cols + self.context_type = f"{df_type}_dataframe" + self.context_summary = dataframe_metadata(prompt) + return - def __init__(self, prompt: str | list[str] | dict[Any, Any] | list[dict[Any, Any]]): if isinstance(prompt, str): self.context_lengths = [len(prompt)] self.context_type = "str" @@ -302,3 +319,4 @@ def __init__(self, prompt: str | list[str] | dict[Any, Any] | list[dict[Any, Any raise ValueError(f"Invalid prompt type: {type(prompt)}") self.context_total_length = sum(self.context_lengths) + self.context_summary = None diff --git a/rlm/environments/base_env.py b/rlm/environments/base_env.py index afd6387f..32d46127 100644 --- a/rlm/environments/base_env.py +++ b/rlm/environments/base_env.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Any, Protocol, runtime_checkable -from rlm.core.types import REPLResult +from rlm.core.types import ContextPayload, REPLResult # ============================================================================= # Custom Tools Support @@ -222,7 +222,7 @@ def setup(self): raise NotImplementedError @abstractmethod - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): raise NotImplementedError @abstractmethod @@ -244,7 +244,7 @@ def setup(self): raise NotImplementedError @abstractmethod - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): raise NotImplementedError @abstractmethod @@ -267,7 +267,7 @@ def setup(self): raise NotImplementedError @abstractmethod - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): raise NotImplementedError @abstractmethod @@ -319,9 +319,7 @@ def update_handler_address(self, address: tuple[str, int]) -> None: """ ... - def add_context( - self, context_payload: dict | list | str, context_index: int | None = None - ) -> int: + def add_context(self, context_payload: ContextPayload, context_index: int | None = None) -> int: """Add a context payload, making it available as context_N in code. Versioning: @@ -334,7 +332,7 @@ def add_context( - context (alias to context_0) Args: - context_payload: The context data (string, dict, or list). + context_payload: The context data. context_index: Optional specific index, or None to auto-increment. Returns: diff --git a/rlm/environments/constants.py b/rlm/environments/constants.py index 4982b384..ff8cc496 100644 --- a/rlm/environments/constants.py +++ b/rlm/environments/constants.py @@ -27,6 +27,7 @@ "tqdm>=4.66.0", "python-dateutil>=2.8.2", "regex>=2023.0.0", + "pyarrow>=14.0.0", # For state serialization "dill>=0.3.7", ] diff --git a/rlm/environments/daytona_repl.py b/rlm/environments/daytona_repl.py index afc4700b..e37a9f72 100644 --- a/rlm/environments/daytona_repl.py +++ b/rlm/environments/daytona_repl.py @@ -23,7 +23,7 @@ ) from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched -from rlm.core.types import REPLResult, RLMChatCompletion +from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion from rlm.environments.base_env import IsolatedEnv, extract_tool_value, validate_custom_tools # ============================================================================= @@ -64,6 +64,7 @@ def get_default_image() -> Image: "tqdm>=4.66.0", "python-dateutil>=2.8.2", "regex>=2023.0.0", + "pyarrow>=14.0.0", # For state serialization "dill>=0.3.7", ) @@ -391,7 +392,7 @@ def __init__( auto_stop_interval: int = 0, image: Image | None = None, lm_handler_address: tuple[str, int] | None = None, - context_payload: dict | list | str | None = None, + context_payload: ContextPayload | None = None, setup_code: str | None = None, persistent: bool = False, depth: int = 1, @@ -607,17 +608,25 @@ def _handle_llm_request(self, req_data: dict) -> dict: return {"error": "Unknown request type"} - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): """Load context into the sandbox environment.""" - if isinstance(context_payload, str): + from rlm.utils.dataframe_utils import ( + build_dataframe_context_code, + dataframe_to_parquet_b64, + get_dataframe_type, + ) + + df_type = get_dataframe_type(context_payload) + if df_type is not None: + parquet_b64, df_type = dataframe_to_parquet_b64(context_payload) + self.execute_code(build_dataframe_context_code(parquet_b64, df_type)) + elif isinstance(context_payload, str): escaped = context_payload.replace("\\", "\\\\").replace('"""', '\\"\\"\\"') - context_code = f'context = """{escaped}"""' + self.execute_code(f'context = """{escaped}"""') else: context_json = json.dumps(context_payload) escaped_json = context_json.replace("\\", "\\\\").replace("'", "\\'") - context_code = f"import json; context = json.loads('{escaped_json}')" - - self.execute_code(context_code) + self.execute_code(f"import json; context = json.loads('{escaped_json}')") def execute_code(self, code: str) -> REPLResult: """Execute code in the Daytona sandbox and return result.""" diff --git a/rlm/environments/docker_repl.py b/rlm/environments/docker_repl.py index 926b5226..b30fab6f 100644 --- a/rlm/environments/docker_repl.py +++ b/rlm/environments/docker_repl.py @@ -16,10 +16,12 @@ import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched -from rlm.core.types import REPLResult, RLMChatCompletion +from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion from rlm.environments.base_env import NonIsolatedEnv +from rlm.utils.dataframe_utils import dataframe_to_parquet_bytes, get_dataframe_type class LLMProxyHandler(BaseHTTPRequestHandler): @@ -196,7 +198,7 @@ def __init__( self, image: str = "python:3.11-slim", lm_handler_address: tuple[str, int] | None = None, - context_payload: dict | list | str | None = None, + context_payload: ContextPayload | None = None, setup_code: str | None = None, persistent: bool = False, depth: int = 1, @@ -271,15 +273,36 @@ def setup(self): self.container_id = result.stdout.strip() - # Install dependencies + # Install base dependencies subprocess.run( ["docker", "exec", self.container_id, "pip", "install", "-q", "dill", "requests"], capture_output=True, ) + self._pandas_installed = False - def load_context(self, context_payload: dict | list | str): + def _ensure_pandas(self): + """Install pandas and pyarrow in the container if not already installed.""" + if self._pandas_installed: + return + subprocess.run( + ["docker", "exec", self.container_id, "pip", "install", "-q", "pandas", "pyarrow"], + capture_output=True, + ) + self._pandas_installed = True + + def load_context(self, context_payload: ContextPayload): """Load context by writing to a file in the mounted workspace.""" - if isinstance(context_payload, str): + df_type = get_dataframe_type(context_payload) + if df_type is not None: + self._ensure_pandas() + context_path = os.path.join(self.temp_dir, "context.parquet") + parquet_bytes, df_type = dataframe_to_parquet_bytes(context_payload) + with open(context_path, "wb") as f: + f.write(parquet_bytes) + self.execute_code( + "import pandas as pd\ncontext = pd.read_parquet('/workspace/context.parquet')" + ) + elif isinstance(context_payload, str): context_path = os.path.join(self.temp_dir, "context.txt") with open(context_path, "w") as f: f.write(context_payload) diff --git a/rlm/environments/local_repl.py b/rlm/environments/local_repl.py index afdde3b3..9f446143 100644 --- a/rlm/environments/local_repl.py +++ b/rlm/environments/local_repl.py @@ -13,13 +13,17 @@ from typing import Any from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched -from rlm.core.types import REPLResult, RLMChatCompletion +from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion from rlm.environments.base_env import ( RESERVED_TOOL_NAMES, NonIsolatedEnv, extract_tool_value, validate_custom_tools, ) +from rlm.utils.dataframe_utils import ( + dataframe_to_parquet_bytes, + get_dataframe_type, +) # ============================================================================= # Safe Builtins @@ -127,7 +131,7 @@ class LocalREPL(NonIsolatedEnv): def __init__( self, lm_handler_address: tuple[str, int] | None = None, - context_payload: dict | list | str | None = None, + context_payload: ContextPayload | None = None, setup_code: str | None = None, persistent: bool = False, depth: int = 1, @@ -342,13 +346,11 @@ def _rlm_query_batched(self, prompts: list[str], model: str | None = None) -> li # Fall back to plain batched LM call if no recursive capability return self._llm_query_batched(prompts, model) - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): """Load context into the environment as context_0 (and 'context' alias).""" self.add_context(context_payload, 0) - def add_context( - self, context_payload: dict | list | str, context_index: int | None = None - ) -> int: + def add_context(self, context_payload: ContextPayload, context_index: int | None = None) -> int: """ Add a context with versioned variable name. @@ -364,7 +366,16 @@ def add_context( var_name = f"context_{context_index}" - if isinstance(context_payload, str): + df_type = get_dataframe_type(context_payload) + if df_type is not None: + context_path = os.path.join(self.temp_dir, f"context_{context_index}.parquet") + parquet_bytes, df_type = dataframe_to_parquet_bytes(context_payload) + with open(context_path, "wb") as f: + f.write(parquet_bytes) + self.execute_code( + f"import pandas as pd\n{var_name} = pd.read_parquet(r'{context_path}')" + ) + elif isinstance(context_payload, str): context_path = os.path.join(self.temp_dir, f"context_{context_index}.txt") with open(context_path, "w") as f: f.write(context_payload) diff --git a/rlm/environments/modal_repl.py b/rlm/environments/modal_repl.py index 4287a44c..20cf8df2 100644 --- a/rlm/environments/modal_repl.py +++ b/rlm/environments/modal_repl.py @@ -3,12 +3,13 @@ import textwrap import threading import time +from typing import Any import modal import requests from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched -from rlm.core.types import REPLResult, RLMChatCompletion +from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion from rlm.environments.base_env import IsolatedEnv from rlm.environments.constants import APT_PACKAGES, PIP_PACKAGES @@ -297,7 +298,7 @@ def __init__( image: modal.Image | None = None, timeout: int = 600, lm_handler_address: tuple[str, int] | None = None, - context_payload: dict | list | str | None = None, + context_payload: ContextPayload | None = None, setup_code: str | None = None, persistent: bool = False, depth: int = 1, @@ -435,17 +436,25 @@ def _handle_llm_request(self, req_data: dict) -> dict: return {"error": "Unknown request type"} - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): """Load context into the sandbox environment.""" - if isinstance(context_payload, str): + from rlm.utils.dataframe_utils import ( + build_dataframe_context_code, + dataframe_to_parquet_b64, + get_dataframe_type, + ) + + df_type = get_dataframe_type(context_payload) + if df_type is not None: + parquet_b64, df_type = dataframe_to_parquet_b64(context_payload) + self.execute_code(build_dataframe_context_code(parquet_b64, df_type)) + elif isinstance(context_payload, str): escaped = context_payload.replace("\\", "\\\\").replace('"""', '\\"\\"\\"') - context_code = f'context = """{escaped}"""' + self.execute_code(f'context = """{escaped}"""') else: context_json = json.dumps(context_payload) escaped_json = context_json.replace("\\", "\\\\").replace("'", "\\'") - context_code = f"import json; context = json.loads('{escaped_json}')" - - self.execute_code(context_code) + self.execute_code(f"import json; context = json.loads('{escaped_json}')") def execute_code(self, code: str) -> REPLResult: """Execute code in the Modal sandbox and return result.""" diff --git a/rlm/environments/prime_repl.py b/rlm/environments/prime_repl.py index eeb9a28a..b4710060 100644 --- a/rlm/environments/prime_repl.py +++ b/rlm/environments/prime_repl.py @@ -22,7 +22,7 @@ ) from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched -from rlm.core.types import REPLResult, RLMChatCompletion +from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion from rlm.environments.base_env import IsolatedEnv from rlm.environments.constants import APT_PACKAGES, PIP_PACKAGES @@ -296,7 +296,7 @@ def __init__( docker_image: str = "python:3.11-slim", timeout_minutes: int = 60, lm_handler_address: tuple[str, int] | None = None, - context_payload: dict | list | str | None = None, + context_payload: ContextPayload | None = None, setup_code: str | None = None, network_access: bool = True, persistent: bool = False, @@ -502,17 +502,25 @@ def _handle_llm_request(self, req_data: dict) -> dict: return {"error": "Unknown request type"} - def load_context(self, context_payload: dict | list | str): + def load_context(self, context_payload: ContextPayload): """Load context into the sandbox environment.""" - if isinstance(context_payload, str): + from rlm.utils.dataframe_utils import ( + build_dataframe_context_code, + dataframe_to_parquet_b64, + get_dataframe_type, + ) + + df_type = get_dataframe_type(context_payload) + if df_type is not None: + parquet_b64, df_type = dataframe_to_parquet_b64(context_payload) + self.execute_code(build_dataframe_context_code(parquet_b64, df_type)) + elif isinstance(context_payload, str): escaped = context_payload.replace("\\", "\\\\").replace('"""', '\\"\\"\\"') - context_code = f'context = """{escaped}"""' + self.execute_code(f'context = """{escaped}"""') else: context_json = json.dumps(context_payload) escaped_json = context_json.replace("\\", "\\\\").replace("'", "\\'") - context_code = f"import json; context = json.loads('{escaped_json}')" - - self.execute_code(context_code) + self.execute_code(f"import json; context = json.loads('{escaped_json}')") def execute_code(self, code: str) -> REPLResult: """Execute code in the Prime sandbox and return result.""" diff --git a/rlm/utils/dataframe_utils.py b/rlm/utils/dataframe_utils.py new file mode 100644 index 00000000..9eeab716 --- /dev/null +++ b/rlm/utils/dataframe_utils.py @@ -0,0 +1,84 @@ +import base64 +import io +from typing import Any, Literal + +DataFrameType = Literal["pandas"] + + +def is_pandas_dataframe(value: Any) -> bool: + try: + import pandas as pd + except ImportError: + return False + return isinstance(value, pd.DataFrame) + + +def get_dataframe_type(value: Any) -> DataFrameType | None: + if is_pandas_dataframe(value): + return "pandas" + return None + + +def dataframe_to_parquet_bytes(value: Any) -> tuple[bytes, DataFrameType]: + df_type = get_dataframe_type(value) + if df_type is None: + raise ValueError(f"Unsupported DataFrame type: {type(value)}") + + buffer = io.BytesIO() + value.to_parquet(buffer, index=False) + return buffer.getvalue(), df_type + + +def dataframe_to_parquet_b64(value: Any) -> tuple[str, DataFrameType]: + data, df_type = dataframe_to_parquet_bytes(value) + return base64.b64encode(data).decode("ascii"), df_type + + +def dataframe_metadata(value: Any) -> str: + df_type = get_dataframe_type(value) + if df_type is None: + raise ValueError(f"Unsupported DataFrame type: {type(value)}") + + rows, cols = value.shape + all_cols = list(value.columns) + dtypes = {col: str(dtype) for col, dtype in value.dtypes.items()} + null_counts = value.isna().sum().to_dict() + memory_bytes = int(value.memory_usage(deep=True).sum()) + + if cols > 20: + displayed_cols = all_cols[:20] + extra = cols - 20 + dtypes = {col: dtypes[col] for col in displayed_cols} + dtypes["..."] = f"{extra} more columns" + null_counts = {col: null_counts[col] for col in displayed_cols} + null_counts["..."] = f"{extra} more columns" + preview_df = value[displayed_cols] + else: + preview_df = value + + head_rows = preview_df.head(3).to_dict(orient="records") + tail_rows = preview_df.tail(3).to_dict(orient="records") + + summary_lines = [ + f"DataFrame type: {df_type}", + f"Shape: {rows} rows x {cols} columns", + f"Dtypes: {dtypes}", + f"Null counts: {null_counts}", + f"Estimated memory: {memory_bytes} bytes", + f"Head (3): {head_rows}", + f"Tail (3): {tail_rows}", + ] + return "\n".join(summary_lines) + + +def build_dataframe_context_code( + parquet_b64: str, + df_type: DataFrameType, + var_name: str = "context", +) -> str: + return ( + "import base64, io\n" + "import pandas as pd\n" + f"_parquet_bytes = base64.b64decode('{parquet_b64}')\n" + f"{var_name} = pd.read_parquet(io.BytesIO(_parquet_bytes))" + ) diff --git a/rlm/utils/prompts.py b/rlm/utils/prompts.py index 3add902e..e619d7d5 100644 --- a/rlm/utils/prompts.py +++ b/rlm/utils/prompts.py @@ -155,7 +155,19 @@ def build_rlm_system_prompt( # Insert custom tools section into the system prompt final_system_prompt = system_prompt.format(custom_tools_section=custom_tools_section) - metadata_prompt = f"Your context is a {context_type} with {context_total_length} total characters, and is broken up into chunks of char lengths: {context_lengths}." + if "dataframe" in context_type: + metadata_prompt = ( + f"Your context is a {context_type}. It has {context_total_length} total cells " + f"and is broken up into chunk lengths: {context_lengths}." + ) + else: + metadata_prompt = ( + f"Your context is a {context_type} with {context_total_length} total characters, " + f"and is broken up into chunks of char lengths: {context_lengths}." + ) + + if query_metadata.context_summary: + metadata_prompt = metadata_prompt + "\n\n" + query_metadata.context_summary return [ {"role": "system", "content": final_system_prompt}, diff --git a/tests/test_dataframe_utils.py b/tests/test_dataframe_utils.py new file mode 100644 index 00000000..fb2ef3d6 --- /dev/null +++ b/tests/test_dataframe_utils.py @@ -0,0 +1,157 @@ +"""Comprehensive tests for rlm.utils.dataframe_utils.""" + +import base64 +import io + +import pytest + +from rlm.utils.dataframe_utils import ( + build_dataframe_context_code, + dataframe_metadata, + dataframe_to_parquet_b64, + dataframe_to_parquet_bytes, + get_dataframe_type, + is_pandas_dataframe, +) + + +# --------------------------------------------------------------------------- +# Detection helpers +# --------------------------------------------------------------------------- + + +class TestIsPandasDataframe: + """Tests for is_pandas_dataframe().""" + + def test_true_for_dataframe(self): + pd = pytest.importorskip("pandas") + assert is_pandas_dataframe(pd.DataFrame({"a": [1]})) is True + + def test_false_for_dict(self): + assert is_pandas_dataframe({"a": 1}) is False + + def test_false_for_list(self): + assert is_pandas_dataframe([1, 2, 3]) is False + + def test_false_for_string(self): + assert is_pandas_dataframe("hello") is False + + def test_false_for_none(self): + assert is_pandas_dataframe(None) is False + + def test_false_for_int(self): + assert is_pandas_dataframe(42) is False + + +class TestGetDataframeType: + """Tests for get_dataframe_type().""" + + def test_pandas_dataframe(self): + pd = pytest.importorskip("pandas") + assert get_dataframe_type(pd.DataFrame({"x": [1]})) == "pandas" + + def test_non_dataframe_returns_none(self): + assert get_dataframe_type("hello") is None + assert get_dataframe_type(42) is None + assert get_dataframe_type({"a": 1}) is None + assert get_dataframe_type([1, 2]) is None + assert get_dataframe_type(None) is None + + +# --------------------------------------------------------------------------- +# Parquet serialization +# --------------------------------------------------------------------------- + + +class TestDataframeToParquetBytes: + """Tests for dataframe_to_parquet_bytes().""" + + def test_roundtrip(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + data, df_type = dataframe_to_parquet_bytes(df) + assert df_type == "pandas" + assert isinstance(data, bytes) + roundtrip = pd.read_parquet(io.BytesIO(data)) + assert list(roundtrip.columns) == ["a", "b"] + assert roundtrip["a"].tolist() == [1, 2, 3] + + def test_rejects_non_dataframe(self): + with pytest.raises(ValueError, match="Unsupported DataFrame type"): + dataframe_to_parquet_bytes({"not": "a df"}) + + +class TestDataframeToParquetB64: + """Tests for dataframe_to_parquet_b64().""" + + def test_roundtrip(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"col": [10, 20]}) + b64_str, df_type = dataframe_to_parquet_b64(df) + assert df_type == "pandas" + raw = base64.b64decode(b64_str) + roundtrip = pd.read_parquet(io.BytesIO(raw)) + assert roundtrip["col"].tolist() == [10, 20] + + +# --------------------------------------------------------------------------- +# Metadata +# --------------------------------------------------------------------------- + + +class TestDataframeMetadata: + """Tests for dataframe_metadata().""" + + def test_basic(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + meta = dataframe_metadata(df) + assert "Shape: 2 rows x 2 columns" in meta + assert "pandas" in meta + + def test_nulls(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [1, None, 3], "b": [None, None, "x"]}) + meta = dataframe_metadata(df) + assert "Null counts" in meta + + def test_many_columns_truncation(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({f"col_{i}": [i] for i in range(25)}) + meta = dataframe_metadata(df) + assert "5 more columns" in meta + # Head/tail rows should also be truncated to displayed columns + assert "col_20" not in meta + assert "col_0" in meta + + def test_rejects_non_dataframe(self): + with pytest.raises(ValueError, match="Unsupported DataFrame type"): + dataframe_metadata("not a dataframe") + + +# --------------------------------------------------------------------------- +# Code generation — DataFrame path +# --------------------------------------------------------------------------- + + +class TestBuildDataframeContextCode: + """Tests for build_dataframe_context_code().""" + + def test_exec_roundtrip(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"x": [10, 20, 30]}) + b64_str, df_type = dataframe_to_parquet_b64(df) + code = build_dataframe_context_code(b64_str, df_type) + ns = {} + exec(code, ns) + assert ns["context"]["x"].tolist() == [10, 20, 30] + + def test_custom_var_name(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"v": [1]}) + b64_str, df_type = dataframe_to_parquet_b64(df) + code = build_dataframe_context_code(b64_str, df_type, var_name="my_df") + ns = {} + exec(code, ns) + assert "my_df" in ns + assert ns["my_df"]["v"].tolist() == [1] diff --git a/tests/test_local_repl.py b/tests/test_local_repl.py index b16e43cb..5ac1ed5a 100644 --- a/tests/test_local_repl.py +++ b/tests/test_local_repl.py @@ -2,6 +2,8 @@ import os +import pytest + from rlm.environments.local_repl import LocalREPL @@ -278,3 +280,50 @@ def test_simulated_rlm_completions_functions_not_preserved(self): assert "NameError" in result.stderr assert "my_helper" in result.stderr completion_2_env.cleanup() + + +class TestLocalREPLDataFrame: + """Tests for DataFrame context loading in LocalREPL.""" + + def test_dataframe_context(self): + """Test that a DataFrame is accessible as `context` and supports operations.""" + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + with LocalREPL(context_payload=df) as repl: + result = repl.execute_code("total = int(context['a'].sum())") + assert result.stderr == "" + assert repl.locals["total"] == 6 + + result = repl.execute_code("cols = list(context.columns)") + assert result.stderr == "" + assert repl.locals["cols"] == ["a", "b"] + + def test_dataframe_preserves_dtypes(self): + """Test that int/float/str types survive the parquet round-trip.""" + pd = pytest.importorskip("pandas") + df = pd.DataFrame({ + "int_col": [1, 2, 3], + "float_col": [1.5, 2.5, 3.5], + "str_col": ["a", "b", "c"], + }) + with LocalREPL(context_payload=df) as repl: + repl.execute_code( + "dtypes = {c: str(context[c].dtype) for c in context.columns}" + ) + dtypes = repl.locals["dtypes"] + assert "int" in dtypes["int_col"] + assert "float" in dtypes["float_col"] + assert dtypes["str_col"] == "object" + + def test_dataframe_versioned_context(self): + """Test add_context with multiple DataFrames.""" + pd = pytest.importorskip("pandas") + df1 = pd.DataFrame({"x": [10, 20]}) + df2 = pd.DataFrame({"y": [30, 40]}) + with LocalREPL(context_payload=df1) as repl: + repl.add_context(df2) + result = repl.execute_code( + "val = int(context_0['x'].sum()) + int(context_1['y'].sum())" + ) + assert result.stderr == "" + assert repl.locals["val"] == 100 diff --git a/tests/test_types.py b/tests/test_types.py index 847043f3..d0852621 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,5 +1,7 @@ """Tests for core types.""" +import pytest + from rlm.core.types import ( CodeBlock, ModelUsageSummary, @@ -199,6 +201,62 @@ def test_string_prompt(self): assert meta.context_total_length == 13 assert meta.context_lengths == [13] + def test_pandas_dataframe_prompt(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + meta = QueryMetadata(df) + assert meta.context_type == "pandas_dataframe" + assert meta.context_total_length == 6 + assert meta.context_lengths == [3] + assert meta.context_summary is not None + + def test_dataframe_with_nulls(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [1, None, 3], "b": [None, None, "x"]}) + meta = QueryMetadata(df) + assert meta.context_type == "pandas_dataframe" + assert "Null counts" in meta.context_summary + + def test_dataframe_many_columns(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({f"col_{i}": [i] for i in range(25)}) + meta = QueryMetadata(df) + assert meta.context_type == "pandas_dataframe" + assert "5 more columns" in meta.context_summary + + def test_dataframe_single_row(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"x": [42]}) + meta = QueryMetadata(df) + assert meta.context_lengths == [1] + assert meta.context_total_length == 1 + + def test_dataframe_empty(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"a": [], "b": []}) + meta = QueryMetadata(df) + assert meta.context_lengths == [0] + assert meta.context_total_length == 0 + + def test_dict_prompt(self): + meta = QueryMetadata({"key": "value", "num": "42"}) + assert meta.context_type == "dict" + assert meta.context_total_length > 0 + + def test_list_of_strings(self): + meta = QueryMetadata(["hello", "world"]) + assert meta.context_type == "list" + assert meta.context_lengths == [5, 5] + + def test_empty_list(self): + meta = QueryMetadata([]) + assert meta.context_type == "list" + assert meta.context_lengths == [0] + + def test_invalid_type_raises(self): + with pytest.raises(ValueError, match="Invalid prompt type"): + QueryMetadata(12345) + class TestRLMMetadata: """Tests for RLMMetadata."""