From b78cf1af2ef7c32ce3a7a1baa7e18f9828cdec80 Mon Sep 17 00:00:00 2001 From: schewskone Date: Mon, 26 Jan 2026 17:00:15 +0100 Subject: [PATCH 1/2] added simple benchmark function to MultiEpochsDataloader class --- experanto/utils.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/experanto/utils.py b/experanto/utils.py index d2891320..14c5dc03 100644 --- a/experanto/utils.py +++ b/experanto/utils.py @@ -121,6 +121,40 @@ def __iter__(self): for i in range(len(self)): yield next(self.iterator) + def benchmark( + self, + n_batches=100, + n_warmup=10, + verbose=False, + ): + it = iter(self) + batch_times = [] + + total_batches = len(self) if n_batches is None else n_batches + + # Warmup + for _ in range(n_warmup): + batch = next(it, None) + if batch is None: + it = iter(self) + batch = next(it) + + # Timed iteration + for i in range(total_batches): + start = time.perf_counter() + batch = next(it, None) + if batch is None: + it = iter(self) + batch = next(it) + end = time.perf_counter() + batch_times.append(end - start) + if verbose: + print(f"Batch {i+1}: {batch_times[-1]:.4f}s") + + avg_time = sum(batch_times) / len(batch_times) + std_time = (sum((t - avg_time) ** 2 for t in batch_times) / len(batch_times)) ** 0.5 + return {'avg_time': avg_time, 'std_time': std_time, 'batch_times': batch_times} + # borrowed with <3 from # https://github.com/sinzlab/neuralpredictors/blob/main/neuralpredictors/training/cyclers.py @@ -739,4 +773,4 @@ def __iter__(self): if self.drop_last and len(batch_indices) < self.batch_size: continue - yield batch_indices + yield batch_indices \ No newline at end of file From c065a73f4f98735c22f8dc1aaef7046ed040924e Mon Sep 17 00:00:00 2001 From: github-actions Date: Mon, 26 Jan 2026 16:01:02 +0000 Subject: [PATCH 2/2] style: auto-format with black and isort --- experanto/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/experanto/utils.py b/experanto/utils.py index 14c5dc03..fdb8e145 100644 --- a/experanto/utils.py +++ b/experanto/utils.py @@ -152,8 +152,10 @@ def benchmark( print(f"Batch {i+1}: {batch_times[-1]:.4f}s") avg_time = sum(batch_times) / len(batch_times) - std_time = (sum((t - avg_time) ** 2 for t in batch_times) / len(batch_times)) ** 0.5 - return {'avg_time': avg_time, 'std_time': std_time, 'batch_times': batch_times} + std_time = ( + sum((t - avg_time) ** 2 for t in batch_times) / len(batch_times) + ) ** 0.5 + return {"avg_time": avg_time, "std_time": std_time, "batch_times": batch_times} # borrowed with <3 from @@ -773,4 +775,4 @@ def __iter__(self): if self.drop_last and len(batch_indices) < self.batch_size: continue - yield batch_indices \ No newline at end of file + yield batch_indices