From 376cef2b5b1406562295e7755b6a110a54244d18 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 25 Jul 2023 16:15:22 -0500 Subject: [PATCH 1/6] Preliminary implementation for distributed shufflable streams Signed-off-by: Alex-Brooks --- caikit_nlp/toolkit/data_stream_wrapper.py | 63 +++++++++++++---------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/caikit_nlp/toolkit/data_stream_wrapper.py b/caikit_nlp/toolkit/data_stream_wrapper.py index c486561b..c7a19dad 100644 --- a/caikit_nlp/toolkit/data_stream_wrapper.py +++ b/caikit_nlp/toolkit/data_stream_wrapper.py @@ -18,7 +18,7 @@ """ # Third Party -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, get_worker_info # First Party from caikit.core.toolkit import error_handler @@ -33,11 +33,13 @@ class SimpleIterableStreamWrapper(IterableDataset): compatability with PyTorch data loaders. """ - def __init__(self, stream, shuffle, buffer_size=None): + def __init__(self, stream, shuffle, buffer_size=None, seed=42): error.type_check("", bool, shuffle=shuffle) error.type_check( "", int, buffer_size=buffer_size, allow_none=True ) + self.seed = seed + self.shuffles_completed = 0 self.stream = stream self.shuffle = shuffle self.buffer_size = buffer_size @@ -48,33 +50,40 @@ def __init__(self, stream, shuffle, buffer_size=None): log.debug("Shuffling buffer size: {}".format(self.buffer_size)) def __iter__(self): - - # FIXME: We are currently not handling case where we have to work with - # multiple workers, so currently duplicate data will get processed by - # each worker. + worker_info = get_worker_info() if self.shuffle: - log.debug4("Reshuffling training data!") - return iter(self.stream.shuffle(self.buffer_size)) - return iter(self.stream) - # worker_info = get_worker_info() - # if worker_info is None: # single-process data loading, return the full iterator - # if self.shuffle: - # log.debug4("Reshuffling training data!") - # return iter(self.stream.shuffle(self.buffer_size)) - # return iter(self.stream) + # Get the next shuffle seed; we use the root seed + number of + # shuffles completed so far to ensure that every worker will + # shuffle the same way for each epoch. + shuffle_seed = self._get_shuffle_seed(worker_info) + log.debug(f"Reshuffling training data with seed: {shuffle_seed}") + cycle_stream = self.stream.shuffle(self.buffer_size, seed=shuffle_seed) + self._increment_shuffle_seed(worker_info) + else: + cycle_stream = self.stream + # Once shuffling has been handled, consider workers; if we have multiple + # then create a substream from the main cycle stream to form a partition. + if worker_info is not None: + cycle_stream = self._get_stream_partition( + cycle_stream, worker_info.id, worker_info.num_workers + ) + return iter(cycle_stream) + + def _get_shuffle_seed(self, worker_info): + if worker_info is None: + return self.seed + self.shuffles_completed + return self.seed + worker_info.dataset.shuffles_completed + + def _increment_shuffle_seed(self, worker_info): + if worker_info is None: + self.shuffles_completed += 1 + else: + worker_info.dataset.shuffles_completed += 1 - # When num_workers > 0, each worker process will have a different copy of - # the dataset object, so we configure each copy independently to avoid - # having duplicate data returned from each worker - # else: # in a worker process - # # split workload - # per_worker = int( - # math.ceil((self.end - self.start) / float(worker_info.num_workers)) - # ) - # worker_id = worker_info.id - # iter_start = self.start + worker_id * per_worker - # iter_end = min(iter_start + per_worker, self.end) - # return iter(range(iter_start, iter_end)) + def _get_stream_partition(self, cycle_stream, worker_id, num_workers): + for idx, elem in enumerate(cycle_stream): + if (idx - worker_id) % num_workers == 0: + yield (elem) def __len__(self): return len(self.stream) From 67ee5a360692c70fc4e3836736a10e465221bfbf Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 25 Jul 2023 16:41:16 -0500 Subject: [PATCH 2/6] Add stream partition docstrings and type hints Signed-off-by: Alex-Brooks --- caikit_nlp/toolkit/data_stream_wrapper.py | 72 ++++++++++++++++++++--- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/caikit_nlp/toolkit/data_stream_wrapper.py b/caikit_nlp/toolkit/data_stream_wrapper.py index c7a19dad..95d509ed 100644 --- a/caikit_nlp/toolkit/data_stream_wrapper.py +++ b/caikit_nlp/toolkit/data_stream_wrapper.py @@ -16,15 +16,17 @@ and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch DataLoaders, with minimal boilerplate. """ +from typing import Any, Iterator, Optional # Third Party from torch.utils.data import IterableDataset, get_worker_info # First Party from caikit.core.toolkit import error_handler +from caikit.core.data_model import DataStream import alog -log = alog.use_channel("PEFT_PROMPT") +log = alog.use_channel("STREAM_WRAP") error = error_handler.get(log) @@ -33,11 +35,12 @@ class SimpleIterableStreamWrapper(IterableDataset): compatability with PyTorch data loaders. """ - def __init__(self, stream, shuffle, buffer_size=None, seed=42): + def __init__(self, stream: DataStream[Any], shuffle: bool, buffer_size: Optional[int]=None, seed: int=42): error.type_check("", bool, shuffle=shuffle) error.type_check( "", int, buffer_size=buffer_size, allow_none=True ) + error.type_check("", int, seed=seed) self.seed = seed self.shuffles_completed = 0 self.stream = stream @@ -49,7 +52,16 @@ def __init__(self, stream, shuffle, buffer_size=None, seed=42): log.debug("Shuffling enabled? {}".format(self.shuffle)) log.debug("Shuffling buffer size: {}".format(self.buffer_size)) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: + """Initialize a consumable iterator. If we have n workers, we handle the shuffle + behaviors first, then return every nth element, forming a partition across the + substreams produced by each iterator at the cost of having to skip items. If + we don't configure workers, we simply return prior to partitioning. + + Returns: + Iterator + iterator pertaining to one worker or the full dataset. + """ worker_info = get_worker_info() if self.shuffle: # Get the next shuffle seed; we use the root seed + number of @@ -69,21 +81,65 @@ def __iter__(self): ) return iter(cycle_stream) - def _get_shuffle_seed(self, worker_info): + def _get_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> int: + """Gets the current seed for this shuffle. + + Args: + worker_info: Optional["torch.utils.data._utils.worker.WorkerInfo"] + Torch dataloader worker or None. + + Returns: + int + the seed to be used while for the next shuffle on the + encapsulated stream. + """ if worker_info is None: return self.seed + self.shuffles_completed return self.seed + worker_info.dataset.shuffles_completed - def _increment_shuffle_seed(self, worker_info): + def _increment_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> None: + """Increments the current seed to prepare for the next shuffle. + IMPORTANT: we must use persistent loaders when shuffling across + multiple workers! Otherwise the worker will be destroyed, and our + shuffle counter will be lost, which will cause shuffle to look + like it's not working. + + Args: + worker_info: Optional["torch.utils.data._utils.worker.WorkerInfo"] + Torch dataloader worker or None. + """ if worker_info is None: self.shuffles_completed += 1 else: worker_info.dataset.shuffles_completed += 1 - def _get_stream_partition(self, cycle_stream, worker_id, num_workers): + def _get_stream_partition(self, + cycle_stream: DataStream[Any], + worker_id: int, + num_workers: int): + """Generator for a subset of a wrapped datastream; here, we simply traverse a stream, + which is assumed to be preshuffled, and yield the elements that align with the + scheme 'worker n gets every nth entry' after shuffling. This ensures that each + record in the stream is encountered at most once per epoch as long as shuffling + is consistent across the different workers. + + Args: + cycle_stream: DataStream[Any] + datastream that we're trying to partition. + worker_id: int + ID of the current worker. + num_workers: int + Number of workers being used to load the dataset. + """ for idx, elem in enumerate(cycle_stream): if (idx - worker_id) % num_workers == 0: - yield (elem) + yield elem - def __len__(self): + def __len__(self) -> int: + """Gets the encapsulated stream length. + + Returns: + int + number of objects in the stream. + """ return len(self.stream) From 89b2858afc452eac60e90665c4c55f9aa155a08d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 25 Jul 2023 17:12:59 -0500 Subject: [PATCH 3/6] Add test for multiworker shuffling Signed-off-by: Alex-Brooks --- tests/toolkit/test_data_stream_wrapper.py | 60 ++++++++++++----------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/tests/toolkit/test_data_stream_wrapper.py b/tests/toolkit/test_data_stream_wrapper.py index b0d3604c..33294991 100644 --- a/tests/toolkit/test_data_stream_wrapper.py +++ b/tests/toolkit/test_data_stream_wrapper.py @@ -5,7 +5,6 @@ # Third Party from torch.utils.data._utils import worker -import pytest # First Party from caikit.core.data_model import DataStream @@ -66,30 +65,35 @@ def test_shuffle_full_buffer(requires_determinism): assert not all(test_results) -# def test_iter_with_multi_worker(requires_determinism): -# """Ensure that we are able to iterate properly over data in case of workers -# managed by torch""" - -# test_results = [] -# # Get the IDs of all objects in the stream -# get_stream_id_order = lambda s: [id(datum) for datum in s] -# # Compare the data stream at two different iteration points; here, we -# # produce True if two streams have the same objects in the same order -# have_same_id_order = lambda id_set1, id_set2: all( -# [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)] -# ) and len(id_set1) == len(id_set2) - -# dummy_worker_info = worker.WorkerInfo( -# id=1, -# num_workers=2, -# seed=7, -# ) - -# with mock.patch.object(worker, '_worker_info', dummy_worker_info): -# wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=False) -# initialize_order = get_stream_id_order(wrapper) -# for _ in range(NUM_CYCLES): -# cycle_ids = get_stream_id_order(wrapper) -# test_res = have_same_id_order(initialize_order, cycle_ids) -# test_results.append(test_res) -# assert not all(test_results) +def test_iter_with_multi_worker(): + """Ensure that we are able to iterate properly over data in case of workers + managed by torch""" + test_results = [] + w1_info = worker.WorkerInfo(id=0, num_workers=3, seed=7) + w2_info = worker.WorkerInfo(id=1, num_workers=3, seed=7) + w3_info = worker.WorkerInfo(id=2, num_workers=3, seed=7) + # Worker distribution works round robin after we consider shuffling. + # Since we don't shuffle in this patched test, they should just be + # divided as is. + index_stream = [ + {"label": 0}, # goes to worker 0 + {"label": 1}, # goes to worker 1 + {"label": 2}, # goes to worker 2 + {"label": 3}, # goes to worker 0 + {"label": 4}, # goes to worker 1 + {"label": 5}, # goes to worker 2 + ] + worker_info = [ + (w1_info, [index_stream[0], index_stream[3]]), + (w2_info, [index_stream[1], index_stream[4]]), + (w3_info, [index_stream[2], index_stream[5]]), + ] + for (dummy_worker, expected_elements) in worker_info: + with mock.patch.object(worker, '_worker_info', dummy_worker): + wrapper = SimpleIterableStreamWrapper(stream=index_stream, shuffle=False) + for _ in range(NUM_CYCLES): + actual_elements = list(wrapper) + test_results.append( + actual_elements == expected_elements and len(actual_elements) == len(expected_elements) + ) + assert all(test_results) From b7bd4cb48a596b9fca1df8fdd9b9e0cc7d2e9bf8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 25 Jul 2023 16:17:26 -0600 Subject: [PATCH 4/6] Stream wrapper linter and code formatting Signed-off-by: Alex-Brooks --- caikit_nlp/toolkit/data_stream_wrapper.py | 22 ++++++++++++++-------- tests/toolkit/test_data_stream_wrapper.py | 17 +++++++++-------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/caikit_nlp/toolkit/data_stream_wrapper.py b/caikit_nlp/toolkit/data_stream_wrapper.py index 95d509ed..4cecd762 100644 --- a/caikit_nlp/toolkit/data_stream_wrapper.py +++ b/caikit_nlp/toolkit/data_stream_wrapper.py @@ -16,14 +16,15 @@ and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch DataLoaders, with minimal boilerplate. """ +# Standard from typing import Any, Iterator, Optional # Third Party from torch.utils.data import IterableDataset, get_worker_info # First Party -from caikit.core.toolkit import error_handler from caikit.core.data_model import DataStream +from caikit.core.toolkit import error_handler import alog log = alog.use_channel("STREAM_WRAP") @@ -35,7 +36,13 @@ class SimpleIterableStreamWrapper(IterableDataset): compatability with PyTorch data loaders. """ - def __init__(self, stream: DataStream[Any], shuffle: bool, buffer_size: Optional[int]=None, seed: int=42): + def __init__( + self, + stream: DataStream[Any], + shuffle: bool, + buffer_size: Optional[int] = None, + seed: int = 42, + ): error.type_check("", bool, shuffle=shuffle) error.type_check( "", int, buffer_size=buffer_size, allow_none=True @@ -68,7 +75,7 @@ def __iter__(self) -> Iterator[Any]: # shuffles completed so far to ensure that every worker will # shuffle the same way for each epoch. shuffle_seed = self._get_shuffle_seed(worker_info) - log.debug(f"Reshuffling training data with seed: {shuffle_seed}") + log.debug("Reshuffling training data with seed: {}".format(shuffle_seed)) cycle_stream = self.stream.shuffle(self.buffer_size, seed=shuffle_seed) self._increment_shuffle_seed(worker_info) else: @@ -113,10 +120,9 @@ def _increment_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> None: else: worker_info.dataset.shuffles_completed += 1 - def _get_stream_partition(self, - cycle_stream: DataStream[Any], - worker_id: int, - num_workers: int): + def _get_stream_partition( + self, cycle_stream: DataStream[Any], worker_id: int, num_workers: int + ): """Generator for a subset of a wrapped datastream; here, we simply traverse a stream, which is assumed to be preshuffled, and yield the elements that align with the scheme 'worker n gets every nth entry' after shuffling. This ensures that each @@ -137,7 +143,7 @@ def _get_stream_partition(self, def __len__(self) -> int: """Gets the encapsulated stream length. - + Returns: int number of objects in the stream. diff --git a/tests/toolkit/test_data_stream_wrapper.py b/tests/toolkit/test_data_stream_wrapper.py index 33294991..f007cb7e 100644 --- a/tests/toolkit/test_data_stream_wrapper.py +++ b/tests/toolkit/test_data_stream_wrapper.py @@ -76,12 +76,12 @@ def test_iter_with_multi_worker(): # Since we don't shuffle in this patched test, they should just be # divided as is. index_stream = [ - {"label": 0}, # goes to worker 0 - {"label": 1}, # goes to worker 1 - {"label": 2}, # goes to worker 2 - {"label": 3}, # goes to worker 0 - {"label": 4}, # goes to worker 1 - {"label": 5}, # goes to worker 2 + {"label": 0}, # goes to worker 0 + {"label": 1}, # goes to worker 1 + {"label": 2}, # goes to worker 2 + {"label": 3}, # goes to worker 0 + {"label": 4}, # goes to worker 1 + {"label": 5}, # goes to worker 2 ] worker_info = [ (w1_info, [index_stream[0], index_stream[3]]), @@ -89,11 +89,12 @@ def test_iter_with_multi_worker(): (w3_info, [index_stream[2], index_stream[5]]), ] for (dummy_worker, expected_elements) in worker_info: - with mock.patch.object(worker, '_worker_info', dummy_worker): + with mock.patch.object(worker, "_worker_info", dummy_worker): wrapper = SimpleIterableStreamWrapper(stream=index_stream, shuffle=False) for _ in range(NUM_CYCLES): actual_elements = list(wrapper) test_results.append( - actual_elements == expected_elements and len(actual_elements) == len(expected_elements) + actual_elements == expected_elements + and len(actual_elements) == len(expected_elements) ) assert all(test_results) From b6d1b82bcba064fce1291e4bacaaa53d8fdf88d2 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 25 Jul 2023 16:21:14 -0600 Subject: [PATCH 5/6] Add notice for multiworker shuffling behaviors Signed-off-by: Alex-Brooks --- caikit_nlp/toolkit/data_stream_wrapper.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/caikit_nlp/toolkit/data_stream_wrapper.py b/caikit_nlp/toolkit/data_stream_wrapper.py index 4cecd762..8d3265f0 100644 --- a/caikit_nlp/toolkit/data_stream_wrapper.py +++ b/caikit_nlp/toolkit/data_stream_wrapper.py @@ -34,6 +34,16 @@ class SimpleIterableStreamWrapper(IterableDataset): """DataStream wrapper as an iterable PyTorch dataset; we use this to add compatability with PyTorch data loaders. + + NOTE: this wrapper does support shuffling iterable datasets with multiple + workers as a true partition, but for it to work correctly, you must + set persistent_workers=True when initializing your dataloader. Otherwise, + your workers will be destroyed, causing them to have the same shuffle + seed every time. + + To verify that multiworker shuffling is working properly, you can turn on + debug logs and verify that the logged shuffle seed changes as you iterate + through your dataset. """ def __init__( From 13929916ba2a36e95ecd49f13a025fc0cabc84d3 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 25 Jul 2023 16:23:13 -0600 Subject: [PATCH 6/6] Cache datastream length Signed-off-by: Alex-Brooks --- caikit_nlp/toolkit/data_stream_wrapper.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/toolkit/data_stream_wrapper.py b/caikit_nlp/toolkit/data_stream_wrapper.py index 8d3265f0..f5b2bf31 100644 --- a/caikit_nlp/toolkit/data_stream_wrapper.py +++ b/caikit_nlp/toolkit/data_stream_wrapper.py @@ -63,9 +63,10 @@ def __init__( self.stream = stream self.shuffle = shuffle self.buffer_size = buffer_size + self.stream_length = len(stream) # Load the whole data set in memory if self.shuffle and buffer_size is None: - self.buffer_size = len(stream) + self.buffer_size = self.stream_length log.debug("Shuffling enabled? {}".format(self.shuffle)) log.debug("Shuffling buffer size: {}".format(self.buffer_size)) @@ -152,10 +153,12 @@ def _get_stream_partition( yield elem def __len__(self) -> int: - """Gets the encapsulated stream length. + """Gets the encapsulated stream length. Note that we cache this attribute, + because taking the length of a datastream (re-entrant generator) requires + iterating until the end of it, which is expensive. Returns: int number of objects in the stream. """ - return len(self.stream) + return self.stream_length