-
Notifications
You must be signed in to change notification settings - Fork 36
Benchmarking Dataloaders #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -121,6 +121,42 @@ def __iter__(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(len(self)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield next(self.iterator) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def benchmark( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_batches=100, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_warmup=10, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| verbose=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | |
| ): | |
| """Benchmark the data loading time over a number of batches. | |
| Args: | |
| n_batches (int, optional): Number of batches to time. If ``None``, | |
| iterate over ``len(self)`` batches. Defaults to ``100``. | |
| n_warmup (int, optional): Number of initial batches to run without | |
| timing, to allow for warmup. Defaults to ``10``. | |
| verbose (bool, optional): If ``True``, print the time for each | |
| individual batch. Defaults to ``False``. | |
| Returns: | |
| dict: A dictionary with the following keys: | |
| - ``avg_time`` (float): Average time per batch in seconds. | |
| - ``std_time`` (float): Standard deviation of per-batch times | |
| in seconds. | |
| - ``batch_times`` (List[float]): List of per-batch timings in | |
| seconds. | |
| Example: | |
| >>> stats = dataloader.benchmark(n_batches=50, n_warmup=5) | |
| >>> print(stats["avg_time"], stats["std_time"]) | |
| """ |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The iterator exhaustion handling has a logical issue. Using next(it, None) with a default value of None assumes that None is never a valid batch value. If a dataloader could legitimately return None as a batch, this would incorrectly trigger iterator reinitialization. A more robust approach would be to catch StopIteration exception instead of using a default value, similar to how the cycle function handles iteration at line 169. This would make the behavior more predictable and aligned with Python's iterator protocol.
| batch = next(it, None) | |
| if batch is None: | |
| it = iter(self) | |
| batch = next(it) | |
| # Timed iteration | |
| for i in range(total_batches): | |
| start = time.perf_counter() | |
| batch = next(it, None) | |
| if batch is None: | |
| try: | |
| batch = next(it) | |
| except StopIteration: | |
| it = iter(self) | |
| batch = next(it) | |
| # Timed iteration | |
| for i in range(total_batches): | |
| start = time.perf_counter() | |
| try: | |
| batch = next(it) | |
| except StopIteration: |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmark method measures both the data loading time and the iterator reinitialization time together when the iterator is exhausted (lines 146-148). This means that some batch times in the results will include the overhead of creating a new iterator, making the timing measurements inconsistent and potentially misleading. Consider either: (1) creating an iterator that won't exhaust during the benchmark, (2) excluding batches that required reinitialization from the timing results, or (3) separately tracking and reporting reinitialization overhead.
| start = time.perf_counter() | |
| batch = next(it, None) | |
| if batch is None: | |
| it = iter(self) | |
| batch = next(it) | |
| batch = next(it, None) | |
| if batch is None: | |
| # Iterator exhausted: reinitialize before starting timing | |
| it = iter(self) | |
| start = time.perf_counter() | |
| batch = next(it) | |
| else: | |
| start = time.perf_counter() | |
| # Measure the time to obtain the next batch | |
| # (batch was already fetched above for consistency) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'batch' is unnecessary as it is redefined before this value is used.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When n_batches is None, the method attempts to iterate through len(self) batches. However, if the dataloader has fewer batches than expected or if there's an issue with iterator exhaustion, the logic at lines 146-148 that reinitializes the iterator might create the iterator multiple times unnecessarily. Consider capping total_batches when n_batches is None, or adding validation to ensure the requested number of batches is reasonable. Also, the logic currently doesn't handle the case where the dataloader is empty (len(self) == 0), which would cause division by zero at line 154.
| # Warmup | |
| for _ in range(n_warmup): | |
| batch = next(it, None) | |
| if batch is None: | |
| it = iter(self) | |
| batch = next(it) | |
| # Timed iteration | |
| for i in range(total_batches): | |
| start = time.perf_counter() | |
| batch = next(it, None) | |
| if batch is None: | |
| it = iter(self) | |
| batch = next(it) | |
| end = time.perf_counter() | |
| batch_times.append(end - start) | |
| if verbose: | |
| print(f"Batch {i+1}: {batch_times[-1]:.4f}s") | |
| # Handle empty dataloader or non-positive batch request | |
| if len(self) == 0 or total_batches is None or total_batches <= 0: | |
| warnings.warn( | |
| "MultiEpochsDataLoader.benchmark: no batches available for benchmarking.", | |
| RuntimeWarning, | |
| ) | |
| return { | |
| "avg_time": math.nan, | |
| "std_time": math.nan, | |
| "batch_times": batch_times, | |
| } | |
| # Warmup (do not warm up more than total_batches) | |
| effective_warmup = min(n_warmup, total_batches) | |
| for _ in range(effective_warmup): | |
| batch = next(it, None) | |
| if batch is None: | |
| # Reinitialize iterator once; if still no batch, stop warmup | |
| it = iter(self) | |
| batch = next(it, None) | |
| if batch is None: | |
| break | |
| # Timed iteration | |
| for i in range(total_batches): | |
| start = time.perf_counter() | |
| batch = next(it, None) | |
| if batch is None: | |
| # Reinitialize iterator once; if still no batch, stop timing | |
| it = iter(self) | |
| batch = next(it, None) | |
| if batch is None: | |
| break | |
| end = time.perf_counter() | |
| batch_times.append(end - start) | |
| if verbose: | |
| print(f"Batch {i+1}: {batch_times[-1]:.4f}s") | |
| # Avoid division by zero if no batch times were recorded | |
| if not batch_times: | |
| warnings.warn( | |
| "MultiEpochsDataLoader.benchmark: no batch times recorded.", | |
| RuntimeWarning, | |
| ) | |
| return { | |
| "avg_time": math.nan, | |
| "std_time": math.nan, | |
| "batch_times": batch_times, | |
| } |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The standard deviation calculation is inefficient, iterating through batch_times twice (once for mean, once for variance). Since numpy is already imported in this file (line 20), consider using numpy for statistical calculations which would be more efficient and accurate: import numpy as np at the top is already available, so you could use np.mean(batch_times) and np.std(batch_times) for better performance and numerical stability.
| avg_time = sum(batch_times) / len(batch_times) | |
| std_time = ( | |
| sum((t - avg_time) ** 2 for t in batch_times) / len(batch_times) | |
| ) ** 0.5 | |
| return {"avg_time": avg_time, "std_time": std_time, "batch_times": batch_times} | |
| avg_time = float(np.mean(batch_times)) | |
| std_time = float(np.std(batch_times)) | |
| return { | |
| "avg_time": avg_time, | |
| "std_time": std_time, | |
| "batch_times": batch_times, | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we shouldn't include any benchmarking functions into the actual package and especially not into a specific dataloader, but rather have scripts in a separate place to benchmark important parts of experanto and easily restart benchmarks after changes to see if runtime is affected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think benchmarking dataloaders is something relevant for all users and not just development that's why I put it here. I also would not mind placing it in a separate benchmarking folder and refactor it so that it accepts a dataloader as an argument. My main goal was to make this as prominent as possible and that's why I put it here originally.
For all cases other than data loaders I agree 100% with you.