From f8bbe38c91ba40865c247584846436988368f5bd Mon Sep 17 00:00:00 2001 From: Guillermo Ortiz-Jimenez Date: Sat, 15 Oct 2022 10:29:09 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 481366497 --- uncertainty_baselines/datasets/augmix.py | 20 ++++++++++---------- uncertainty_baselines/datasets/base.py | 18 +++++++++++++++--- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/uncertainty_baselines/datasets/augmix.py b/uncertainty_baselines/datasets/augmix.py index 809597e0e..a93c56008 100644 --- a/uncertainty_baselines/datasets/augmix.py +++ b/uncertainty_baselines/datasets/augmix.py @@ -140,7 +140,6 @@ def mixup(batch_size, aug_params, images, labels): aug_params: Dict of data augmentation hyper parameters. images: A batch of images of shape [batch_size, ...] labels: A batch of labels of shape [batch_size, num_classes] - Returns: A tuple of (images, labels) with the same dimensions as the input with Mixup regularization applied. @@ -192,15 +191,16 @@ def mixup(batch_size, aug_params, images, labels): labels = tf.reshape( tf.tile(labels, [1, aug_count + 1]), [batch_size, aug_count + 1, -1]) labels_mix = ( - labels * mix_weight + - tf.gather(labels, mixup_index) * (1. - mix_weight)) + labels * mix_weight + tf.gather(labels, mixup_index) * + (1. - mix_weight)) labels_mix = tf.reshape( tf.transpose(labels_mix, [1, 0, 2]), [batch_size * (aug_count + 1), -1]) else: labels_mix = ( - labels * mix_weight + - tf.gather(labels, mixup_index) * (1. - mix_weight)) - return images_mix, labels_mix + labels * mix_weight + tf.gather(labels, mixup_index) * + (1. - mix_weight)) + +return images_mix, labels_mix def adaptive_mixup(batch_size, aug_params, images, labels): @@ -215,7 +215,6 @@ def adaptive_mixup(batch_size, aug_params, images, labels): aug_params: Dict of data augmentation hyper parameters. images: A batch of images of shape [batch_size, ...] labels: A batch of labels of shape [batch_size, num_classes] - Returns: A tuple of (images, labels) with the same dimensions as the input with Mixup regularization applied. @@ -229,8 +228,8 @@ def adaptive_mixup(batch_size, aug_params, images, labels): # Need to filter out elements in alpha which equal to 0. greater_zero_indicator = tf.cast(alpha > 0, alpha.dtype) less_one_indicator = tf.cast(alpha < 1, alpha.dtype) - valid_alpha_indicator = tf.cast( - greater_zero_indicator * less_one_indicator, tf.bool) + valid_alpha_indicator = tf.cast(greater_zero_indicator * less_one_indicator, + tf.bool) sampled_alpha = tf.where(valid_alpha_indicator, alpha, 0.1) mix_weight = tfd.Beta(sampled_alpha, sampled_alpha).sample() mix_weight = tf.where(valid_alpha_indicator, mix_weight, alpha) @@ -253,4 +252,5 @@ def adaptive_mixup(batch_size, aug_params, images, labels): images_mix = ( images * images_mix_weight + images[::-1] * (1. - images_mix_weight)) labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight) - return images_mix, labels_mix + +return images_mix, labels_mix diff --git a/uncertainty_baselines/datasets/base.py b/uncertainty_baselines/datasets/base.py index 902988593..0c803e409 100644 --- a/uncertainty_baselines/datasets/base.py +++ b/uncertainty_baselines/datasets/base.py @@ -267,6 +267,7 @@ def _add_example_id(enumerate_id, example): def _load(self, *, preprocess_fn: Optional[PreProcessFn] = None, + process_batch_fn: Optional[PreProcessFn] = None, batch_size: int = -1) -> tf.data.Dataset: """Transforms the dataset from builder.as_dataset() to batch, repeat, etc. @@ -278,6 +279,9 @@ def _load(self, preprocess_fn: an optional preprocessing function, if not provided then a subclass must define _create_process_example_fn() which will be used to preprocess the data. + process_batch_fn: an optional processing batch function, if not + provided then _create_process_batch_fn() will be used to generate the + function that will process a batch of data. batch_size: the batch size to use. Returns: @@ -372,7 +376,8 @@ def _load(self, else: dataset = dataset.batch(batch_size, drop_remainder=self._drop_remainder) - process_batch_fn = self._create_process_batch_fn(batch_size) # pylint: disable=assignment-from-none + if process_batch_fn is None: + process_batch_fn = self._create_process_batch_fn(batch_size) if process_batch_fn: dataset = dataset.map( process_batch_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) @@ -406,6 +411,7 @@ def load( self, *, preprocess_fn: Optional[PreProcessFn] = None, + process_batch_fn: Optional[PreProcessFn] = None, batch_size: int = -1, strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset: """Function definition to support multi-host dataset sharding. @@ -431,11 +437,13 @@ def load( Args: preprocess_fn: see `load()`. + process_batch_fn: see `load()`. batch_size: the *global* batch size to use. This should equal `per_replica_batch_size * num_replica_in_sync`. strategy: the DistributionStrategy used to shard the dataset. Note that this is only required if TensorFlow for training, otherwise it can be ignored. + Returns: A sharded dataset, with its seed combined with the per-host id. """ @@ -445,11 +453,15 @@ def _load_distributed(ctx: tf.distribute.InputContext): self._seed, ctx.input_pipeline_id) per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size) return self._load( - preprocess_fn=preprocess_fn, batch_size=per_replica_batch_size) + preprocess_fn=preprocess_fn, + process_batch_fn=process_batch_fn, + batch_size=per_replica_batch_size) return strategy.distribute_datasets_from_function(_load_distributed) else: - return self._load(preprocess_fn=preprocess_fn, batch_size=batch_size) + return self._load(preprocess_fn=preprocess_fn, + process_batch_fn=process_batch_fn, + batch_size=batch_size) _BaseDatasetClass = Type[TypeVar('B', bound=BaseDataset)]