-
Notifications
You must be signed in to change notification settings - Fork 38
Ktezcan/dev/iss941 encode targets sepfstep #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
1ab20ce
7a74aaa
c4129a3
4c2d12f
029d0ad
6f61c16
2f4197e
f73adda
cd30eb1
5f42d43
5a38dc7
d3976fb
721de02
44f4a11
b16d6ae
4304352
97ffb58
c8fc70d
f432e30
88ffff5
a367dc6
c78c3c5
32b5970
d76ea22
036ab9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,13 +41,20 @@ def batchify_source( | |
source: np.array, | ||
times: np.array, | ||
time_win: tuple, | ||
normalizer, # dataset | ||
normalizer, # dataset, | ||
use_normalizer: str, # "source_normalizer" or "target_normalizer" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename |
||
): | ||
init_loggers() | ||
token_size = stream_info["token_size"] | ||
is_diagnostic = stream_info.get("diagnostic", False) | ||
tokenize_spacetime = stream_info.get("tokenize_spacetime", False) | ||
|
||
channel_normalizer = ( | ||
normalizer.normalize_source_channels | ||
if use_normalizer == "source_normalizer" | ||
else normalizer.normalize_target_channels | ||
) | ||
|
||
tokenize_window = partial( | ||
tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, | ||
time_win=time_win, | ||
|
@@ -56,7 +63,7 @@ def batchify_source( | |
hpy_verts_rots=self.hpy_verts_rots_source[-1], | ||
n_coords=normalizer.normalize_coords, | ||
n_geoinfos=normalizer.normalize_geoinfos, | ||
n_data=normalizer.normalize_source_channels, | ||
n_data=channel_normalizer, | ||
enc_time=encode_times_source, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,12 +48,19 @@ def batchify_source( | |
times: np.array, | ||
time_win: tuple, | ||
normalizer, # dataset | ||
use_normalizer: str, # "source_normalizer" or "target_normalizer" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename |
||
): | ||
init_loggers() | ||
token_size = stream_info["token_size"] | ||
is_diagnostic = stream_info.get("diagnostic", False) | ||
tokenize_spacetime = stream_info.get("tokenize_spacetime", False) | ||
|
||
channel_normalizer = ( | ||
normalizer.normalize_source_channels | ||
if use_normalizer == "source_normalizer" | ||
else normalizer.normalize_target_channels | ||
) | ||
|
||
tokenize_window = partial( | ||
tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, | ||
time_win=time_win, | ||
|
@@ -62,7 +69,7 @@ def batchify_source( | |
hpy_verts_rots=self.hpy_verts_rots_source[-1], | ||
n_coords=normalizer.normalize_coords, | ||
n_geoinfos=normalizer.normalize_geoinfos, | ||
n_data=normalizer.normalize_source_channels, | ||
n_data=channel_normalizer, | ||
enc_time=encode_times_source, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -677,6 +677,84 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: | |
return batch | ||
|
||
|
||
def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> StreamData: | ||
""" | ||
Compute auxiliary information for scatter operation that changes from stream-centric to | ||
cell-centric computations | ||
|
||
Parameters | ||
---------- | ||
batch : str | ||
batch of stream data information for which offsets have to be computed | ||
|
||
Returns | ||
------- | ||
StreamData | ||
stream data with offsets added as members | ||
""" | ||
|
||
# collect source_tokens_lens for all stream datas | ||
target_source_like_tokens_lens = torch.stack( | ||
[ | ||
torch.stack( | ||
[ | ||
torch.stack( | ||
[ | ||
s.target_source_like_tokens_lens[fstep] | ||
if len(s.target_source_like_tokens_lens[fstep]) > 0 | ||
else torch.tensor([]) | ||
for fstep in range(len(s.target_source_like_tokens_lens)) | ||
] | ||
) | ||
for s in stl_b | ||
] | ||
) | ||
for stl_b in batch | ||
] | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use less lines, because it looks more complex than it actually is.
If this was caused by ruff then just forget about this comment... |
||
|
||
# precompute index sets for scatter operation after embed | ||
offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) | ||
# shift the offsets for each fstep by one to the right, add a zero to the | ||
# beginning as the first token starts at 0 | ||
zeros_col = torch.zeros( | ||
(offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device | ||
) | ||
offsets = torch.cat([zeros_col, offsets_base[:, :-1]], dim=1) | ||
offsets_pe = torch.zeros_like(offsets) | ||
|
||
for ib, sb in enumerate(batch): | ||
for itype, s in enumerate(sb): | ||
for fstep in range(offsets.shape[0]): | ||
if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace with |
||
s.target_source_like_idxs_embed[fstep] = torch.cat( | ||
[ | ||
torch.arange(offset, offset + token_len, dtype=torch.int64) | ||
for offset, token_len in zip( | ||
offsets[fstep], | ||
target_source_like_tokens_lens[ib, itype, fstep], | ||
strict=False, | ||
) | ||
] | ||
) | ||
s.target_source_like_idxs_embed_pe[fstep] = torch.cat( | ||
[ | ||
torch.arange(offset, offset + token_len, dtype=torch.int32) | ||
for offset, token_len in zip( | ||
offsets_pe[fstep], | ||
target_source_like_tokens_lens[ib][itype][fstep], | ||
strict=False, | ||
) | ||
] | ||
) | ||
|
||
# advance offsets | ||
offsets[fstep] += target_source_like_tokens_lens[ib][itype][fstep] | ||
offsets_pe[fstep] += target_source_like_tokens_lens[ib][itype][fstep] | ||
|
||
return batch | ||
|
||
|
||
def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: | ||
""" | ||
Compute auxiliary information for prediction | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace
with
for slightly better efficiency.
Maybe you can find a way to replace len(s) with a way to do the check in constant time without having to write multiple lines of code.