Skip to content

Commit 2feb430

Browse files
committed
Merge branch 'master' of github.com:keras-team/keras
2 parents 458f100 + b026ff7 commit 2feb430

21 files changed

+442
-315
lines changed

keras/backend/jax/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def convert_to_tensor(x, dtype=None, sparse=None):
6161
if dtype and dtype != x.dtype:
6262
return x.value.astype(dtype)
6363
return x.value
64-
return jnp.array(x, dtype=dtype)
64+
return jnp.asarray(x, dtype=dtype)
6565

6666

6767
def convert_to_numpy(x):
68-
return np.array(x)
68+
return np.asarray(x)
6969

7070

7171
def is_tensor(x):

keras/backend/jax/trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,11 +904,10 @@ def distribute_single_value(d):
904904

905905

906906
class JAXEpochIterator(EpochIterator):
907-
def _get_iterator(self, return_type="auto"):
908-
if return_type in ("np", "auto"):
909-
# enable prefetching when using numpy_iterator
910-
return self._prefetch_numpy_iterator(super()._get_iterator("np"))
911-
return super()._get_iterator(return_type)
907+
def _get_iterator(self):
908+
return self._prefetch_numpy_iterator(
909+
self.data_adapter.get_jax_iterator()
910+
)
912911

913912
def _prefetch_numpy_iterator(self, numpy_iterator):
914913
"""Shard and prefetch batches on device.

keras/backend/numpy/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def append_to_outputs(batch_outputs, outputs):
198198
self.stop_predicting = False
199199
callbacks.on_predict_begin()
200200
outputs = None
201-
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
201+
for step, data in epoch_iterator.enumerate_epoch():
202202
callbacks.on_predict_batch_begin(step)
203203
batch_outputs = self.predict_function(data)
204204
outputs = append_to_outputs(batch_outputs, outputs)
@@ -242,7 +242,7 @@ def evaluate(
242242

243243
if not all(layer.built for layer in self._flatten_layers()):
244244
# Build the model on one batch of data.
245-
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
245+
for _, data in epoch_iterator.enumerate_epoch():
246246
data_batch = data[0]
247247
self._symbolic_build(data_batch)
248248
break
@@ -264,7 +264,7 @@ def evaluate(
264264
callbacks.on_test_begin()
265265
logs = None
266266
self.reset_metrics()
267-
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
267+
for step, data in epoch_iterator.enumerate_epoch():
268268
callbacks.on_test_batch_begin(step)
269269
logs = self.test_function(data)
270270
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))

keras/backend/tensorflow/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def convert_to_numpy(x):
126126
x.set_shape(x_shape)
127127
elif isinstance(x, tf.IndexedSlices):
128128
x = tf.convert_to_tensor(x)
129-
return np.array(x)
129+
return np.asarray(x)
130130

131131

132132
def is_tensor(x):

keras/backend/tensorflow/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,14 +629,17 @@ class TFEpochIterator(EpochIterator):
629629
def __init__(self, distribute_strategy=None, *args, **kwargs):
630630
super().__init__(*args, **kwargs)
631631
self._distribute_strategy = distribute_strategy
632-
dataset = self.data_adapter.get_tf_dataset()
632+
dataset = self._get_iterator()
633633
if not isinstance(dataset, tf.distribute.DistributedDataset):
634634
dataset = self._distribute_strategy.experimental_distribute_dataset(
635635
dataset
636636
)
637637
self._distributed_dataset = dataset
638638
self._steps_seen = 0
639639

640+
def _get_iterator(self):
641+
return self.data_adapter.get_tf_dataset()
642+
640643
def enumerate_epoch(self):
641644
if self.steps_per_epoch:
642645
if not self._current_iterator:

keras/backend/torch/trainer.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import collections
2-
import itertools
31
import warnings
42

53
import numpy as np
@@ -10,7 +8,6 @@
108
from keras import backend
119
from keras import callbacks as callbacks_module
1210
from keras import optimizers as optimizers_module
13-
from keras.trainers import data_adapters
1411
from keras.trainers import trainer as base_trainer
1512
from keras.trainers.data_adapters import data_adapter_utils
1613
from keras.trainers.epoch_iterator import EpochIterator
@@ -496,40 +493,5 @@ def predict_on_batch(self, x):
496493

497494

498495
class TorchEpochIterator(EpochIterator):
499-
def _get_iterator(self, return_type="auto"):
500-
if return_type == "auto" and isinstance(
501-
self.data_adapter, data_adapters.TorchDataLoaderAdapter
502-
):
503-
return self.data_adapter.get_torch_dataloader()
504-
elif return_type in ("np", "auto"):
505-
# enable prefetching when using numpy_iterator
506-
return self._prefetch_numpy_iterator(super()._get_iterator("np"))
507-
return super()._get_iterator(return_type)
508-
509-
def _prefetch_numpy_data(self, data):
510-
return tree.map_structure(backend.convert_to_tensor, data)
511-
512-
def _prefetch_numpy_iterator(self, numpy_iterator):
513-
"""Prefetch batches on device.
514-
515-
The idea has been borrowed from
516-
`torchtnt.utils.data.CudaDataPrefetcher`
517-
518-
This utility takes an iterator and returns a new iterator which fills an
519-
on device prefetch buffer. Eager prefetching can improve the performance
520-
of training loops significantly by overlapping compute and data
521-
transfer.
522-
"""
523-
queue = collections.deque()
524-
525-
# If you're training on GPUs, 2 is generally the best choice because
526-
# this guarantees that you can overlap a training step on GPU with a
527-
# data prefetch step on CPU.
528-
def enqueue(n=2):
529-
for data in itertools.islice(numpy_iterator, n):
530-
queue.append(self._prefetch_numpy_data(data))
531-
532-
enqueue(n=2) # TODO: should we make `n` configurable?
533-
while queue:
534-
yield queue.popleft()
535-
enqueue(1)
496+
def _get_iterator(self):
497+
return self.data_adapter.get_torch_dataloader()

keras/layers/normalization/batch_normalization_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ def test_correctness(
9292
broadcast_shape = [1] * len(input_shape)
9393
broadcast_shape[axis] = input_shape[axis]
9494
out = backend.convert_to_numpy(out)
95-
out -= np.reshape(backend.convert_to_numpy(layer.beta), broadcast_shape)
96-
out /= np.reshape(
95+
out = out - np.reshape(
96+
backend.convert_to_numpy(layer.beta), broadcast_shape
97+
)
98+
out = out / np.reshape(
9799
backend.convert_to_numpy(layer.gamma), broadcast_shape
98100
)
99101

@@ -200,8 +202,12 @@ def test_trainable_behavior(self):
200202
out = layer(x, training=True)
201203

202204
out = backend.convert_to_numpy(out)
203-
out -= np.reshape(backend.convert_to_numpy(layer.beta), (1, 1, 1, 3))
204-
out /= np.reshape(backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3))
205+
out = out - np.reshape(
206+
backend.convert_to_numpy(layer.beta), (1, 1, 1, 3)
207+
)
208+
out = out / np.reshape(
209+
backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3)
210+
)
205211

206212
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
207213
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)

keras/trainers/data_adapters/array_data_adapter.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras import backend
77
from keras.trainers.data_adapters import data_adapter_utils
88
from keras.trainers.data_adapters.data_adapter import DataAdapter
9+
from keras.utils.dataset_utils import is_torch_tensor
910
from keras.utils.nest import lists_to_tuples
1011

1112
try:
@@ -98,13 +99,23 @@ def __init__(
9899

99100
def get_numpy_iterator(self):
100101
inputs = self._inputs
101-
if self._shuffle:
102+
if self._shuffle and self._shuffle != "batch":
102103
inputs = data_adapter_utils.sync_shuffle(
103104
inputs, num_samples=self._num_samples
104105
)
105106
for i in range(self._size):
106-
start, stop = i * self._batch_size, (i + 1) * self._batch_size
107-
yield tree.map_structure(lambda x: x[start:stop], inputs)
107+
start = i * self._batch_size
108+
stop = min((i + 1) * self._batch_size, self._num_samples)
109+
if self._shuffle == "batch":
110+
111+
def slice_and_shuffle(x):
112+
return data_adapter_utils.sync_shuffle(
113+
x[start:stop], num_samples=(stop - start)
114+
)
115+
116+
yield tree.map_structure(slice_and_shuffle, inputs)
117+
else:
118+
yield tree.map_structure(lambda x: x[start:stop], inputs)
108119

109120
def get_tf_dataset(self):
110121
from keras.utils.module_utils import tensorflow as tf
@@ -237,6 +248,62 @@ def shuffle_batch(*batch):
237248
dataset = dataset.with_options(options)
238249
return dataset.prefetch(tf.data.AUTOTUNE)
239250

251+
def get_jax_iterator(self):
252+
return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator())
253+
254+
def get_torch_dataloader(self):
255+
import torch
256+
257+
from keras.backend.torch.core import convert_to_tensor
258+
259+
class ArrayDataset(torch.utils.data.Dataset):
260+
def __init__(self, array):
261+
self.array = array
262+
263+
def __getitem__(self, index):
264+
def slice_and_convert(x):
265+
return convert_to_tensor(x[index])
266+
267+
return tree.map_structure(slice_and_convert, self.array)
268+
269+
def __len__(self):
270+
return len(self.array[0])
271+
272+
class RandomBatchSampler(torch.utils.data.Sampler):
273+
def __init__(self, sampler):
274+
self.sampler = sampler
275+
276+
def __iter__(self):
277+
for batch in self.sampler:
278+
yield [batch[i] for i in torch.randperm(len(batch))]
279+
280+
def __len__(self):
281+
return len(self.sampler)
282+
283+
if self._shuffle == "batch":
284+
batch_sampler = RandomBatchSampler(
285+
torch.utils.data.BatchSampler(
286+
range(self._num_samples),
287+
batch_size=self._batch_size,
288+
drop_last=False,
289+
)
290+
)
291+
elif self._shuffle:
292+
batch_sampler = torch.utils.data.BatchSampler(
293+
torch.utils.data.RandomSampler(range(self._num_samples)),
294+
batch_size=self._batch_size,
295+
drop_last=False,
296+
)
297+
else:
298+
batch_sampler = torch.utils.data.BatchSampler(
299+
torch.utils.data.SequentialSampler(range(self._num_samples)),
300+
batch_size=self._batch_size,
301+
drop_last=False,
302+
)
303+
304+
dataset = ArrayDataset(self._inputs)
305+
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
306+
240307
@property
241308
def num_batches(self):
242309
return self._size
@@ -315,7 +382,9 @@ def convert_single_array(x):
315382
# `torch.Tensor`, as well as any other tensor-like object that has
316383
# added numpy support.
317384
if hasattr(x, "__array__"):
318-
x = backend.convert_to_numpy(x)
385+
if is_torch_tensor(x):
386+
x = x.cpu()
387+
x = np.asarray(x)
319388
else:
320389
raise ValueError(
321390
"Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "

0 commit comments

Comments
 (0)