Skip to content

Commit 0c32ff7

Browse files
authored
Merge pull request #66 from boeddeker/master
Fix items bug of dataset that doesn't support items
2 parents 357b4c6 + 7d7e8d8 commit 0c32ff7

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

.github/workflows/run_python_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
runs-on: ubuntu-latest
1616
strategy:
1717
matrix:
18-
python-version: ["3.7", "3.8", "3.9", "3.10"]
18+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
1919

2020
steps:
2121
- uses: actions/checkout@v2

lazy_dataset/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import core
12
from .core import (
23
new,
34
concatenate,
@@ -7,6 +8,7 @@
78
from_dict,
89
from_list,
910
from_dataset,
11+
from_file,
1012
FilterException,
1113
)
12-
from.core import _zip as zip
14+
from .core import _zip as zip

lazy_dataset/core.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,35 @@ def from_dataset(
201201
>>> ds = from_dataset(new({'a': 1, 'b': 2, 'c': 3, 'd': 4}).filter(lambda x: x%2))
202202
>>> dict(ds)
203203
{'a': 1, 'c': 3}
204+
205+
# Works with concatenated datasets and duplicated keys
206+
>>> ds = new({'a': 1, 'b': 2})
207+
>>> ds = concatenate(ds, ds)
208+
>>> ds
209+
DictDataset(len=2)
210+
MapDataset(_pickle.loads)
211+
DictDataset(len=2)
212+
MapDataset(_pickle.loads)
213+
ConcatenateDataset()
214+
>>> from_dataset(ds)
215+
ListDataset(len=4)
216+
MapDataset(_pickle.loads)
217+
204218
"""
205219
try:
206220
items = list(examples.items())
207221
except ItemsNotDefined:
208222
return from_list(list(examples),
209223
immutable_warranty=immutable_warranty, name=name)
210224
else:
211-
return from_dict(dict(items),
212-
immutable_warranty=immutable_warranty, name=name)
225+
new = dict(items)
226+
if len(new) == len(items):
227+
return from_dict(new,
228+
immutable_warranty=immutable_warranty, name=name)
229+
else:
230+
# Duplicates in keys
231+
return from_list(list(map(operator.itemgetter(1), items)),
232+
immutable_warranty=immutable_warranty, name=name)
213233

214234

215235
def concatenate(*datasets):
@@ -417,7 +437,10 @@ def copy(self, freeze: bool = False) -> 'Dataset':
417437
Returns:
418438
A copy of this dataset
419439
"""
420-
raise NotImplementedError
440+
raise NotImplementedError(
441+
f'copy is not implemented for {self.__class__}.\n'
442+
f'self: \n{repr(self)}'
443+
)
421444

422445
def __iter__(self, with_key=False):
423446
if with_key:
@@ -2973,6 +2996,7 @@ def __init__(self, *input_datasets):
29732996
]
29742997
raise AssertionError(
29752998
f'Expect that all input_datasets have the same keys. '
2999+
f'Missing: {lengths} of {len(keys)}\n'
29763000
f'Missing keys: '
29773001
f'{missing_keys}\n{self.input_datasets}'
29783002
)
@@ -3067,8 +3091,8 @@ class ItemsDataset(Dataset):
30673091
>>> ds_nokeys_rng = ds_plain.shuffle(True, rng=np.random.RandomState(0)) # No keys
30683092
>>> list(ds_nokeys.map(lambda x: x + 10).items())
30693093
[('a', 11), ('b', 12), ('c', 13)]
3070-
>>> list(ds_nokeys.concatenate(ds_plain).items())
3071-
[('a', 1), ('b', 2), ('c', 3), ('a', 1), ('b', 2), ('c', 3)]
3094+
>>> list(ds_nokeys.map(lambda x: x + 10).concatenate(ds_plain).filter(lambda x: x in [1, 12, 13]).items())
3095+
[('b', 12), ('c', 13), ('a', 1)]
30723096
>>> list(ds_nokeys_rng.intersperse(ds_nokeys_rng).items())
30733097
[('c', 3), ('a', 1), ('c', 3), ('c', 3), ('b', 2), ('b', 2)]
30743098
>>> list(ds_plain.key_zip(ds_plain).items())

0 commit comments

Comments
 (0)