@@ -201,15 +201,35 @@ def from_dataset(
201
201
>>> ds = from_dataset(new({'a': 1, 'b': 2, 'c': 3, 'd': 4}).filter(lambda x: x%2))
202
202
>>> dict(ds)
203
203
{'a': 1, 'c': 3}
204
+
205
+ # Works with concatenated datasets and duplicated keys
206
+ >>> ds = new({'a': 1, 'b': 2})
207
+ >>> ds = concatenate(ds, ds)
208
+ >>> ds
209
+ DictDataset(len=2)
210
+ MapDataset(_pickle.loads)
211
+ DictDataset(len=2)
212
+ MapDataset(_pickle.loads)
213
+ ConcatenateDataset()
214
+ >>> from_dataset(ds)
215
+ ListDataset(len=4)
216
+ MapDataset(_pickle.loads)
217
+
204
218
"""
205
219
try :
206
220
items = list (examples .items ())
207
221
except ItemsNotDefined :
208
222
return from_list (list (examples ),
209
223
immutable_warranty = immutable_warranty , name = name )
210
224
else :
211
- return from_dict (dict (items ),
212
- immutable_warranty = immutable_warranty , name = name )
225
+ new = dict (items )
226
+ if len (new ) == len (items ):
227
+ return from_dict (new ,
228
+ immutable_warranty = immutable_warranty , name = name )
229
+ else :
230
+ # Duplicates in keys
231
+ return from_list (list (map (operator .itemgetter (1 ), items )),
232
+ immutable_warranty = immutable_warranty , name = name )
213
233
214
234
215
235
def concatenate (* datasets ):
@@ -417,7 +437,10 @@ def copy(self, freeze: bool = False) -> 'Dataset':
417
437
Returns:
418
438
A copy of this dataset
419
439
"""
420
- raise NotImplementedError
440
+ raise NotImplementedError (
441
+ f'copy is not implemented for { self .__class__ } .\n '
442
+ f'self: \n { repr (self )} '
443
+ )
421
444
422
445
def __iter__ (self , with_key = False ):
423
446
if with_key :
@@ -2973,6 +2996,7 @@ def __init__(self, *input_datasets):
2973
2996
]
2974
2997
raise AssertionError (
2975
2998
f'Expect that all input_datasets have the same keys. '
2999
+ f'Missing: { lengths } of { len (keys )} \n '
2976
3000
f'Missing keys: '
2977
3001
f'{ missing_keys } \n { self .input_datasets } '
2978
3002
)
@@ -3067,8 +3091,8 @@ class ItemsDataset(Dataset):
3067
3091
>>> ds_nokeys_rng = ds_plain.shuffle(True, rng=np.random.RandomState(0)) # No keys
3068
3092
>>> list(ds_nokeys.map(lambda x: x + 10).items())
3069
3093
[('a', 11), ('b', 12), ('c', 13)]
3070
- >>> list(ds_nokeys.concatenate(ds_plain).items())
3071
- [('a', 1), (' b', 2 ), ('c', 3 ), ('a', 1), ('b', 2), ('c', 3 )]
3094
+ >>> list(ds_nokeys.map(lambda x: x + 10). concatenate(ds_plain).filter(lambda x: x in [1, 12, 13] ).items())
3095
+ [('b', 12 ), ('c', 13 ), ('a', 1)]
3072
3096
>>> list(ds_nokeys_rng.intersperse(ds_nokeys_rng).items())
3073
3097
[('c', 3), ('a', 1), ('c', 3), ('c', 3), ('b', 2), ('b', 2)]
3074
3098
>>> list(ds_plain.key_zip(ds_plain).items())
0 commit comments