Skip to content

Conversation

kctezcan
Copy link
Contributor

@kctezcan kctezcan commented Sep 25, 2025

Description

Encoding target variables into the latent space, similar to sources.

The changes are made for the forecasting mode, tested for both training and inference.

Some open points:

  1. How to name the variables and functions? Now everything is called "..._srclk" as abbreviation to "source like", indicating that the targets are being processed as sources.
  2. Indeed in some cases it is possible to reuse some code for processing sources by looping through targets fsteps and processing each seperately. But this requires significant rewriting of the existing functions since the funcitons do not operate on simple input variables but input objects and access their fields.
  3. Is it also relevant in the MTM mode? This is not tested.
  4. A test with multiple datasets still needs to be done.
  5. The saving of the latent variables is left for future work

Issue Number

Closes #941
Refs #941
Refs #941

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@kctezcan
Copy link
Contributor Author

@sophie-xhonneux @tjhunter @clessig Please have a look, thanks

Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Maybe best to have a quick call to go through the questions.

)
(tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = (
self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function
self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please remove all the KCT's. We missed a few last time ...

Also, what's the question here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a reminder for me, all good, I removed it.

time_win: tuple,
normalizer, # dataset,
use_normalizer: str, # "source" or "target"
use_normalizer: str, # "source" or "target"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove whitespace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately ruff is adding that whitespace back :)

time_win: tuple,
normalizer, # dataset
use_normalizer: str, # "source" or "target"
use_normalizer: str, # "source" or "target"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove whitespace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately ruff is adding that whitespace back :)

]
)
for s in stl_b
s.target_srclk_tokens_lens[fstep]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change? Or is this from a different PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what your question is here

for itype, s in enumerate(sb):
for fstep in range(offsets.shape[0]):
if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty
if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove whitespace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again ruff...

zeros_col = torch.zeros((offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device)
offsets = torch.cat([zeros_col, offsets_base[:,:-1]], dim=1)
# take offset_base up to last col and append a 0 in the beginning per fstep
zeros_col = torch.zeros(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on the comment? It's not clear to me why this is necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rephrased it to:

    # shift the offsets for each fstep by one to the right, add a zero to the beginning the first token starts at 0

)
tokens_target = self.assimilate_global(model_params, tokens_target)
tokens_target_det = tokens_target.detach() # explicitly detach as well
tokens_target_det = tokens_target.detach() # explicitly detach as well
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove whitespace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff :(

num_fsteps = target_srclk_tokens_lens.shape[2] # TODO: KCT, if there are diff no of tokens per fstep, this may fail
num_fsteps = target_srclk_tokens_lens.shape[
2
] # TODO: KCT, if there are diff no of tokens per fstep, this may fail
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle this special case if things might break. When would it be triggered.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is OK, it does not break, I checked.

@@ -0,0 +1,192 @@
# (C) Copyright 2025 WeatherGenerator contributors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why all these changes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, i need this to start the code in debugging mode in VScode. It slipped into a commit, untracked it again.

@kctezcan
Copy link
Contributor Author

Thanks a lot for the comments, @clessig

I have a question about how to handle empty target fsteps: see src/weathergen/model/model.py line 668.

We can keep the empty fsteps as well and we would not have to deal with the fstep shifts between the sources and targets. In the curent form, the code needs to introduce some offsetting if there is forecast_offset, for example.


if rdata.is_empty():
stream_data.add_empty_target(fstep)
stream_data.add_empty_target_srclk(fstep)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change the name to add_empty_target_source_like or something more readable as a variable/function name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"source_like" sounds good. It is a bit long, but des not bother me, if @clessig you are also ok, I can change it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, srclk wasn't clear to me. Rather long and explicit.

self.source_tokens_cells = torch.tensor([])
self.source_centroids = torch.tensor([])

# >>>>>>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


return tokens_all

def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it necessary to duplicate the function embed_cells? I think ideally we avoid that, because each them the embedding engine changes this function also needs to change, ie it is quite prone to code rot

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the discussion points for me as well. if we have a function that we can use both for sources and targets, then we could also use it n times for each fstep as well, i.e. taking the loop over fsteps out of the funciton, as you and Christian have suggested earlier.

As the code is written at this point, this function does not take a variable inside but accesses the variables directly from the streams_data object. So my suggestion would be:

  1. rewrite the function to take not the object as input but the related variables
  2. call it in a for loop over fsteps for the targets.

I will implement it so we can see how it looks and decide what is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ideally we can do that change in a separate PR? that way we can merge this if it looks good independently of the latent loss, e.g. faster ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should go into a separate PR.

for _, sb in enumerate(streams_data):
for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)):
for fstep in range(num_fsteps):
if s.target_source_like_tokens_lens[fstep].sum() != 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clessig @sophie-xhonneux
what do you think? should we skip empty fsteps or return an empty tensor for those?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented

@clessig
Copy link
Collaborator

clessig commented Sep 29, 2025

Thanks a lot for the comments, @clessig

I have a question about how to handle empty target fsteps: see src/weathergen/model/model.py line 668.

We can keep the empty fsteps as well and we would not have to deal with the fstep shifts between the sources and targets. In the curent form, the code needs to introduce some offsetting if there is forecast_offset, for example.

We should remove the special case handling, yes. But can this go to a separate PR?

@kctezcan
Copy link
Contributor Author

Thanks a lot for the comments, @clessig
I have a question about how to handle empty target fsteps: see src/weathergen/model/model.py line 668.
We can keep the empty fsteps as well and we would not have to deal with the fstep shifts between the sources and targets. In the curent form, the code needs to introduce some offsetting if there is forecast_offset, for example.

We should remove the special case handling, yes. But can this go to a separate PR?

Yes, of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

encoding target variales in the latent space
3 participants