-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
125 lines (102 loc) · 3.97 KB
/
eval.py
File metadata and controls
125 lines (102 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Evaluate Modal inference on ALOHA sim dataset.
Usage:
uv run eval.py --url <modal-url> --num-episodes 5 --max-steps 100
"""
import argparse
import asyncio
import time
import httpx
import numpy as np
from datasets import load_dataset
from protocol import pack_obs, unpack_response
async def evaluate(url: str, num_episodes: int, max_steps_per_episode: int):
"""Run inference on real ALOHA sim trajectories from HuggingFace."""
print(f"Loading ALOHA sim dataset...")
ds = load_dataset(
"lerobot/aloha_sim_transfer_cube_human",
split="train",
streaming=True,
)
endpoint = url.rstrip("/") + "/infer"
step_latencies = []
episode_count = 0
episode = None
episode_steps = []
async with httpx.AsyncClient(timeout=60.0) as client:
for sample in ds:
# New episode
if episode is None or sample["episode_index"] != episode["index"]:
if episode is not None and episode_steps:
await run_episode(
client, endpoint, episode_steps, step_latencies, episode_count
)
episode_count += 1
if episode_count >= num_episodes:
break
episode = {"index": sample["episode_index"]}
episode_steps = []
episode_steps.append(sample)
if sample["frame_index"] >= max_steps_per_episode - 1:
if episode_count < num_episodes - 1:
await run_episode(
client, endpoint, episode_steps, step_latencies, episode_count
)
episode_count += 1
episode = None
episode_steps = []
# Final episode
if episode_steps and episode_count < num_episodes:
await run_episode(
client, endpoint, episode_steps, step_latencies, episode_count
)
# Report
if step_latencies:
lats = np.array(step_latencies)
print(f"\n{'='*60}")
print(f"Episodes: {episode_count}")
print(f"Total steps: {len(step_latencies)}")
print(f"\nLatency (ms):")
print(f" Mean: {lats.mean():.1f}")
print(f" Median: {np.median(lats):.1f}")
print(f" P95: {np.percentile(lats, 95):.1f}")
print(f" Min: {lats.min():.1f}")
print(f" Max: {lats.max():.1f}")
print(f"\nControl frequency (Hz):")
freqs = 1000 / lats
print(f" Mean: {freqs.mean():.1f}")
print(f" Min: {freqs.min():.1f}")
print(f"{'='*60}\n")
async def run_episode(client, endpoint, steps, latencies, ep_idx):
"""Run inference on episode steps."""
print(f" Episode {ep_idx + 1}: {len(steps)} steps...", end="", flush=True)
start = time.perf_counter()
for sample in steps:
obs = {
"observation.state": np.array(sample["observation.state"], dtype=np.float32),
}
try:
t0 = time.perf_counter()
resp = await client.post(
endpoint,
content=pack_obs(obs),
headers={"Content-Type": "application/x-msgpack"},
)
resp.raise_for_status()
unpack_response(resp.content)
latencies.append((time.perf_counter() - t0) * 1000)
except Exception as e:
print(f"\n Error: {e}")
break
duration = (time.perf_counter() - start) / 1000
print(f" {duration:.2f}s")
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--url", required=True, help="Modal server URL")
parser.add_argument("--num-episodes", type=int, default=5)
parser.add_argument("--max-steps", type=int, default=100)
args = parser.parse_args()
print(f"Target: {args.url}")
print(f"Episodes: {args.num_episodes}\n")
await evaluate(args.url, args.num_episodes, args.max_steps)
if __name__ == "__main__":
asyncio.run(main())