Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions rlm/core/rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rlm.core.types import (
ClientBackend,
CodeBlock,
ContextPayload,
EnvironmentType,
REPLResult,
RLMChatCompletion,
Expand Down Expand Up @@ -186,7 +187,7 @@ def __init__(
self.verbose.print_metadata(metadata)

@contextmanager
def _spawn_completion_context(self, prompt: str | dict[str, Any]):
def _spawn_completion_context(self, prompt: ContextPayload):
"""
Spawn an LM handler and environment for a single completion call.

Expand Down Expand Up @@ -250,7 +251,7 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]):
if not self.persistent and hasattr(environment, "cleanup"):
environment.cleanup()

def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
def _setup_prompt(self, prompt: ContextPayload) -> list[dict[str, Any]]:
"""
Setup the system prompt for the RLM. Also include metadata about the prompt and build
up the initial message history.
Expand All @@ -268,17 +269,15 @@ def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
)
return message_history

def completion(
self, prompt: str | dict[str, Any], root_prompt: str | None = None
) -> RLMChatCompletion:
def completion(self, prompt: ContextPayload, root_prompt: str | None = None) -> RLMChatCompletion:
"""
Recursive Language Model completion call. This is the main entry point for querying an RLM, and
can replace a regular LM completion call.

Spawns its own environment and LM handler for the duration of this call.

Args:
prompt: A single string or dictionary of messages to pass as context to the model.
prompt: A string, dict/list, or pandas DataFrame to pass as context to the model.
root_prompt: We allow the RLM's root LM to see a (small) prompt that the user specifies. A common example of this
is if the user is asking the RLM to answer a question, we can pass the question as the root prompt.
Returns:
Expand Down
20 changes: 19 additions & 1 deletion rlm/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from types import ModuleType
from typing import Any, Literal

# Type alias for context payloads. At runtime this is Any (to accept pandas
# DataFrames without requiring pandas), but intended types are:
# str | dict[str, Any] | list[Any] | pd.DataFrame
ContextPayload = Any

ClientBackend = Literal[
"openai",
"portkey",
Expand Down Expand Up @@ -261,8 +266,20 @@ class QueryMetadata:
context_lengths: list[int]
context_total_length: int
context_type: str
context_summary: str | None = None

def __init__(self, prompt: "ContextPayload"):
from rlm.utils.dataframe_utils import dataframe_metadata, get_dataframe_type

df_type = get_dataframe_type(prompt)
if df_type is not None:
rows, cols = prompt.shape
self.context_lengths = [rows]
self.context_total_length = rows * cols
self.context_type = f"{df_type}_dataframe"
self.context_summary = dataframe_metadata(prompt)
return

def __init__(self, prompt: str | list[str] | dict[Any, Any] | list[dict[Any, Any]]):
if isinstance(prompt, str):
self.context_lengths = [len(prompt)]
self.context_type = "str"
Expand Down Expand Up @@ -302,3 +319,4 @@ def __init__(self, prompt: str | list[str] | dict[Any, Any] | list[dict[Any, Any
raise ValueError(f"Invalid prompt type: {type(prompt)}")

self.context_total_length = sum(self.context_lengths)
self.context_summary = None
14 changes: 6 additions & 8 deletions rlm/environments/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Any, Protocol, runtime_checkable

from rlm.core.types import REPLResult
from rlm.core.types import ContextPayload, REPLResult

# =============================================================================
# Custom Tools Support
Expand Down Expand Up @@ -222,7 +222,7 @@ def setup(self):
raise NotImplementedError

@abstractmethod
def load_context(self, context_payload: dict | list | str):
def load_context(self, context_payload: ContextPayload):
raise NotImplementedError

@abstractmethod
Expand All @@ -244,7 +244,7 @@ def setup(self):
raise NotImplementedError

@abstractmethod
def load_context(self, context_payload: dict | list | str):
def load_context(self, context_payload: ContextPayload):
raise NotImplementedError

@abstractmethod
Expand All @@ -267,7 +267,7 @@ def setup(self):
raise NotImplementedError

@abstractmethod
def load_context(self, context_payload: dict | list | str):
def load_context(self, context_payload: ContextPayload):
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -319,9 +319,7 @@ def update_handler_address(self, address: tuple[str, int]) -> None:
"""
...

def add_context(
self, context_payload: dict | list | str, context_index: int | None = None
) -> int:
def add_context(self, context_payload: ContextPayload, context_index: int | None = None) -> int:
"""Add a context payload, making it available as context_N in code.

Versioning:
Expand All @@ -334,7 +332,7 @@ def add_context(
- context (alias to context_0)

Args:
context_payload: The context data (string, dict, or list).
context_payload: The context data.
context_index: Optional specific index, or None to auto-increment.

Returns:
Expand Down
1 change: 1 addition & 0 deletions rlm/environments/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"tqdm>=4.66.0",
"python-dateutil>=2.8.2",
"regex>=2023.0.0",
"pyarrow>=14.0.0",
# For state serialization
"dill>=0.3.7",
]
25 changes: 17 additions & 8 deletions rlm/environments/daytona_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched
from rlm.core.types import REPLResult, RLMChatCompletion
from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion
from rlm.environments.base_env import IsolatedEnv, extract_tool_value, validate_custom_tools

# =============================================================================
Expand Down Expand Up @@ -64,6 +64,7 @@ def get_default_image() -> Image:
"tqdm>=4.66.0",
"python-dateutil>=2.8.2",
"regex>=2023.0.0",
"pyarrow>=14.0.0",
# For state serialization
"dill>=0.3.7",
)
Expand Down Expand Up @@ -391,7 +392,7 @@ def __init__(
auto_stop_interval: int = 0,
image: Image | None = None,
lm_handler_address: tuple[str, int] | None = None,
context_payload: dict | list | str | None = None,
context_payload: ContextPayload | None = None,
setup_code: str | None = None,
persistent: bool = False,
depth: int = 1,
Expand Down Expand Up @@ -607,17 +608,25 @@ def _handle_llm_request(self, req_data: dict) -> dict:

return {"error": "Unknown request type"}

def load_context(self, context_payload: dict | list | str):
def load_context(self, context_payload: ContextPayload):
"""Load context into the sandbox environment."""
if isinstance(context_payload, str):
from rlm.utils.dataframe_utils import (
build_dataframe_context_code,
dataframe_to_parquet_b64,
get_dataframe_type,
)

df_type = get_dataframe_type(context_payload)
if df_type is not None:
parquet_b64, df_type = dataframe_to_parquet_b64(context_payload)
self.execute_code(build_dataframe_context_code(parquet_b64, df_type))
elif isinstance(context_payload, str):
escaped = context_payload.replace("\\", "\\\\").replace('"""', '\\"\\"\\"')
context_code = f'context = """{escaped}"""'
self.execute_code(f'context = """{escaped}"""')
else:
context_json = json.dumps(context_payload)
escaped_json = context_json.replace("\\", "\\\\").replace("'", "\\'")
context_code = f"import json; context = json.loads('{escaped_json}')"

self.execute_code(context_code)
self.execute_code(f"import json; context = json.loads('{escaped_json}')")

def execute_code(self, code: str) -> REPLResult:
"""Execute code in the Daytona sandbox and return result."""
Expand Down
33 changes: 28 additions & 5 deletions rlm/environments/docker_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any

from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched
from rlm.core.types import REPLResult, RLMChatCompletion
from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion
from rlm.environments.base_env import NonIsolatedEnv
from rlm.utils.dataframe_utils import dataframe_to_parquet_bytes, get_dataframe_type


class LLMProxyHandler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -196,7 +198,7 @@ def __init__(
self,
image: str = "python:3.11-slim",
lm_handler_address: tuple[str, int] | None = None,
context_payload: dict | list | str | None = None,
context_payload: ContextPayload | None = None,
setup_code: str | None = None,
persistent: bool = False,
depth: int = 1,
Expand Down Expand Up @@ -271,15 +273,36 @@ def setup(self):

self.container_id = result.stdout.strip()

# Install dependencies
# Install base dependencies
subprocess.run(
["docker", "exec", self.container_id, "pip", "install", "-q", "dill", "requests"],
capture_output=True,
)
self._pandas_installed = False

def load_context(self, context_payload: dict | list | str):
def _ensure_pandas(self):
"""Install pandas and pyarrow in the container if not already installed."""
if self._pandas_installed:
return
subprocess.run(
["docker", "exec", self.container_id, "pip", "install", "-q", "pandas", "pyarrow"],
capture_output=True,
)
self._pandas_installed = True

def load_context(self, context_payload: ContextPayload):
"""Load context by writing to a file in the mounted workspace."""
if isinstance(context_payload, str):
df_type = get_dataframe_type(context_payload)
if df_type is not None:
self._ensure_pandas()
context_path = os.path.join(self.temp_dir, "context.parquet")
parquet_bytes, df_type = dataframe_to_parquet_bytes(context_payload)
with open(context_path, "wb") as f:
f.write(parquet_bytes)
self.execute_code(
"import pandas as pd\ncontext = pd.read_parquet('/workspace/context.parquet')"
)
elif isinstance(context_payload, str):
context_path = os.path.join(self.temp_dir, "context.txt")
with open(context_path, "w") as f:
f.write(context_payload)
Expand Down
25 changes: 18 additions & 7 deletions rlm/environments/local_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
from typing import Any

from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched
from rlm.core.types import REPLResult, RLMChatCompletion
from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion
from rlm.environments.base_env import (
RESERVED_TOOL_NAMES,
NonIsolatedEnv,
extract_tool_value,
validate_custom_tools,
)
from rlm.utils.dataframe_utils import (
dataframe_to_parquet_bytes,
get_dataframe_type,
)

# =============================================================================
# Safe Builtins
Expand Down Expand Up @@ -127,7 +131,7 @@ class LocalREPL(NonIsolatedEnv):
def __init__(
self,
lm_handler_address: tuple[str, int] | None = None,
context_payload: dict | list | str | None = None,
context_payload: ContextPayload | None = None,
setup_code: str | None = None,
persistent: bool = False,
depth: int = 1,
Expand Down Expand Up @@ -342,13 +346,11 @@ def _rlm_query_batched(self, prompts: list[str], model: str | None = None) -> li
# Fall back to plain batched LM call if no recursive capability
return self._llm_query_batched(prompts, model)

def load_context(self, context_payload: dict | list | str):
def load_context(self, context_payload: ContextPayload):
"""Load context into the environment as context_0 (and 'context' alias)."""
self.add_context(context_payload, 0)

def add_context(
self, context_payload: dict | list | str, context_index: int | None = None
) -> int:
def add_context(self, context_payload: ContextPayload, context_index: int | None = None) -> int:
"""
Add a context with versioned variable name.

Expand All @@ -364,7 +366,16 @@ def add_context(

var_name = f"context_{context_index}"

if isinstance(context_payload, str):
df_type = get_dataframe_type(context_payload)
if df_type is not None:
context_path = os.path.join(self.temp_dir, f"context_{context_index}.parquet")
parquet_bytes, df_type = dataframe_to_parquet_bytes(context_payload)
with open(context_path, "wb") as f:
f.write(parquet_bytes)
self.execute_code(
f"import pandas as pd\n{var_name} = pd.read_parquet(r'{context_path}')"
)
elif isinstance(context_payload, str):
context_path = os.path.join(self.temp_dir, f"context_{context_index}.txt")
with open(context_path, "w") as f:
f.write(context_payload)
Expand Down
25 changes: 17 additions & 8 deletions rlm/environments/modal_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import textwrap
import threading
import time
from typing import Any

import modal
import requests

from rlm.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched
from rlm.core.types import REPLResult, RLMChatCompletion
from rlm.core.types import ContextPayload, REPLResult, RLMChatCompletion
from rlm.environments.base_env import IsolatedEnv
from rlm.environments.constants import APT_PACKAGES, PIP_PACKAGES

Expand Down Expand Up @@ -297,7 +298,7 @@ def __init__(
image: modal.Image | None = None,
timeout: int = 600,
lm_handler_address: tuple[str, int] | None = None,
context_payload: dict | list | str | None = None,
context_payload: ContextPayload | None = None,
setup_code: str | None = None,
persistent: bool = False,
depth: int = 1,
Expand Down Expand Up @@ -435,17 +436,25 @@ def _handle_llm_request(self, req_data: dict) -> dict:

return {"error": "Unknown request type"}

def load_context(self, context_payload: dict | list | str):
def load_context(self, context_payload: ContextPayload):
"""Load context into the sandbox environment."""
if isinstance(context_payload, str):
from rlm.utils.dataframe_utils import (
build_dataframe_context_code,
dataframe_to_parquet_b64,
get_dataframe_type,
)

df_type = get_dataframe_type(context_payload)
if df_type is not None:
parquet_b64, df_type = dataframe_to_parquet_b64(context_payload)
self.execute_code(build_dataframe_context_code(parquet_b64, df_type))
elif isinstance(context_payload, str):
escaped = context_payload.replace("\\", "\\\\").replace('"""', '\\"\\"\\"')
context_code = f'context = """{escaped}"""'
self.execute_code(f'context = """{escaped}"""')
else:
context_json = json.dumps(context_payload)
escaped_json = context_json.replace("\\", "\\\\").replace("'", "\\'")
context_code = f"import json; context = json.loads('{escaped_json}')"

self.execute_code(context_code)
self.execute_code(f"import json; context = json.loads('{escaped_json}')")

def execute_code(self, code: str) -> REPLResult:
"""Execute code in the Modal sandbox and return result."""
Expand Down
Loading
Loading