diff --git a/models/ctm.py b/models/ctm.py index 833451b..4a9a26b 100644 --- a/models/ctm.py +++ b/models/ctm.py @@ -239,11 +239,20 @@ def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, s selected_left = activated_state[:, neuron_indices_left] selected_right = activated_state[:, neuron_indices_right] + # OLD VERSION (for comparison): # Compute outer product of selected neurons - outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1) - # Resulting matrix is symmetric, so we only need the upper triangle - i, j = torch.triu_indices(n_synch, n_synch) - pairwise_product = outer[:, i, j] + # outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1) + # # Resulting matrix is symmetric, so we only need the upper triangle + # i, j = torch.triu_indices(n_synch, n_synch) + # pairwise_product = outer[:, i, j] + # + # NEW VERSION (optimized): + # Compute pairwise products efficiently without intermediate tensor + # - Equivalent result: selected_left[:, i] * selected_right[:, j] == outer[:, i, j] + # - Memory efficient: skips (B, N, N) intermediate tensor entirely + # - Device aware: ensures indices on same device as tensors + i, j = torch.triu_indices(n_synch, n_synch, device=activated_state.device) + pairwise_product = selected_left[:, i] * selected_right[:, j] elif self.neuron_select_type == 'random-pairing': # For random-pairing, we compute the sync between specific pairs of neurons diff --git a/utils/housekeeping.py b/utils/housekeeping.py index b89c997..1d04001 100644 --- a/utils/housekeeping.py +++ b/utils/housekeeping.py @@ -9,20 +9,18 @@ def zip_python_code(output_filename): """ - Zips all .py files in the current repository and saves it to the + Zips all .py files in the current repository and saves it to the specified output filename. Args: - output_filename: The name of the output zip file. + output_filename: The name of the output zip file. Defaults to "python_code_backup.zip". """ with zipfile.ZipFile(output_filename, 'w') as zipf: files = glob.glob('models/**/*.py', recursive=True) + glob.glob('utils/**/*.py', recursive=True) + glob.glob('tasks/**/*.py', recursive=True) + glob.glob('*.py', recursive=True) for file in files: - root = '/'.join(file.split('/')[:-1]) - nm = file.split('/')[-1] - zipf.write(os.path.join(root, nm)) + zipf.write(file, arcname=file) def set_seed(seed=42, deterministic=True): """