-
Notifications
You must be signed in to change notification settings - Fork 3.2k
RTC adjustments. Bug fix & Alex Soare optimization #2499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
8c85667
a1b0183
d9314da
f42250b
9c3c810
19717fb
b966ad3
6d836a0
798748c
0018e61
25081f9
cb6c862
c38bfe1
0866355
a642ec8
052b6d1
527f9f5
8eb10cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,7 @@ def denoise_step( | |
| time, | ||
| original_denoise_step_partial, | ||
| execution_horizon=None, | ||
| num_flow_matching_steps=None, | ||
|
||
| ) -> Tensor: | ||
| """RTC guidance wrapper around an existing denoiser. | ||
|
|
||
|
|
@@ -163,6 +164,9 @@ def denoise_step( | |
| # So we need to invert the time | ||
| tau = 1 - time | ||
|
|
||
| if self.config.use_soare_optimization and num_flow_matching_steps is None: | ||
| raise ValueError("num_flow_matching_steps must be provided when use_soare_optimization is True") | ||
|
|
||
| if prev_chunk_left_over is None: | ||
| # First step, no guidance - return v_t | ||
| v_t = original_denoise_step_partial(x_t) | ||
|
|
@@ -217,10 +221,23 @@ def denoise_step( | |
| grad_outputs = err.clone().detach() | ||
| correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] | ||
|
|
||
| max_guidance_weight = self.rtc_config.max_guidance_weight | ||
|
|
||
| # Check the following paper - https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html | ||
| # num of steps could be used as clipping parameter without requirements on hyperparameters tuning | ||
| if self.config.use_soare_optimization: | ||
| max_guidance_weight = num_flow_matching_steps | ||
|
|
||
| max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) | ||
| tau_tensor = torch.as_tensor(tau) | ||
| squared_one_minus_tau = (1 - tau_tensor) ** 2 | ||
| inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau) | ||
| if self.config.use_soare_optimization: | ||
|
||
| variance_clipping_factor = torch.as_tensor(self.rtc_config.variance_clipping_factor) | ||
| inv_r2 = (squared_one_minus_tau + tau_tensor**2 * variance_clipping_factor) / ( | ||
| squared_one_minus_tau * variance_clipping_factor | ||
| ) | ||
| else: | ||
| inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau) | ||
| c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight) | ||
| guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight) | ||
| guidance_weight = torch.minimum(guidance_weight, max_guidance_weight) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant:
sigma_d: float = 1.0instead ofvariance_clipping_factor(or if you prefer a more descriptive parameter,prior_variance, but be careful because "variance" issigma_d ** 2). That parameter is used in all cases. When it is 1.0 you are not using the improvement suggested in my article. Otherwise you are. And therefore, you can dropuse_soare_optimizationaltogether, and don't need to guard any code withif use_soare_optimizationThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, as per my article, it's
max_guidance_weightthat you would set equal tonum_flow_matching_steps. In the RTC paper they don't give guidance for that, and just suggest setting it to 5.0.