Skip to content

Commit 1ffb9b4

Browse files
committedDec 3, 2024
add back support of duplicated keys in ConcatenateDataset and change from_dataset
1 parent 1c99a4b commit 1ffb9b4

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed
 

Diff for: ‎lazy_dataset/core.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,13 @@ def from_dataset(
223223
immutable_warranty=immutable_warranty, name=name)
224224
else:
225225
new = dict(items)
226-
assert len(new) == len(items), f'{len(new)} != {len(items)}\nYou found a bug!\n{examples!r}'
227-
return from_dict(new,
228-
immutable_warranty=immutable_warranty, name=name)
226+
if len(new) == len(items):
227+
return from_dict(new,
228+
immutable_warranty=immutable_warranty, name=name)
229+
else:
230+
# Duplicates in keys
231+
return from_list(list(map(operator.itemgetter(1), items)),
232+
immutable_warranty=immutable_warranty, name=name)
229233

230234

231235
def concatenate(*datasets):
@@ -2705,12 +2709,6 @@ def ordered(self) -> bool:
27052709
return all(ds.ordered for ds in self.input_datasets)
27062710

27072711
def __iter__(self, with_key=False):
2708-
if with_key:
2709-
try:
2710-
self.keys()
2711-
except AssertionError as e:
2712-
raise _ItemsNotDefined(self.__class__.__name__) from e
2713-
27142712
for input_dataset in self.input_datasets:
27152713
if with_key:
27162714
iterable = input_dataset.__iter__(with_key=True)
@@ -3093,8 +3091,8 @@ class ItemsDataset(Dataset):
30933091
>>> ds_nokeys_rng = ds_plain.shuffle(True, rng=np.random.RandomState(0)) # No keys
30943092
>>> list(ds_nokeys.map(lambda x: x + 10).items())
30953093
[('a', 11), ('b', 12), ('c', 13)]
3096-
>>> list(ds_nokeys.concatenate(ds_plain).items())
3097-
[('a', 1), ('b', 2), ('c', 3), ('a', 1), ('b', 2), ('c', 3)]
3094+
>>> list(ds_nokeys.map(lambda x: x + 10).concatenate(ds_plain).filter(lambda x: x in [1, 12, 13]).items())
3095+
[('b', 12), ('c', 13), ('a', 1)]
30983096
>>> list(ds_nokeys_rng.intersperse(ds_nokeys_rng).items())
30993097
[('c', 3), ('a', 1), ('c', 3), ('c', 3), ('b', 2), ('b', 2)]
31003098
>>> list(ds_plain.key_zip(ds_plain).items())

0 commit comments

Comments
 (0)