-
Notifications
You must be signed in to change notification settings - Fork 38
encoding targets into latent space #961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
encoding targets into latent space #961
Conversation
@sophie-xhonneux @tjhunter @clessig Please have a look, thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. Maybe best to have a quick call to go through the questions.
) | ||
(tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( | ||
self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function | ||
self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please remove all the KCT's. We missed a few last time ...
Also, what's the question here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a reminder for me, all good, I removed it.
time_win: tuple, | ||
normalizer, # dataset, | ||
use_normalizer: str, # "source" or "target" | ||
use_normalizer: str, # "source" or "target" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately ruff is adding that whitespace back :)
time_win: tuple, | ||
normalizer, # dataset | ||
use_normalizer: str, # "source" or "target" | ||
use_normalizer: str, # "source" or "target" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately ruff is adding that whitespace back :)
src/weathergen/datasets/utils.py
Outdated
] | ||
) | ||
for s in stl_b | ||
s.target_srclk_tokens_lens[fstep] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this change? Or is this from a different PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what your question is here
src/weathergen/datasets/utils.py
Outdated
for itype, s in enumerate(sb): | ||
for fstep in range(offsets.shape[0]): | ||
if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty | ||
if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again ruff...
zeros_col = torch.zeros((offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device) | ||
offsets = torch.cat([zeros_col, offsets_base[:,:-1]], dim=1) | ||
# take offset_base up to last col and append a 0 in the beginning per fstep | ||
zeros_col = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand on the comment? It's not clear to me why this is necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rephrased it to:
# shift the offsets for each fstep by one to the right, add a zero to the beginning the first token starts at 0
src/weathergen/model/model.py
Outdated
) | ||
tokens_target = self.assimilate_global(model_params, tokens_target) | ||
tokens_target_det = tokens_target.detach() # explicitly detach as well | ||
tokens_target_det = tokens_target.detach() # explicitly detach as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ruff :(
src/weathergen/model/model.py
Outdated
num_fsteps = target_srclk_tokens_lens.shape[2] # TODO: KCT, if there are diff no of tokens per fstep, this may fail | ||
num_fsteps = target_srclk_tokens_lens.shape[ | ||
2 | ||
] # TODO: KCT, if there are diff no of tokens per fstep, this may fail |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we handle this special case if things might break. When would it be triggered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is OK, it does not break, I checked.
src/weathergen/run_evaluate.py
Outdated
@@ -0,0 +1,192 @@ | |||
# (C) Copyright 2025 WeatherGenerator contributors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why all these changes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, i need this to start the code in debugging mode in VScode. It slipped into a commit, untracked it again.
Thanks a lot for the comments, @clessig I have a question about how to handle empty target fsteps: see src/weathergen/model/model.py line 668. We can keep the empty fsteps as well and we would not have to deal with the fstep shifts between the sources and targets. In the curent form, the code needs to introduce some offsetting if there is forecast_offset, for example. |
|
||
if rdata.is_empty(): | ||
stream_data.add_empty_target(fstep) | ||
stream_data.add_empty_target_srclk(fstep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please change the name to add_empty_target_source_like
or something more readable as a variable/function name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"source_like" sounds good. It is a bit long, but des not bother me, if @clessig you are also ok, I can change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, srclk
wasn't clear to me. Rather long and explicit.
self.source_tokens_cells = torch.tensor([]) | ||
self.source_centroids = torch.tensor([]) | ||
|
||
# >>>>>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
src/weathergen/model/model.py
Outdated
|
||
return tokens_all | ||
|
||
def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it necessary to duplicate the function embed_cells
? I think ideally we avoid that, because each them the embedding engine changes this function also needs to change, ie it is quite prone to code rot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one of the discussion points for me as well. if we have a function that we can use both for sources and targets, then we could also use it n times for each fstep as well, i.e. taking the loop over fsteps out of the funciton, as you and Christian have suggested earlier.
As the code is written at this point, this function does not take a variable inside but accesses the variables directly from the streams_data
object. So my suggestion would be:
- rewrite the function to take not the object as input but the related variables
- call it in a for loop over fsteps for the targets.
I will implement it so we can see how it looks and decide what is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe ideally we can do that change in a separate PR? that way we can merge this if it looks good independently of the latent loss, e.g. faster ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this should go into a separate PR.
for _, sb in enumerate(streams_data): | ||
for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): | ||
for fstep in range(num_fsteps): | ||
if s.target_source_like_tokens_lens[fstep].sum() != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@clessig @sophie-xhonneux
what do you think? should we skip empty fsteps or return an empty tensor for those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented
We should remove the special case handling, yes. But can this go to a separate PR? |
Yes, of course. |
Description
Encoding target variables into the latent space, similar to sources.
The changes are made for the forecasting mode, tested for both training and inference.
Some open points:
Issue Number
Closes #941
Refs #941
Refs #941
Checklist before asking for review
./scripts/actions.sh lint
./scripts/actions.sh unit-test
./scripts/actions.sh integration-test
launch-slurm.py --time 60