Refactor CascadePredictor.generate so generate methods can be private#959
Refactor CascadePredictor.generate so generate methods can be private#959
CascadePredictor.generate so generate methods can be private#959Conversation
…e methods private
There was a problem hiding this comment.
Thanks, @AnnaKwa. I think this is on the right track that we need to include more rich set of information with the outputs of the models (e.g., coordinates). I ran into something related to the output information issue in the last PR of the static input refactor. For the cascading code, there is a requirement for the generate_on_batch to have a ModelOutput with latent steps, etc., but that can't be provided with _generate as it's set up.
a) It's not clear whether we'll use CascadePredictor again, so we could just opt to deprecate this model and remove it from the options of inference models and stop updating it for now
--or--
b) generate stays public in order to satisfy the output needs for generate_on_batch, but we will still need to include real coordinate information for the static_input handling within the model
I'm open to a) since we haven't been using these types of models recently.
| current_coarse = generated | ||
| generated = _restore_batch_and_sample_dims(generated, n_samples) | ||
| return generated, generated_norm, latent_steps | ||
| return generated |
There was a problem hiding this comment.
Outputs need to be preserved for the generate_on_batch method of the cascade predictor, which complicates the use of non-public _generate for the underlying model.
|
|
||
|
|
||
| def _batch_data_with_unused_coords(data: TensorMapping) -> BatchData: | ||
| # wrapper function so that we can call each level's |
There was a problem hiding this comment.
Would be nice if this was some kind of subclass of BatchData or easily identifiable that this is not actual data. But on second thought, if I am using the coarse coordinates to define the static input usage, this would break that contract.
I'm also fine with this. I'll close this PR and make an issue to deprecate this code. |
This predictor class is not used currently, as we decided to go with a direct downscaling approach. It gets in the way of some other potential refactors that would be nice to have but are blocked by the requirements of the sequential prediction in `CascadePredictor.generate` (ex. #959 (comment))
I made some changes to address this comment since they were fresh in my mind: #954 (comment)
This PR refactors CascadePredictor.generate to wrap previous step's output into a BatchData object so it can call the underlying model's generate_on_batch_no_target instead of generate.
This allows us to make the DiffusionModel.generate method private.