diff --git a/brainio/assemblies.py b/brainio/assemblies.py index 13a19b4..d1659d0 100644 --- a/brainio/assemblies.py +++ b/brainio/assemblies.py @@ -11,6 +11,8 @@ import xarray as xr from xarray import DataArray, IndexVariable +from brainio.stimuli import StimulusSet + BRAINIO_CHUNKS = 'BRAINIO_CHUNKS' _logger = logging.getLogger(__name__) @@ -18,8 +20,8 @@ def is_fastpath(*args, **kwargs): """checks whether a set of args and kwargs would be interpreted by DataArray.__init__""" - n = 7 # maximum length of args if all arguments to DataArray are positional (as of 0.16.1) - return ("fastpath" in kwargs and kwargs["fastpath"]) or (len(args) >= n and args[n-1]) + n = 7 # maximum length of args if all arguments to DataArray are positional (as of 0.16.1) + return ("fastpath" in kwargs and kwargs["fastpath"]) or (len(args) >= n and args[n - 1]) class DataPoint(object): @@ -86,7 +88,7 @@ def __init__(self, values): def __eq__(self, other): return len(self.values) == len(other.values) and \ - all(v1 == v2 for v1, v2 in zip(self.values, other.values)) + all(v1 == v2 for v1, v2 in zip(self.values, other.values)) def __lt__(self, other): return self.values < other.values @@ -337,13 +339,15 @@ def get_metadata_before_2022_06(assembly, dims=None, names_only=False, include_c """ Return coords and/or indexes or index levels from an assembly, yielding either `name` or `(name, dims, values)`. """ + def what(name, dims, values, names_only): if names_only: return name else: return name, dims, values + if dims is None: - dims = assembly.dims + (None,) # all dims plus dimensionless coords + dims = assembly.dims + (None,) # all dims plus dimensionless coords for name in assembly.coords.variables: values = assembly.coords.variables[name] is_subset = values.dims and (set(values.dims) <= set(dims)) @@ -351,14 +355,14 @@ def what(name, dims, values, names_only): if is_subset or is_dimless: is_index = isinstance(values, IndexVariable) if is_index: - if values.level_names: # it's a MultiIndex + if values.level_names: # it's a MultiIndex if include_multi_indexes: yield what(name, values.dims, values.values, names_only) if include_levels: for level in values.level_names: level_values = assembly.coords[level] yield what(level, level_values.dims, level_values.values, names_only) - else: # it's an Index + else: # it's an Index if include_indexes: yield what(name, values.dims, values.values, names_only) else: @@ -367,17 +371,19 @@ def what(name, dims, values, names_only): def get_metadata_after_2022_06(assembly, dims=None, names_only=False, include_coords=True, - include_indexes=True, include_multi_indexes=False, include_levels=True): + include_indexes=True, include_multi_indexes=False, include_levels=True): """ Return coords and/or indexes or index levels from an assembly, yielding either `name` or `(name, dims, values)`. """ + def what(name, dims, values, names_only): if names_only: return name else: return name, dims, values + if dims is None: - dims = assembly.dims + (None,) # all dims plus dimensionless coords + dims = assembly.dims + (None,) # all dims plus dimensionless coords for name, values in assembly.coords.items(): none_but_keep = (not values.dims) and None in dims shared = not (set(values.dims).isdisjoint(set(dims))) @@ -407,7 +413,7 @@ def get_metadata(assembly, dims=None, names_only=False, include_coords=True, include_indexes, include_multi_indexes, include_levels) except TypeError as e: yield from get_metadata_before_2022_06(assembly, dims, names_only, include_coords, - include_indexes, include_multi_indexes, include_levels) + include_indexes, include_multi_indexes, include_levels) def coords_for_dim(assembly, dim): @@ -434,7 +440,8 @@ def gather_indexes(assembly): """This is only necessary as long as xarray cannot persist MultiIndex to netCDF. """ coords_d = {} for dim in assembly.dims: - coord_names = list(get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False)) + coord_names = list( + get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False)) if coord_names: coords_d[dim] = coord_names if coords_d: @@ -457,7 +464,7 @@ def load(self): try: import dask result = xr.open_dataarray(self.file_path, group=self.group, chunks=chunks) - except ModuleNotFoundError as e: + except ModuleNotFoundError: result = xr.open_dataarray(self.file_path, group=self.group) result = self.correct_stimulus_id_name(result) result = self.assembly_class(data=result) @@ -501,7 +508,7 @@ def load(self): result = self.merge_stimulus_set_meta(result, self.stimulus_set) return result - def merge_stimulus_set_meta(self, assy, stimulus_set): + def merge_stimulus_set_meta(self, assy: DataAssembly, stimulus_set: StimulusSet) -> DataAssembly: dim_name, index_column = "presentation", "stimulus_id" assy = assy.reset_index(list(assy.indexes)) df_of_coords = pd.DataFrame(coords_for_dim(assy, dim_name)) @@ -539,6 +546,3 @@ def load(self): exc_info=True ) return result - - - diff --git a/brainio/stimuli.py b/brainio/stimuli.py index 2832ee9..7e5ff59 100644 --- a/brainio/stimuli.py +++ b/brainio/stimuli.py @@ -9,7 +9,8 @@ class StimulusSet(pd.DataFrame): # http://pandas.pydata.org/pandas-docs/stable/development/extending.html#subclassing-pandas-data-structures - _metadata = pd.DataFrame._metadata + ["identifier", "get_stimulus", 'get_loader_class', "stimulus_paths", "from_files"] + _metadata = pd.DataFrame._metadata + ["identifier", "get_stimulus", "get_loader_class", + "stimulus_paths", "from_files", "placed_on_screen"] @property def _constructor(self):