diff --git a/docs/source/rtc.mdx b/docs/source/rtc.mdx index f63c00fcaf..6dc386b54e 100644 --- a/docs/source/rtc.mdx +++ b/docs/source/rtc.mdx @@ -50,7 +50,6 @@ policy_cfg = PI0Config() policy_cfg.rtc_config = RTCConfig( enabled=True, execution_horizon=10, # How many steps to blend with previous chunk - max_guidance_weight=10.0, # How strongly to enforce consistency prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend ) @@ -101,7 +100,10 @@ Typical values: 8-12 steps RTCConfig(execution_horizon=10) ``` -**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value. +**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. + +If `max_guidance_weight` is not set, the number of flow matching steps will be used as max guidance weight. +Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html **`prefix_attention_schedule`**: How to weight consistency across the overlap region. @@ -112,6 +114,14 @@ RTCConfig(execution_horizon=10) **`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime. +**`sigma_d`**: The variance of the prior distribution. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. + +Typical values: 0.1-1.0 + +By default, `sigma_d` is set to 1.0. So it's following the original RTC paper. But you can tune it to your needs, by reducing it to get more reactivity and by increasing it to get more smoothness. + +Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html + ## Testing RTC Offline Before running on a real robot, test RTC with dataset samples to visualize how it works: @@ -121,7 +131,6 @@ python examples/rtc/eval_dataset.py \ --policy.path=lerobot/pi0_libero_finetuned \ --dataset.repo_id=HuggingFaceVLA/libero \ --rtc.execution_horizon=10 \ - --rtc.max_guidance_weight=10.0 \ --device=cuda ``` diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 4652df1077..c2213b57c7 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -17,10 +17,15 @@ """ Evaluate Real-Time Chunking (RTC) performance on dataset samples. -This script takes two random samples from a dataset: +This script takes two samples from a dataset: - Uses actions from the first sample as previous chunk - Generates new actions for the second sample with and without RTC +Sampling modes: +- Random (default): Two independent random samples +- Correlated (--sample_correlation_shift): Second sample is shifted from first by N steps + to test temporal correlation and sigma effects + It compares action predictions with and without RTC on dataset samples, measuring consistency and ground truth alignment. @@ -31,17 +36,30 @@ --dataset.repo_id=helper2424/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ - --rtc.max_guidance_weight=10.0 \ --rtc.prefix_attention_schedule=EXP \ + --random_chunks=true \ --seed=10 + uv run python examples/rtc/eval_dataset.py \ + --policy.path=lerobot/pi05_libero_finetuned \ + --rtc.max_guidance_weight=11 \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --rtc.execution_horizon=10 \ + --device=mps \ + --seed=10 \ + --random_chunks=true \ + --rtc.sigma_d=1 + # Basic usage with pi0.5 policy uv run python examples/rtc/eval_dataset.py \ --policy.path=lerobot/pi05_libero_finetuned \ --dataset.repo_id=HuggingFaceVLA/libero \ --rtc.execution_horizon=10 \ - --device=mps - --seed=10 + --device=mps \ + --seed=10 \ + --sample_correlation_shift=10 \ + --rtc.sigma_d=1.0 \ + --rtc.full_trajectory_alignment=true # Basic usage with pi0.5 policy with cuda device uv run python examples/rtc/eval_dataset.py \ @@ -63,6 +81,16 @@ --rtc.execution_horizon=8 \ --device=cuda + # With sample correlation shift to test temporal correlation (sigma effect) + # Second sample is taken as first_sample_index + shift + uv run python examples/rtc/eval_dataset.py \ + --policy.path=lerobot/pi05_libero_finetuned \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --rtc.execution_horizon=10 \ + --device=mps \ + --sample_correlation_shift=5 \ + --seed=10 + # With torch.compile for faster inference (PyTorch 2.0+) # Note: CUDA graphs disabled by default due to in-place ops in denoising loop uv run python examples/rtc/eval_dataset.py \ @@ -161,7 +189,6 @@ class RTCEvalConfig(HubMixin): default_factory=lambda: RTCConfig( enabled=True, execution_horizon=20, - max_guidance_weight=10.0, prefix_attention_schedule=RTCAttentionSchedule.EXP, debug=True, debug_maxlen=1000, @@ -191,6 +218,11 @@ class RTCEvalConfig(HubMixin): metadata={"help": "Inference delay for RTC"}, ) + num_inference_steps: int | None = field( + default=None, + metadata={"help": "Number of flow matching inference steps. If None, uses policy default."}, + ) + # Torch compile configuration use_torch_compile: bool = field( default=False, @@ -215,6 +247,22 @@ class RTCEvalConfig(HubMixin): }, ) + next_inference_after: int = field( + default=10, + metadata={ + "help": "How many steps after the previous " + "operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues." + }, + ) + + random_chunks: bool = field( + default=False, + metadata={ + "help": "The shift between the two chunks to be evaluated. It's used to check bigger difference between previons action chunk" + "and newly generated chunk." + }, + ) + def __post_init__(self): # Parse policy path policy_path = parser.get_path_arg("policy") @@ -303,6 +351,17 @@ def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool): if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0": config.compile_model = self.cfg.use_torch_compile + # Override number of flow matching steps if specified + if self.cfg.num_inference_steps is not None: + if self.cfg.policy.type == "smolvla": + config.num_steps = self.cfg.num_inference_steps + logging.info(f" Overriding num_steps for SmolVLA: {self.cfg.num_inference_steps}") + elif self.cfg.policy.type in ["pi0", "pi05"]: + config.num_inference_steps = self.cfg.num_inference_steps + logging.info( + f" Overriding num_inference_steps for {self.cfg.policy.type}: {self.cfg.num_inference_steps}" + ) + policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config) policy = policy.to(self.device) policy.eval() @@ -315,6 +374,8 @@ def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool): prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule, debug=rtc_debug, debug_maxlen=self.cfg.rtc.debug_maxlen, + full_trajectory_alignment=self.cfg.rtc.full_trajectory_alignment, + sigma_d=self.cfg.rtc.sigma_d, ) policy.config.rtc_config = rtc_config policy.init_rtc_processor() @@ -433,13 +494,45 @@ def run_evaluation(self): logging.info("=" * 80) logging.info("Starting RTC evaluation") logging.info(f"Inference delay: {self.cfg.inference_delay}") + if self.cfg.num_inference_steps is not None: + logging.info(f"Number of flow matching steps: {self.cfg.num_inference_steps}") + else: + logging.info("Number of flow matching steps: Using policy default") logging.info("=" * 80) - # Load two random samples from dataset - data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) - loader_iter = iter(data_loader) - first_sample = next(loader_iter) - second_sample = next(loader_iter) + # Correlated sampling: second sample is shifted from first + shift = self.cfg.next_inference_after + logging.info(f"Using correlated sampling: second sample shifted by {shift} from first sample") + + # Get random first index + first_idx = random.randint(0, len(self.dataset) - 1) + + # Calculate second index with shift, ensuring it's within bounds + second_idx = first_idx + shift + + if self.cfg.random_chunks: + second_idx = random.randint(first_idx + 1, len(self.dataset) - 1) + + if second_idx < 0 or second_idx >= len(self.dataset): + raise ValueError( + f"Second sample index {second_idx} is out of bounds [0, {len(self.dataset) - 1}]. " + f"First index: {first_idx}, shift: {shift}. " + f"Please use a smaller shift value or adjust the seed." + ) + + logging.info(f"First sample index: {first_idx}, Second sample index: {second_idx}") + + # Get samples directly from dataset + first_sample = self.dataset[first_idx] + second_sample = self.dataset[second_idx] + + # Add batch dimension (dataset returns unbatched samples) + first_sample = { + k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in first_sample.items() + } + second_sample = { + k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in second_sample.items() + } preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_second_sample = self.preprocessor(second_sample) @@ -461,7 +554,7 @@ def run_evaluation(self): with torch.no_grad(): prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk( preprocessed_first_sample, - )[:, :25, :].squeeze(0) + )[:, shift : shift + 25, :].squeeze(0) logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}") # Destroy policy_prev_chunk to free memory for large models diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 6f051485ab..e1403a8faa 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -142,8 +142,7 @@ class RTCDemoConfig(HubMixin): # RTC configuration rtc: RTCConfig = field( default_factory=lambda: RTCConfig( - execution_horizon=10, - max_guidance_weight=1.0, + execution_horizon=15, prefix_attention_schedule=RTCAttentionSchedule.EXP, ) ) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 9b6f38ad4f..c17d632cca 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -842,6 +842,7 @@ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): time=time, original_denoise_step_partial=denoise_step_partial_call, execution_horizon=execution_horizon, + num_flow_matching_steps=num_steps, ) else: v_t = denoise_step_partial_call(x_t) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index b017bbc57e..6c6f98c7b5 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -814,6 +814,7 @@ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): time=time, original_denoise_step_partial=denoise_step_partial_call, execution_horizon=execution_horizon, + num_flow_matching_steps=num_steps, ) else: v_t = denoise_step_partial_call(x_t) diff --git a/src/lerobot/policies/rtc/configuration_rtc.py b/src/lerobot/policies/rtc/configuration_rtc.py index 70a8dfb096..ebc639402f 100644 --- a/src/lerobot/policies/rtc/configuration_rtc.py +++ b/src/lerobot/policies/rtc/configuration_rtc.py @@ -40,16 +40,30 @@ class RTCConfig: # Core RTC settings # Todo change to exp prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR - max_guidance_weight: float = 10.0 + + # This parameter is used to clip the guidance weight + # In the original RTC it's a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. + # The original paper had value of 5.0, during the implementation it was found that this parameter is not needed and can be replaced with the number of steps. + # Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html + # num of steps could be used as clipping parameter without requirements on hyperparameters tuning + # If user doesn't provide this parameter, than the number of flow matching steps will be used as max guidance weight + max_guidance_weight: float | None = None execution_horizon: int = 10 + # This parameter is used to clip the variance of the prior distribution + # Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html + # The value could be in range of [0, 1], if it's 1.0, than the behavior is the same as the original RTC + sigma_d: float = 1.0 + + full_trajectory_alignment: bool = False + # Debug settings debug: bool = False debug_maxlen: int = 100 def __post_init__(self): """Validate RTC configuration parameters.""" - if self.max_guidance_weight <= 0: + if self.max_guidance_weight is not None and self.max_guidance_weight <= 0: raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}") if self.debug_maxlen <= 0: raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}") diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 280905adf9..434fb64a58 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -120,6 +120,7 @@ def denoise_step( inference_delay, time, original_denoise_step_partial, + num_flow_matching_steps, execution_horizon=None, ) -> Tensor: """RTC guidance wrapper around an existing denoiser. @@ -138,6 +139,9 @@ def denoise_step( broadcastable with ``x_t``. original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that computes the base denoised velocity given only ``x_t``. + num_flow_matching_steps (int): Number of flow matching inference steps (must be positive integer). + If ``max_guidance_weight`` is ``None``, will be used as the max guidance weight + (Alex Soare optimization). execution_horizon (int | None): Horizon used to build prefix weights. If ``None``, defaults to ``self.rtc_config.execution_horizon``. @@ -153,6 +157,10 @@ def denoise_step( - Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and ``error = (prev_chunk_left_over - x1_t) * weights``. - The final guidance weight is clamped by ``max_guidance_weight`` from the config. + - Alex Soare optimization: If ``max_guidance_weight`` is ``None``, + ``max_guidance_weight`` is automatically set to ``num_flow_matching_steps`` + without requiring hyperparameter tuning. + Reference: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html Reference: https://www.physicalintelligence.company/download/real_time_chunking.pdf @@ -209,18 +217,37 @@ def denoise_step( ) with torch.enable_grad(): - v_t = original_denoise_step_partial(x_t) x_t.requires_grad_(True) + v_t = original_denoise_step_partial(x_t) x1_t = x_t - time * v_t # noqa: N806 err = (prev_chunk_left_over - x1_t) * weights - grad_outputs = err.clone().detach() - correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] - max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) + correction = err + + # If full trajectory alignment is enabled this is not default RTC behavior, + # the newly generated trajectory will be fully aligned with the previous chunk. It's similar to the case where we ignore gradients from + # from the neural network, and take into the account only the error between the previous chunk and the newly generated trajectory. + # It will work faster and if the distance between chunks generation is not so high than it gives smoother transitions. + if not self.rtc_config.full_trajectory_alignment: + grad_outputs = err.clone().detach() + correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] + + # Alex Soare optimization: Use num_flow_matching_steps as max_guidance_weight if not set + # Reference: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html + # The number of flow matching steps can be used as a clipping parameter without hyperparameter tuning + max_guidance_weight = self.rtc_config.max_guidance_weight + if max_guidance_weight is None: + max_guidance_weight = num_flow_matching_steps + + max_guidance_weight = torch.as_tensor(max_guidance_weight) + tau_tensor = torch.as_tensor(tau) squared_one_minus_tau = (1 - tau_tensor) ** 2 - inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau) + prior_variance = torch.as_tensor(self.rtc_config.sigma_d**2) + inv_r2 = (squared_one_minus_tau + tau_tensor**2 * prior_variance) / ( + squared_one_minus_tau * prior_variance + ) c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight) guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight) guidance_weight = torch.minimum(guidance_weight, max_guidance_weight) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index e442b14d5a..a4c1f85f91 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -814,6 +814,7 @@ def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): time=time, original_denoise_step_partial=denoise_step_partial_call, execution_horizon=execution_horizon, + num_flow_matching_steps=self.config.num_steps, ) else: v_t = denoise_step_partial_call(x_t) diff --git a/tests/policies/pi0_pi05/test_pi05_rtc.py b/tests/policies/pi0_pi05/test_pi05_rtc.py index 3a753031f9..d20a442e53 100644 --- a/tests/policies/pi0_pi05/test_pi05_rtc.py +++ b/tests/policies/pi0_pi05/test_pi05_rtc.py @@ -66,8 +66,6 @@ def test_pi05_rtc_initialization(): assert policy.rtc_processor is not None assert policy.rtc_processor.rtc_config.enabled is True - print("✓ PI0.5 RTC initialization: Test passed") - @require_cuda def test_pi05_rtc_initialization_without_rtc_config(): @@ -85,7 +83,193 @@ def test_pi05_rtc_initialization_without_rtc_config(): assert policy.model.rtc_processor is None assert policy._rtc_enabled() is False - print("✓ PI0.5 RTC initialization without RTC config: Test passed") + +@require_cuda +def test_pi05_rtc_alex_soare_optimization(): + """Test PI0.5 with Alex Soare optimization (max_guidance_weight=None, uses num_inference_steps during denoise_step).""" + set_seed(42) + + config = PI05Config( + max_action_dim=7, + max_state_dim=14, + dtype="float32", + num_inference_steps=20, # This will be passed to denoise_step + ) + + # Add RTC config WITHOUT max_guidance_weight (optimization happens in denoise_step) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Not provided - optimization happens during denoise_step + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = PI05Policy(config) + + # Verify RTC processor has max_guidance_weight still None (optimization happens in denoise_step) + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.max_guidance_weight is None + + +@require_cuda +def test_pi05_rtc_explicit_max_guidance_weight(): + """Test PI0.5 respects explicit max_guidance_weight when provided.""" + set_seed(42) + + config = PI05Config( + max_action_dim=7, + max_state_dim=14, + dtype="float32", + num_inference_steps=20, + ) + + # Add RTC config WITH explicit max_guidance_weight + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, # Explicitly set - should NOT be overridden + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = PI05Policy(config) + + # Verify RTC processor keeps the explicit max_guidance_weight + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.max_guidance_weight == 5.0 + assert policy.rtc_processor.rtc_config.max_guidance_weight != config.num_inference_steps + + +@require_cuda +def test_pi05_rtc_inference_with_different_sigma_d_and_auto_guidance(): + """Test PI0.5 inference with different sigma_d values using Alex Soare optimization.""" + set_seed(42) + + config = PI05Config( + max_action_dim=7, + max_state_dim=14, + chunk_size=50, + dtype="float32", + num_inference_steps=10, # Will be used as max_guidance_weight + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats (PI0.5 uses QUANTILES normalization) + dataset_stats = { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Test with sigma_d = 0.2 (stronger guidance) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Use Alex Soare optimization + sigma_d=0.2, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + policy1 = PI05Policy(config) + policy1.eval() + preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats) + + # Verify max_guidance_weight was auto-set + assert policy1.rtc_processor.rtc_config.max_guidance_weight is None + assert policy1.rtc_processor.rtc_config.sigma_d == 0.2 + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + noise = policy1.model.sample_noise((1, config.chunk_size, 7), device) + actions_sigma_02 = policy1.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Now test with sigma_d = 1.0 (weaker guidance) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Use Alex Soare optimization + sigma_d=1.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + policy2 = PI05Policy(config) + policy2.eval() + + # Verify max_guidance_weight was auto-set and sigma_d is different + assert policy2.rtc_processor.rtc_config.max_guidance_weight is None + assert policy2.rtc_processor.rtc_config.sigma_d == 1.0 + + with torch.no_grad(): + actions_sigma_10 = policy2.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Verify shapes + assert actions_sigma_02.shape == (1, config.chunk_size, 7) + assert actions_sigma_10.shape == (1, config.chunk_size, 7) + + # Different sigma_d values should produce different results + assert not torch.allclose(actions_sigma_02, actions_sigma_10, rtol=1e-3) @require_cuda @@ -172,8 +356,6 @@ def test_pi05_rtc_inference_with_prev_chunk(): # With previous chunk, actions should be different (RTC guidance applied) assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) - print("✓ PI0.5 RTC inference with prev_chunk: Test passed") - @require_cuda def test_pi05_rtc_inference_without_prev_chunk(): @@ -250,8 +432,6 @@ def test_pi05_rtc_inference_without_prev_chunk(): # Without previous chunk, RTC should have no effect assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) - print("✓ PI0.5 RTC inference without prev_chunk: Test passed") - @require_cuda def test_pi05_rtc_validation_rules(): diff --git a/tests/policies/pi0_pi05/test_pi0_rtc.py b/tests/policies/pi0_pi05/test_pi0_rtc.py index 68e94dd943..718b31b83e 100644 --- a/tests/policies/pi0_pi05/test_pi0_rtc.py +++ b/tests/policies/pi0_pi05/test_pi0_rtc.py @@ -66,8 +66,6 @@ def test_pi0_rtc_initialization(): assert policy.rtc_processor is not None assert policy.rtc_processor.rtc_config.enabled is True - print("✓ PI0 RTC initialization: Test passed") - @require_cuda def test_pi0_rtc_initialization_without_rtc_config(): @@ -85,7 +83,183 @@ def test_pi0_rtc_initialization_without_rtc_config(): assert policy.model.rtc_processor is None assert policy._rtc_enabled() is False - print("✓ PI0 RTC initialization without RTC config: Test passed") + +@require_cuda +def test_pi0_rtc_alex_soare_optimization(): + """Test PI0 with Alex Soare optimization (max_guidance_weight=None, uses num_inference_steps during denoise_step).""" + set_seed(42) + + config = PI0Config( + max_action_dim=7, + max_state_dim=14, + dtype="float32", + num_inference_steps=20, # This will be passed to denoise_step + ) + + # Add RTC config WITHOUT max_guidance_weight (optimization happens in denoise_step) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Not provided - optimization happens during denoise_step + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = PI0Policy(config) + + # Verify RTC processor has max_guidance_weight still None (optimization happens in denoise_step) + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.max_guidance_weight is None + + +@require_cuda +def test_pi0_rtc_explicit_max_guidance_weight(): + """Test PI0 respects explicit max_guidance_weight when provided.""" + set_seed(42) + + config = PI0Config( + max_action_dim=7, + max_state_dim=14, + dtype="float32", + num_inference_steps=20, + ) + + # Add RTC config WITH explicit max_guidance_weight + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, # Explicitly set - should NOT be overridden + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = PI0Policy(config) + + # Verify RTC processor keeps the explicit max_guidance_weight + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.max_guidance_weight == 5.0 + assert policy.rtc_processor.rtc_config.max_guidance_weight != config.num_inference_steps + + +@require_cuda +def test_pi0_rtc_inference_with_different_sigma_d_and_auto_guidance(): + """Test PI0 inference with different sigma_d values using Alex Soare optimization.""" + set_seed(42) + + config = PI0Config( + max_action_dim=7, + max_state_dim=14, + chunk_size=50, + dtype="float32", + num_inference_steps=10, # Will be used as max_guidance_weight + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Test with sigma_d = 0.2 (stronger guidance) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Use Alex Soare optimization + sigma_d=0.2, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + policy1 = PI0Policy(config) + policy1.eval() + preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) + + # Verify max_guidance_weight was auto-set + assert policy1.rtc_processor.rtc_config.max_guidance_weight is None + assert policy1.rtc_processor.rtc_config.sigma_d == 0.2 + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + noise = policy1.model.sample_noise((1, config.chunk_size, 7), device) + actions_sigma_02 = policy1.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Now test with sigma_d = 1.0 (weaker guidance) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Use Alex Soare optimization + sigma_d=1.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + policy2 = PI0Policy(config) + policy2.eval() + + # Verify max_guidance_weight was auto-set and sigma_d is different + assert policy2.rtc_processor.rtc_config.max_guidance_weight is None + assert policy2.rtc_processor.rtc_config.sigma_d == 1.0 + + with torch.no_grad(): + actions_sigma_10 = policy2.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Verify shapes + assert actions_sigma_02.shape == (1, config.chunk_size, 7) + assert actions_sigma_10.shape == (1, config.chunk_size, 7) + + # Different sigma_d values should produce different results + assert not torch.allclose(actions_sigma_02, actions_sigma_10, rtol=1e-3) def test_pi0_rtc_inference_with_prev_chunk(): @@ -161,8 +335,6 @@ def test_pi0_rtc_inference_with_prev_chunk(): # With previous chunk, actions should be different (RTC guidance applied) assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) - print("✓ PI0 RTC inference with prev_chunk: Test passed") - @require_cuda def test_pi0_rtc_inference_without_prev_chunk(): @@ -229,8 +401,6 @@ def test_pi0_rtc_inference_without_prev_chunk(): # Without previous chunk, RTC should have no effect assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) - print("✓ PI0 RTC inference without prev_chunk: Test passed") - @require_cuda def test_pi0_rtc_validation_rules(): @@ -334,8 +504,6 @@ def test_pi0_rtc_validation_rules(): device = config.device for schedule in schedules: - print(f"Testing schedule: {schedule}") - # Add RTC config with specific schedule config.rtc_config = RTCConfig( enabled=True, @@ -373,6 +541,3 @@ def test_pi0_rtc_validation_rules(): # Verify shape assert actions.shape == (1, config.chunk_size, 7) - print(f" ✓ Schedule {schedule}: Test passed") - - print("✓ PI0 RTC different schedules: All schedules tested") diff --git a/tests/policies/rtc/test_configuration_rtc.py b/tests/policies/rtc/test_configuration_rtc.py index bb4550eaa6..1e9e3a8279 100644 --- a/tests/policies/rtc/test_configuration_rtc.py +++ b/tests/policies/rtc/test_configuration_rtc.py @@ -28,10 +28,11 @@ def test_rtc_config_default_initialization(): assert config.enabled is False assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR - assert config.max_guidance_weight == 10.0 + assert config.max_guidance_weight is None assert config.execution_horizon == 10 assert config.debug is False assert config.debug_maxlen == 100 + assert config.sigma_d == 1.0 def test_rtc_config_custom_initialization(): @@ -51,6 +52,7 @@ def test_rtc_config_custom_initialization(): assert config.execution_horizon == 20 assert config.debug is True assert config.debug_maxlen == 200 + assert config.sigma_d == 1.0 def test_rtc_config_partial_initialization(): @@ -63,3 +65,10 @@ def test_rtc_config_partial_initialization(): assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR assert config.execution_horizon == 10 assert config.debug is False + + +def test_rtc_config_sigma_d_initialization(): + """Test RTCConfig initializes with custom values.""" + config = RTCConfig(sigma_d=0.5) + + assert config.sigma_d == 0.5 diff --git a/tests/policies/rtc/test_modeling_rtc.py b/tests/policies/rtc/test_modeling_rtc.py index e7fdc09c65..14c028e1db 100644 --- a/tests/policies/rtc/test_modeling_rtc.py +++ b/tests/policies/rtc/test_modeling_rtc.py @@ -381,6 +381,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Should return v_t unchanged (no guidance) @@ -402,6 +403,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) expected_result = torch.tensor( @@ -452,6 +454,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Output should be 2D (batch dimension removed) @@ -461,7 +464,7 @@ def mock_denoiser(x): def test_denoise_step_uses_custom_execution_horizon(): """Test denoise_step uses custom execution_horizon parameter.""" - config = RTCConfig(execution_horizon=10) + config = RTCConfig(execution_horizon=10, max_guidance_weight=10.0) processor = RTCProcessor(config) x_t = torch.ones(1, 20, 1) @@ -476,6 +479,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, execution_horizon=15, ) @@ -526,6 +530,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.0), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) expected_result = torch.tensor( @@ -587,6 +592,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) assert result.shape == (batch_size, chunk_size, action_dim) @@ -610,6 +616,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(1.0), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Should clamp to max_guidance_weight (no Inf) @@ -630,6 +637,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Should have tracked one step @@ -661,6 +669,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Should not track @@ -696,6 +705,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.8), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Second step - with guidance @@ -705,6 +715,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.6), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Both should complete successfully @@ -734,6 +745,7 @@ def mock_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, ) # Result should be on CUDA @@ -759,6 +771,7 @@ def deterministic_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=deterministic_denoiser, + num_flow_matching_steps=10, ) result2 = processor.denoise_step( @@ -767,7 +780,152 @@ def deterministic_denoiser(x): inference_delay=5, time=torch.tensor(0.5), original_denoise_step_partial=deterministic_denoiser, + num_flow_matching_steps=10, ) # Should produce identical results assert torch.allclose(result1, result2) + + +# ====================== Configuration Tests ====================== + + +def test_rtc_config_sigma_d_parameter(): + """Test RTCConfig sigma_d parameter (renamed from sigma_delta).""" + # Test default value + config = RTCConfig() + assert config.sigma_d == 1.0 + + # Test custom value + config = RTCConfig(sigma_d=0.5) + assert config.sigma_d == 0.5 + + # Test that sigma_d affects variance calculation + config1 = RTCConfig(sigma_d=0.5) + config2 = RTCConfig(sigma_d=1.0) + + processor1 = RTCProcessor(config1) + processor2 = RTCProcessor(config2) + + # sigma_d is squared to get variance, so different values should be stored + assert processor1.rtc_config.sigma_d == 0.5 + assert processor2.rtc_config.sigma_d == 1.0 + + +def test_rtc_config_sigma_d_different_values(): + """Test that different sigma_d values produce different guidance.""" + x_t = torch.ones(1, 20, 1) + prev_chunk = torch.full((1, 20, 1), 0.1) + + def mock_denoiser(x): + return x * 0.5 + + # Test with sigma_d = 0.5 (stronger guidance) + config1 = RTCConfig(sigma_d=0.5, max_guidance_weight=10.0) + processor1 = RTCProcessor(config1) + + result1 = processor1.denoise_step( + x_t=x_t.clone(), + prev_chunk_left_over=prev_chunk.clone(), + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=10, + ) + + expected_result = torch.tensor( + [ + [ + [3.7500], + [3.7500], + [3.7500], + [3.7500], + [3.7500], + [3.2083], + [2.6667], + [2.1250], + [1.5833], + [1.0417], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + [0.5000], + ] + ] + ) + + assert torch.allclose(result1, expected_result, atol=1e-4) + + +def test_denoise_step_alex_soare_optimization(): + """Test Alex Soare optimization: num_flow_matching_steps used as max_guidance_weight when None.""" + x_t = torch.ones(1, 20, 1) + prev_chunk = torch.full((1, 20, 1), 0.1) + + def mock_denoiser(x): + return x * 0.5 + + # Test with max_guidance_weight = None (should use num_flow_matching_steps) + config = RTCConfig(max_guidance_weight=None) + processor = RTCProcessor(config) + + # Verify max_guidance_weight is still None in config + assert processor.rtc_config.max_guidance_weight is None + + result = processor.denoise_step( + x_t=x_t.clone(), + prev_chunk_left_over=prev_chunk.clone(), + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=15, # This should be used as max_guidance_weight + ) + + # Result should be computed with max_guidance_weight=15 + assert result.shape == (1, 20, 1) + # The optimization happens internally during denoise_step + + +def test_denoise_step_respects_explicit_max_guidance_weight(): + """Test denoise_step respects explicit max_guidance_weight when provided.""" + x_t = torch.ones(1, 20, 1) + prev_chunk = torch.full((1, 20, 1), 0.1) + + def mock_denoiser(x): + return x * 0.5 + + # Test with explicit max_guidance_weight + config = RTCConfig(max_guidance_weight=5.0) + processor = RTCProcessor(config) + + # Use time=0.9 (tau=0.1) to produce high guidance weight that will be clamped + result1 = processor.denoise_step( + x_t=x_t.clone(), + prev_chunk_left_over=prev_chunk.clone(), + inference_delay=5, + time=torch.tensor(0.9), + original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=20, # Should be ignored, use 5.0 instead + ) + + # Test with max_guidance_weight = None (uses num_flow_matching_steps) + config2 = RTCConfig(max_guidance_weight=None) + processor2 = RTCProcessor(config2) + + result2 = processor2.denoise_step( + x_t=x_t.clone(), + prev_chunk_left_over=prev_chunk.clone(), + inference_delay=5, + time=torch.tensor(0.9), + original_denoise_step_partial=mock_denoiser, + num_flow_matching_steps=20, # Should be used as max_guidance_weight + ) + + # Results should be different (different max_guidance_weight used) + assert not torch.allclose(result1, result2, atol=1e-4) diff --git a/tests/policies/smolvla/test_smolvla_rtc.py b/tests/policies/smolvla/test_smolvla_rtc.py index 53e74d940b..7b99b2fd38 100644 --- a/tests/policies/smolvla/test_smolvla_rtc.py +++ b/tests/policies/smolvla/test_smolvla_rtc.py @@ -62,8 +62,6 @@ def test_smolvla_rtc_initialization(): assert policy.rtc_processor is not None assert policy.rtc_processor.rtc_config.enabled is True - print("✓ SmolVLA RTC initialization: Test passed") - @require_package("transformers") @require_cuda @@ -84,7 +82,193 @@ def test_smolvla_rtc_initialization_without_rtc_config(): assert policy.model.rtc_processor is None assert policy._rtc_enabled() is False - print("✓ SmolVLA RTC initialization without RTC config: Test passed") + +@require_package("transformers") +@require_cuda +def test_smolvla_rtc_alex_soare_optimization(): + from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401 + + """Test SmolVLA with Alex Soare optimization (max_guidance_weight=None, uses num_steps during denoise_step).""" + set_seed(42) + + config = SmolVLAConfig( + max_action_dim=7, + chunk_size=50, + num_steps=15, # SmolVLA uses num_steps (will be passed to denoise_step) + ) + + # Add RTC config WITHOUT max_guidance_weight (optimization happens in denoise_step) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Not provided - optimization happens during denoise_step + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = SmolVLAPolicy(config) + + # Verify RTC processor has max_guidance_weight still None (optimization happens in denoise_step) + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.max_guidance_weight is None + + +@require_package("transformers") +@require_cuda +def test_smolvla_rtc_explicit_max_guidance_weight(): + from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401 + + """Test SmolVLA respects explicit max_guidance_weight when provided.""" + set_seed(42) + + config = SmolVLAConfig( + max_action_dim=7, + chunk_size=50, + num_steps=15, + ) + + # Add RTC config WITH explicit max_guidance_weight + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, # Explicitly set - should NOT be overridden + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = SmolVLAPolicy(config) + + # Verify RTC processor keeps the explicit max_guidance_weight + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.max_guidance_weight == 5.0 + assert policy.rtc_processor.rtc_config.max_guidance_weight != config.num_steps + + +@require_package("transformers") +@require_cuda +def test_smolvla_rtc_inference_with_different_sigma_d_and_auto_guidance(): + from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401 + + """Test SmolVLA inference with different sigma_d values using Alex Soare optimization.""" + set_seed(42) + + config = SmolVLAConfig( + max_action_dim=7, + chunk_size=50, + num_steps=10, # Will be used as max_guidance_weight + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Test with sigma_d = 0.5 (stronger guidance) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Use Alex Soare optimization + sigma_d=0.5, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + policy1 = SmolVLAPolicy(config) + policy1.eval() + + device = config.device + policy1 = policy1.to(device) + + preprocessor, _ = make_pre_post_processors( + policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats + ) + + # Verify max_guidance_weight was auto-set + assert policy1.rtc_processor.rtc_config.max_guidance_weight is None + assert policy1.rtc_processor.rtc_config.sigma_d == 0.5 + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + noise = policy1.model.sample_noise((1, config.chunk_size, 7), device) + actions_sigma_05 = policy1.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Now test with sigma_d = 1.5 (weaker guidance) + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=None, # Use Alex Soare optimization + sigma_d=1.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + policy2 = SmolVLAPolicy(config) + policy2.eval() + policy2 = policy2.to(device) + + # Verify max_guidance_weight was auto-set and sigma_d is different + assert policy2.rtc_processor.rtc_config.max_guidance_weight is None + assert policy2.rtc_processor.rtc_config.sigma_d == 1.0 + + with torch.no_grad(): + actions_sigma_10 = policy2.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Verify shapes + assert actions_sigma_05.shape == (1, config.chunk_size, 7) + assert actions_sigma_10.shape == (1, config.chunk_size, 7) + + # Different sigma_d values should produce different results + assert not torch.allclose(actions_sigma_05, actions_sigma_10, rtol=1e-3) @require_package("transformers") @@ -167,8 +351,6 @@ def test_smolvla_rtc_inference_with_prev_chunk(): # With previous chunk, actions should be different (RTC guidance applied) assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) - print("✓ SmolVLA RTC inference with prev_chunk: Test passed") - @require_package("transformers") @require_cuda @@ -241,8 +423,6 @@ def test_smolvla_rtc_inference_without_prev_chunk(): # Without previous chunk, RTC should have no effect assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) - print("✓ SmolVLA RTC inference without prev_chunk: Test passed") - @require_package("transformers") @require_cuda