From 56085d9b98eb98377eeb28c159d90abb6360be95 Mon Sep 17 00:00:00 2001 From: Ales Mikholap Date: Sat, 27 Jun 2015 21:43:34 +0300 Subject: [PATCH] Fix HDF5 dataset 1. Don't provide parameters for superclass __new__ method (object.__new__). 2. Add function calls where required. 3. Fix interaction with iterators. --- pylearn2/datasets/hdf5.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pylearn2/datasets/hdf5.py b/pylearn2/datasets/hdf5.py index 7b586390d9..14b505db11 100644 --- a/pylearn2/datasets/hdf5.py +++ b/pylearn2/datasets/hdf5.py @@ -15,6 +15,7 @@ import tables except ImportError: tables = None +import numpy as np import warnings from os.path import isfile from pylearn2.compat import OrderedDict @@ -86,9 +87,7 @@ def __new__(cls, filename, X=None, topo_view=None, y=None, load_all=False, return HDF5DatasetDeprecated(filename, X, topo_view, y, load_all, cache_size, **kwargs) else: - return super(HDF5Dataset, cls).__new__( - cls, filename, sources, spaces, aliases, load_all, cache_size, - use_h5py, **kwargs) + return super(HDF5Dataset, cls).__new__(cls) def __init__(self, filename, sources, spaces, aliases=None, load_all=False, cache_size=None, use_h5py='auto', **kwargs): @@ -204,7 +203,7 @@ def iterator(self, mode=None, data_specs=None, batch_size=None, provided when the dataset object has been created will be used. """ if data_specs is None: - data_specs = (self._get_sources, self._get_spaces) + data_specs = (self._get_spaces(), self._get_sources()) [mode, batch_size, num_batches, rng, data_specs] = self._init_iterator( mode, batch_size, num_batches, rng, data_specs) @@ -240,7 +239,7 @@ def _get_spaces(self): ------- A Space or a list of Spaces. """ - space = [self.spaces[s] for s in self._get_sources] + space = [self.spaces[s] for s in self._get_sources()] return space[0] if len(space) == 1 else tuple(space) def get_data_specs(self, source_or_alias=None): @@ -310,16 +309,16 @@ def get(self, sources, indexes): sources[s], *e.args)) if (isinstance(indexes, (slice, py_integer_types)) or len(indexes) == 1): - rval.append(sdata[indexes]) + val = sdata[indexes] else: warnings.warn('Accessing non sequential elements of an ' 'HDF5 file will be at best VERY slow. Avoid ' 'using iteration schemes that access ' 'random/shuffled data with hdf5 datasets!!') - val = [] - [val.append(sdata[idx]) for idx in indexes] - rval.append(val) - return tuple(rval) + val = [sdata[idx] for idx in indexes] + val = tuple(tuple(row) for row in val) + rval.append(val) + return tuple(np.array(v) for v in rval) @wraps(Dataset.get_num_examples, assigned=(), updated=()) def get_num_examples(self, source_or_alias=None):