Skip to content

Commit d48bfa8

Browse files
Sherin Thomasrlizzo
Sherin Thomas
andauthored
New API design for datasets (#206)
* new API design for datasets * tuorial fix * clean up for naming fix and docstring fixes * more test cases for collate function, batching, internal dataset etc * updates to the getitem/reduce functions, but this does not work * fixed collate function bug for string datasets, not sure if working as intended * rebased and updated broken documentation * fixed issues where nested columns did not work. Changed nested column __getitem__() method to accept subsample key names as well * making it work * updated docstrings Co-authored-by: Rick Izzo <[email protected]>
1 parent 9084381 commit d48bfa8

34 files changed

+1581
-1008
lines changed

.github/workflows/testsuite.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ jobs:
2929
# build time with limited macos jobs
3030
- platform: macos-latest
3131
python-version: 3.7
32+
- platform: windows-latest
33+
python-version: 3.7
34+
testml: yes
3235

3336
steps:
3437
- uses: actions/checkout@v2
@@ -43,14 +46,14 @@ jobs:
4346
python -m pip install tox-gh-actions
4447
- name: Run Tests Without Coverage Report
4548
if: matrix.testcover == 'no'
46-
run: tox -- -p no:sugar
49+
run: tox
4750
env:
4851
PYTEST_XDIST_PROC_NR: 2
4952
TESTCOVER: ${{ matrix.testcover }}
5053
TESTML: ${{ matrix.testml }}
5154
- name: Run Tests With Coverage Report
5255
if: matrix.testcover == 'yes'
53-
run: tox -- --cov-report xml -p no:sugar
56+
run: tox -- --cov-report xml
5457
env:
5558
PYTEST_XDIST_PROC_NR: 2
5659
TESTCOVER: ${{ matrix.testcover }}

CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@ Change Log
33
==========
44

55

6+
_`In-Progress`
7+
==============
8+
9+
Improvements
10+
------------
11+
12+
* New API design for datasets (previously dataloaders) for machine learning libraries.
13+
(`#187 <https://github.com/tensorwerk/hangar-py/pull/187>`__) `@hhsecond <<https://github.com/hhsecond>>`__
14+
615
`v0.5.2`_ (2020-05-08)
716
======================
817

MANIFEST.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ include CODE_OF_CONDUCT.rst
1313
include LICENSE
1414
include README.rst
1515

16-
include tox.ini .travis.yml mypy.ini
16+
include tox.ini
17+
include mypy.ini
18+
include setup.py
1719

1820
global-exclude *.py[cod] *.so *.DS_Store
1921
global-exclude __pycache__ .mypy_cache .pytest_cache .hypothesis

docs/Tutorial-Dataloader.ipynb renamed to docs/Tutorial-Dataset.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@
161161
},
162162
"source": [
163163
"### Let's make a Tensorflow dataloader\n",
164-
"Hangar provides `make_tf_dataset` & `make_torch_dataset` for creating Tensorflow & PyTorch datasets from Hangar columns. You can read more about it in the [documentation](https://hangar-py.readthedocs.io/en/latest/api.html#ml-framework-dataloaders). Next we'll make a Tensorflow dataset and loop over it to make sure we have got a proper Tensorflow dataset."
164+
"Hangar provides `make_numpy_dataset`, `make_tensorflow_dataset` & `make_torch_dataset` for creating Tensorflow & PyTorch datasets from Hangar columns. You can read more about it in the [documentation](https://hangar-py.readthedocs.io/en/latest/api.html#ml-framework-dataloaders). Next we'll make a Tensorflow dataset and loop over it to make sure we have got a proper Tensorflow dataset."
165165
]
166166
},
167167
{
@@ -174,7 +174,7 @@
174174
},
175175
"outputs": [],
176176
"source": [
177-
"from hangar import make_tf_dataset"
177+
"from hangar.dataset import make_tensorflow_dataset"
178178
]
179179
},
180180
{
@@ -223,7 +223,7 @@
223223
"from matplotlib.pyplot import imshow\n",
224224
"co = repo.checkout()\n",
225225
"image_column = co.columns['images']\n",
226-
"dataset = make_tf_dataset(image_column)\n",
226+
"dataset = make_tensorflow_dataset(image_column)\n",
227227
"for image in dataset:\n",
228228
" imshow(image[0].numpy())\n",
229229
" break"
@@ -530,7 +530,7 @@
530530
"### Dataloaders for training\n",
531531
"We are using Tensorflow to build the network but how do we load this data from Hangar repository to Tensorflow?\n",
532532
"\n",
533-
"A naive option would be to run through the samples and load the numpy arrays and pass that to the `sess.run` of Tensorflow. But that would be quite inefficient. Tensorflow uses multiple threads to load the data in memory and its dataloaders can prefetch the data before-hand so that your training loop doesn't get blocked while loading the data. Also, Tensoflow dataloaders brings batching, shuffling, etc. to the table prebuilt. That's cool but how to load data from Hangar to Tensorflow using TF dataset? Well, we have `make_tf_dataset` which accepts the list of columns as a parameter and returns a TF dataset object."
533+
"A naive option would be to run through the samples and load the numpy arrays and pass that to the `sess.run` of Tensorflow. But that would be quite inefficient. Tensorflow uses multiple threads to load the data in memory and its dataloaders can prefetch the data before-hand so that your training loop doesn't get blocked while loading the data. Also, Tensoflow dataloaders brings batching, shuffling, etc. to the table prebuilt. That's cool but how to load data from Hangar to Tensorflow using TF dataset? Well, we have `make_tensorflow_dataset` which accepts the list of columns as a parameter and returns a TF dataset object."
534534
]
535535
},
536536
{
@@ -555,7 +555,7 @@
555555
}
556556
],
557557
"source": [
558-
"from hangar import make_tf_dataset\n",
558+
"from hangar.dataset import make_tensorflow_dataset\n",
559559
"co = repo.checkout() # we don't need write checkout here"
560560
]
561561
},
@@ -601,7 +601,7 @@
601601
"captions_dset = co.columns['captions']\n",
602602
"pimages_dset = co.columns['processed_images']\n",
603603
"\n",
604-
"dataset = make_tf_dataset([pimages_dset, captions_dset], shuffle=True)"
604+
"dataset = make_tensorflow_dataset([pimages_dset, captions_dset], shuffle=True)"
605605
]
606606
},
607607
{
@@ -613,7 +613,7 @@
613613
"source": [
614614
"### Padded Batching\n",
615615
"\n",
616-
"Batching needs a bit more explanation here since the dataset does not just consist of fixed shaped data. We have two dataset in which one is for captions. As you know captions are sequences which can be variably shaped. So instead of using `dataset.batch` we need to use `dataset.padded_batch` which takes care of padding the tensors with the longest value in each dimension for each batch. This `padded_batch` needs the shape by which the user needs the batch to be padded. Unless you need customization, you can use the shape stored in the `dataset` object by `make_tf_dataset` function."
616+
"Batching needs a bit more explanation here since the dataset does not just consist of fixed shaped data. We have two dataset in which one is for captions. As you know captions are sequences which can be variably shaped. So instead of using `dataset.batch` we need to use `dataset.padded_batch` which takes care of padding the tensors with the longest value in each dimension for each batch. This `padded_batch` needs the shape by which the user needs the batch to be padded. Unless you need customization, you can use the shape stored in the `dataset` object by `make_tensorflow_dataset` function."
617617
]
618618
},
619619
{
@@ -965,7 +965,7 @@
965965
"name": "python",
966966
"nbconvert_exporter": "python",
967967
"pygments_lexer": "ipython3",
968-
"version": "3.7.3"
968+
"version": "3.7.7"
969969
}
970970
},
971971
"nbformat": 4,

docs/api.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,14 @@ ML Framework Dataloaders
132132
Tensorflow
133133
----------
134134

135-
.. autofunction:: hangar.make_tf_dataset
135+
.. autofunction:: hangar.dataset.make_tensorflow_dataset
136136

137137
Pytorch
138138
-------
139139

140-
.. autofunction:: hangar.make_torch_dataset
140+
.. autofunction:: hangar.dataset.make_torch_dataset
141+
142+
Numpy
143+
-----
144+
145+
.. autofunction:: hangar.dataset.make_numpy_dataset

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def run(self):
119119
join('src', 'hangar', 'records', 'hashmachine.pyx'),
120120
]
121121
CYTHON_HEADERS = [
122+
join('src', 'hangar', 'external_cpython.pxd'),
122123
join('src', 'hangar', 'optimized_utils.pxd'),
123124
join('src', 'hangar', 'backends', 'specs.pxd'),
124125
join('src', 'hangar', 'records', 'recordstructs.pxd'),

src/hangar/__init__.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
11
__version__ = '0.5.2'
2-
__all__ = ('make_torch_dataset', 'make_tf_dataset', 'Repository')
2+
__all__ = ('Repository',)
33

4-
from functools import partial
54
from .repository import Repository
6-
7-
8-
def raise_ImportError(message, *args, **kwargs):
9-
raise ImportError(message)
10-
11-
12-
try:
13-
from .dataloaders.tfloader import make_tf_dataset
14-
except ImportError:
15-
make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")
16-
17-
try:
18-
from .dataloaders.torchloader import make_torch_dataset
19-
except ImportError:
20-
make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")
21-

src/hangar/_version.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
https://github.com/pypa/packaging/blob/6a09d4015b/LICENSE.BSD
1717
"""
1818
import re
19+
import typing
1920
from collections import namedtuple
2021
from itertools import dropwhile
2122
from typing import Callable, Optional, SupportsInt, Tuple, Union
@@ -99,24 +100,25 @@ def __neg__(self) -> InfinityType:
99100

100101
# -------------------- Type Definitions ---------------------------------------
101102

102-
InfiniteTypes = Union[InfinityType, NegativeInfinityType]
103-
PrePostDevType = Union[InfiniteTypes, Tuple[str, int]]
104-
SubLocalType = Union[InfiniteTypes, int, str]
105-
LocalType = Union[
106-
NegativeInfinityType,
107-
Tuple[
108-
Union[
109-
SubLocalType,
110-
Tuple[SubLocalType, str],
111-
Tuple[NegativeInfinityType, SubLocalType],
103+
if typing.TYPE_CHECKING:
104+
InfiniteTypes = Union[InfinityType, NegativeInfinityType]
105+
PrePostDevType = Union[InfiniteTypes, Tuple[str, int]]
106+
SubLocalType = Union[InfiniteTypes, int, str]
107+
LocalType = Union[
108+
NegativeInfinityType,
109+
Tuple[
110+
Union[
111+
SubLocalType,
112+
Tuple[SubLocalType, str],
113+
Tuple[NegativeInfinityType, SubLocalType],
114+
],
115+
...,
112116
],
113-
...,
114-
],
115-
]
116-
CmpKey = Tuple[
117-
int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType
118-
]
119-
VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool]
117+
]
118+
CmpKey = Tuple[
119+
int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType
120+
]
121+
VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool]
120122

121123

122124
# ---------------------------- Version Parsing --------------------------------
@@ -142,7 +144,7 @@ class _BaseVersion(object):
142144
__slots__ = ('_key',)
143145

144146
def __init__(self):
145-
self._key: CmpKey = None
147+
self._key: 'CmpKey' = None
146148

147149
def __hash__(self) -> int:
148150
return hash(self._key)
@@ -165,7 +167,7 @@ def __gt__(self, other: '_BaseVersion') -> bool:
165167
def __ne__(self, other: object) -> bool:
166168
return self._compare(other, ne)
167169

168-
def _compare(self, other: object, method: VersionComparisonMethod
170+
def _compare(self, other: object, method: 'VersionComparisonMethod'
169171
) -> Union[bool, type(NotImplemented)]:
170172
if isinstance(other, _BaseVersion):
171173
return method(self._key, other._key)
@@ -385,7 +387,7 @@ def _parse_letter_version(
385387
_local_version_separators = re.compile(r"[\._-]")
386388

387389

388-
def _parse_local_version(local: str) -> Optional[LocalType]:
390+
def _parse_local_version(local: str) -> Optional['LocalType']:
389391
"""
390392
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
391393
"""
@@ -403,8 +405,8 @@ def _cmpkey(
403405
pre: Optional[Tuple[str, int]],
404406
post: Optional[Tuple[str, int]],
405407
dev: Optional[Tuple[str, int]],
406-
local: Optional[Tuple[SubLocalType]],
407-
) -> CmpKey:
408+
local: Optional[Tuple['SubLocalType']],
409+
) -> 'CmpKey':
408410

409411
# When we compare a release version, we want to compare it with all of the
410412
# trailing zeros removed. So we'll use a reverse the list, drop all the now

src/hangar/backends/hdf5_00.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@
186186
from .. import __version__
187187
from ..optimized_utils import SizedDict
188188
from ..constants import DIR_DATA_REMOTE, DIR_DATA_STAGE, DIR_DATA_STORE, DIR_DATA
189-
from ..utils import find_next_prime, random_string, set_blosc_nthreads
189+
from ..utils import random_string, set_blosc_nthreads
190+
from ..optimized_utils import find_next_prime
190191
from ..op_state import reader_checkout_only, writer_checkout_only
191192
from ..typesystem import Descriptor, OneOf, DictItems, SizedIntegerTuple, checkedmeta
192193

src/hangar/backends/hdf5_01.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@
228228
from ..optimized_utils import SizedDict
229229
from ..constants import DIR_DATA_REMOTE, DIR_DATA_STAGE, DIR_DATA_STORE, DIR_DATA
230230
from ..op_state import writer_checkout_only, reader_checkout_only
231-
from ..utils import find_next_prime, random_string, set_blosc_nthreads
231+
from ..utils import random_string, set_blosc_nthreads
232+
from ..optimized_utils import find_next_prime
232233
from ..typesystem import Descriptor, OneOf, DictItems, SizedIntegerTuple, checkedmeta
233234

234235
set_blosc_nthreads()

src/hangar/columns/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
generate_nested_column,
66
column_type_object_from_schema
77
)
8+
from .introspection import is_column, is_writer_column
89

910
__all__ = (
1011
'Columns',
@@ -13,4 +14,6 @@
1314
'generate_nested_column',
1415
'column_type_object_from_schema',
1516
'ColumnTxn',
17+
'is_column',
18+
'is_writer_column'
1619
)

src/hangar/columns/introspection.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from .layout_flat import FlatSampleReader, FlatSampleWriter
2+
from .layout_nested import (
3+
FlatSubsampleReader,
4+
FlatSubsampleWriter,
5+
NestedSampleReader,
6+
NestedSampleWriter
7+
)
8+
9+
10+
def is_column(obj) -> bool:
11+
"""Determine if arbitrary input is an instance of a column layout.
12+
13+
Returns
14+
-------
15+
bool: True if input is an column, otherwise False.
16+
"""
17+
return isinstance(obj, (FlatSampleReader, FlatSubsampleReader, NestedSampleReader))
18+
19+
20+
def is_writer_column(obj) -> bool:
21+
"""Determine if arbitrary input is an instance of a write-enabled column layout.
22+
23+
Returns
24+
-------
25+
bool: True if input is write-enabled column, otherwise False.
26+
"""
27+
return isinstance(obj, (FlatSampleWriter, FlatSubsampleWriter, NestedSampleWriter))

src/hangar/columns/layout_flat.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
from contextlib import ExitStack
99
from pathlib import Path
10+
from operator import attrgetter as op_attrgetter
1011
from typing import Tuple, Union, Iterable, Optional, Any
1112

1213
from .common import open_file_handles
@@ -23,7 +24,8 @@
2324
from ..records.parsing import generate_sample_name
2425
from ..backends import backend_decoder
2526
from ..op_state import reader_checkout_only
26-
from ..utils import is_suitable_user_key, valfilter, valfilterfalse
27+
from ..utils import is_suitable_user_key
28+
from ..optimized_utils import valfilter, valfilterfalse
2729

2830

2931
KeyType = Union[str, int]
@@ -324,7 +326,8 @@ def contains_remote_references(self) -> bool:
324326
on some remote server. True if all sample data is available on the
325327
machine's local disk.
326328
"""
327-
return not all(map(lambda x: x.islocal, self._samples.values()))
329+
_islocal_func = op_attrgetter('islocal')
330+
return not all(map(_islocal_func, self._samples.values()))
328331

329332
@property
330333
def remote_reference_keys(self) -> Tuple[KeyType]:
@@ -336,7 +339,8 @@ def remote_reference_keys(self) -> Tuple[KeyType]:
336339
list of sample keys in the column whose data references indicate
337340
they are stored on a remote server.
338341
"""
339-
return tuple(valfilterfalse(lambda x: x.islocal, self._samples).keys())
342+
_islocal_func = op_attrgetter('islocal')
343+
return tuple(valfilterfalse(_islocal_func, self._samples).keys())
340344

341345
def _mode_local_aware_key_looper(self, local: bool) -> Iterable[KeyType]:
342346
"""Generate keys for iteration with dict update safety ensured.
@@ -352,11 +356,12 @@ def _mode_local_aware_key_looper(self, local: bool) -> Iterable[KeyType]:
352356
Iterable[KeyType]
353357
Sample keys conforming to the `local` argument spec.
354358
"""
359+
_islocal_func = op_attrgetter('islocal')
355360
if local:
356361
if self._mode == 'r':
357-
yield from valfilter(lambda x: x.islocal, self._samples).keys()
362+
yield from valfilter(_islocal_func, self._samples).keys()
358363
else:
359-
yield from tuple(valfilter(lambda x: x.islocal, self._samples).keys())
364+
yield from tuple(valfilter(_islocal_func, self._samples).keys())
360365
else:
361366
if self._mode == 'r':
362367
yield from self._samples.keys()

0 commit comments

Comments
 (0)