diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index c823e3419..98be7bb1b 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -11,6 +11,7 @@ from areal.api.workflow_api import RolloutWorkflow from areal.infra import workflow_context +from areal.infra.rpc.rtensor import RTensor from areal.infra.rpc.serialization import deserialize_value from areal.infra.utils.http import async_http_retry from areal.utils import logging, stats_tracker @@ -245,8 +246,12 @@ async def _run_online( if not traj: return None - if "rewards" in traj and len(traj["rewards"]) > 0: - last_reward = float(traj["rewards"][-1]) + rewards_tensor = traj.get("rewards") + if isinstance(rewards_tensor, RTensor): + rewards_tensor = rewards_tensor.to_local() + + if rewards_tensor is not None and len(rewards_tensor) > 0: + last_reward = float(rewards_tensor[-1]) elif ( "interactions" in traj and traj["interactions"] diff --git a/areal/infra/workflow_executor.py b/areal/infra/workflow_executor.py index 949d2af0b..dcf1eabd1 100644 --- a/areal/infra/workflow_executor.py +++ b/areal/infra/workflow_executor.py @@ -22,6 +22,7 @@ from areal.api.cli_args import InferenceEngineConfig from areal.api import RolloutWorkflow +from areal.infra.rpc.rtensor import RTensor from .async_task_runner import ( AsyncTaskRunner, TaskQueueFullError, @@ -840,6 +841,8 @@ async def _dump_trajectory( if traj is None: return False, "trajectory is None" + traj = RTensor.localize(traj) + dump_dir = self._get_dump_dir(is_eval) if dump_dir is None: return False, "dump dir is empty"