Skip to content

Refactor CascadePredictor.generate so generate methods can be private#959

Closed
AnnaKwa wants to merge 1 commit intomainfrom
refactor/private-generate-method-rebase
Closed

Refactor CascadePredictor.generate so generate methods can be private#959
AnnaKwa wants to merge 1 commit intomainfrom
refactor/private-generate-method-rebase

Conversation

@AnnaKwa
Copy link
Contributor

@AnnaKwa AnnaKwa commented Mar 11, 2026

I made some changes to address this comment since they were fresh in my mind: #954 (comment)

This PR refactors CascadePredictor.generate to wrap previous step's output into a BatchData object so it can call the underlying model's generate_on_batch_no_target instead of generate.

This allows us to make the DiffusionModel.generate method private.

Copy link
Collaborator

@frodre frodre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @AnnaKwa. I think this is on the right track that we need to include more rich set of information with the outputs of the models (e.g., coordinates). I ran into something related to the output information issue in the last PR of the static input refactor. For the cascading code, there is a requirement for the generate_on_batch to have a ModelOutput with latent steps, etc., but that can't be provided with _generate as it's set up.

a) It's not clear whether we'll use CascadePredictor again, so we could just opt to deprecate this model and remove it from the options of inference models and stop updating it for now

--or--

b) generate stays public in order to satisfy the output needs for generate_on_batch, but we will still need to include real coordinate information for the static_input handling within the model

I'm open to a) since we haven't been using these types of models recently.

current_coarse = generated
generated = _restore_batch_and_sample_dims(generated, n_samples)
return generated, generated_norm, latent_steps
return generated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outputs need to be preserved for the generate_on_batch method of the cascade predictor, which complicates the use of non-public _generate for the underlying model.



def _batch_data_with_unused_coords(data: TensorMapping) -> BatchData:
# wrapper function so that we can call each level's
Copy link
Collaborator

@frodre frodre Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if this was some kind of subclass of BatchData or easily identifiable that this is not actual data. But on second thought, if I am using the coarse coordinates to define the static input usage, this would break that contract.

@AnnaKwa
Copy link
Contributor Author

AnnaKwa commented Mar 13, 2026

a) It's not clear whether we'll use CascadePredictor again, so we could just opt to deprecate this model and remove it from the options of inference models and stop updating it for now

I'm also fine with this. I'll close this PR and make an issue to deprecate this code.

@AnnaKwa AnnaKwa closed this Mar 13, 2026
frodre pushed a commit that referenced this pull request Mar 13, 2026
This predictor class is not used currently, as we decided to go with a
direct downscaling approach. It gets in the way of some other potential
refactors that would be nice to have but are blocked by the requirements
of the sequential prediction in `CascadePredictor.generate` (ex.
#959 (comment))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants