diff --git a/caikit_nlp/toolkit/data_stream_wrapper.py b/caikit_nlp/toolkit/data_stream_wrapper.py index c486561b..f5b2bf31 100644 --- a/caikit_nlp/toolkit/data_stream_wrapper.py +++ b/caikit_nlp/toolkit/data_stream_wrapper.py @@ -16,65 +16,149 @@ and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch DataLoaders, with minimal boilerplate. """ +# Standard +from typing import Any, Iterator, Optional # Third Party -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, get_worker_info # First Party +from caikit.core.data_model import DataStream from caikit.core.toolkit import error_handler import alog -log = alog.use_channel("PEFT_PROMPT") +log = alog.use_channel("STREAM_WRAP") error = error_handler.get(log) class SimpleIterableStreamWrapper(IterableDataset): """DataStream wrapper as an iterable PyTorch dataset; we use this to add compatability with PyTorch data loaders. + + NOTE: this wrapper does support shuffling iterable datasets with multiple + workers as a true partition, but for it to work correctly, you must + set persistent_workers=True when initializing your dataloader. Otherwise, + your workers will be destroyed, causing them to have the same shuffle + seed every time. + + To verify that multiworker shuffling is working properly, you can turn on + debug logs and verify that the logged shuffle seed changes as you iterate + through your dataset. """ - def __init__(self, stream, shuffle, buffer_size=None): + def __init__( + self, + stream: DataStream[Any], + shuffle: bool, + buffer_size: Optional[int] = None, + seed: int = 42, + ): error.type_check("", bool, shuffle=shuffle) error.type_check( "", int, buffer_size=buffer_size, allow_none=True ) + error.type_check("", int, seed=seed) + self.seed = seed + self.shuffles_completed = 0 self.stream = stream self.shuffle = shuffle self.buffer_size = buffer_size + self.stream_length = len(stream) # Load the whole data set in memory if self.shuffle and buffer_size is None: - self.buffer_size = len(stream) + self.buffer_size = self.stream_length log.debug("Shuffling enabled? {}".format(self.shuffle)) log.debug("Shuffling buffer size: {}".format(self.buffer_size)) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: + """Initialize a consumable iterator. If we have n workers, we handle the shuffle + behaviors first, then return every nth element, forming a partition across the + substreams produced by each iterator at the cost of having to skip items. If + we don't configure workers, we simply return prior to partitioning. - # FIXME: We are currently not handling case where we have to work with - # multiple workers, so currently duplicate data will get processed by - # each worker. + Returns: + Iterator + iterator pertaining to one worker or the full dataset. + """ + worker_info = get_worker_info() if self.shuffle: - log.debug4("Reshuffling training data!") - return iter(self.stream.shuffle(self.buffer_size)) - return iter(self.stream) - # worker_info = get_worker_info() - # if worker_info is None: # single-process data loading, return the full iterator - # if self.shuffle: - # log.debug4("Reshuffling training data!") - # return iter(self.stream.shuffle(self.buffer_size)) - # return iter(self.stream) - - # When num_workers > 0, each worker process will have a different copy of - # the dataset object, so we configure each copy independently to avoid - # having duplicate data returned from each worker - # else: # in a worker process - # # split workload - # per_worker = int( - # math.ceil((self.end - self.start) / float(worker_info.num_workers)) - # ) - # worker_id = worker_info.id - # iter_start = self.start + worker_id * per_worker - # iter_end = min(iter_start + per_worker, self.end) - # return iter(range(iter_start, iter_end)) - - def __len__(self): - return len(self.stream) + # Get the next shuffle seed; we use the root seed + number of + # shuffles completed so far to ensure that every worker will + # shuffle the same way for each epoch. + shuffle_seed = self._get_shuffle_seed(worker_info) + log.debug("Reshuffling training data with seed: {}".format(shuffle_seed)) + cycle_stream = self.stream.shuffle(self.buffer_size, seed=shuffle_seed) + self._increment_shuffle_seed(worker_info) + else: + cycle_stream = self.stream + # Once shuffling has been handled, consider workers; if we have multiple + # then create a substream from the main cycle stream to form a partition. + if worker_info is not None: + cycle_stream = self._get_stream_partition( + cycle_stream, worker_info.id, worker_info.num_workers + ) + return iter(cycle_stream) + + def _get_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> int: + """Gets the current seed for this shuffle. + + Args: + worker_info: Optional["torch.utils.data._utils.worker.WorkerInfo"] + Torch dataloader worker or None. + + Returns: + int + the seed to be used while for the next shuffle on the + encapsulated stream. + """ + if worker_info is None: + return self.seed + self.shuffles_completed + return self.seed + worker_info.dataset.shuffles_completed + + def _increment_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> None: + """Increments the current seed to prepare for the next shuffle. + IMPORTANT: we must use persistent loaders when shuffling across + multiple workers! Otherwise the worker will be destroyed, and our + shuffle counter will be lost, which will cause shuffle to look + like it's not working. + + Args: + worker_info: Optional["torch.utils.data._utils.worker.WorkerInfo"] + Torch dataloader worker or None. + """ + if worker_info is None: + self.shuffles_completed += 1 + else: + worker_info.dataset.shuffles_completed += 1 + + def _get_stream_partition( + self, cycle_stream: DataStream[Any], worker_id: int, num_workers: int + ): + """Generator for a subset of a wrapped datastream; here, we simply traverse a stream, + which is assumed to be preshuffled, and yield the elements that align with the + scheme 'worker n gets every nth entry' after shuffling. This ensures that each + record in the stream is encountered at most once per epoch as long as shuffling + is consistent across the different workers. + + Args: + cycle_stream: DataStream[Any] + datastream that we're trying to partition. + worker_id: int + ID of the current worker. + num_workers: int + Number of workers being used to load the dataset. + """ + for idx, elem in enumerate(cycle_stream): + if (idx - worker_id) % num_workers == 0: + yield elem + + def __len__(self) -> int: + """Gets the encapsulated stream length. Note that we cache this attribute, + because taking the length of a datastream (re-entrant generator) requires + iterating until the end of it, which is expensive. + + Returns: + int + number of objects in the stream. + """ + return self.stream_length diff --git a/tests/toolkit/test_data_stream_wrapper.py b/tests/toolkit/test_data_stream_wrapper.py index b0d3604c..f007cb7e 100644 --- a/tests/toolkit/test_data_stream_wrapper.py +++ b/tests/toolkit/test_data_stream_wrapper.py @@ -5,7 +5,6 @@ # Third Party from torch.utils.data._utils import worker -import pytest # First Party from caikit.core.data_model import DataStream @@ -66,30 +65,36 @@ def test_shuffle_full_buffer(requires_determinism): assert not all(test_results) -# def test_iter_with_multi_worker(requires_determinism): -# """Ensure that we are able to iterate properly over data in case of workers -# managed by torch""" - -# test_results = [] -# # Get the IDs of all objects in the stream -# get_stream_id_order = lambda s: [id(datum) for datum in s] -# # Compare the data stream at two different iteration points; here, we -# # produce True if two streams have the same objects in the same order -# have_same_id_order = lambda id_set1, id_set2: all( -# [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)] -# ) and len(id_set1) == len(id_set2) - -# dummy_worker_info = worker.WorkerInfo( -# id=1, -# num_workers=2, -# seed=7, -# ) - -# with mock.patch.object(worker, '_worker_info', dummy_worker_info): -# wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=False) -# initialize_order = get_stream_id_order(wrapper) -# for _ in range(NUM_CYCLES): -# cycle_ids = get_stream_id_order(wrapper) -# test_res = have_same_id_order(initialize_order, cycle_ids) -# test_results.append(test_res) -# assert not all(test_results) +def test_iter_with_multi_worker(): + """Ensure that we are able to iterate properly over data in case of workers + managed by torch""" + test_results = [] + w1_info = worker.WorkerInfo(id=0, num_workers=3, seed=7) + w2_info = worker.WorkerInfo(id=1, num_workers=3, seed=7) + w3_info = worker.WorkerInfo(id=2, num_workers=3, seed=7) + # Worker distribution works round robin after we consider shuffling. + # Since we don't shuffle in this patched test, they should just be + # divided as is. + index_stream = [ + {"label": 0}, # goes to worker 0 + {"label": 1}, # goes to worker 1 + {"label": 2}, # goes to worker 2 + {"label": 3}, # goes to worker 0 + {"label": 4}, # goes to worker 1 + {"label": 5}, # goes to worker 2 + ] + worker_info = [ + (w1_info, [index_stream[0], index_stream[3]]), + (w2_info, [index_stream[1], index_stream[4]]), + (w3_info, [index_stream[2], index_stream[5]]), + ] + for (dummy_worker, expected_elements) in worker_info: + with mock.patch.object(worker, "_worker_info", dummy_worker): + wrapper = SimpleIterableStreamWrapper(stream=index_stream, shuffle=False) + for _ in range(NUM_CYCLES): + actual_elements = list(wrapper) + test_results.append( + actual_elements == expected_elements + and len(actual_elements) == len(expected_elements) + ) + assert all(test_results)