Skip to content

Commit 54b742d

Browse files
authored
extend OggZipHdfDataInput (#94)
1 parent b398f92 commit 54b742d

File tree

1 file changed

+8
-3
lines changed
  • common/setups/rasr/util

1 file changed

+8
-3
lines changed

common/setups/rasr/util/nn.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,15 @@ def get_crp(self, **kwargs):
148148
class OggZipHdfDataInput:
149149
def __init__(
150150
self,
151-
oggzip_files: tk.Path,
151+
oggzip_files: List[tk.Path],
152152
alignments: tk.Path,
153153
context_window: Dict,
154154
audio: Dict,
155-
targets: str,
155+
targets: Optional[str] = None,
156156
partition_epoch: int = 1,
157157
seq_ordering: str = "laplace:.1000",
158+
ogg_args: Optional[Dict[str, Any]] = None,
159+
acoustic_mixtures: Optional[tk.Path] = None,
158160
):
159161
"""
160162
:param oggzip_files:
@@ -172,6 +174,8 @@ def __init__(
172174
self.partition_epoch = partition_epoch
173175
self.seq_ordering = seq_ordering
174176
self.targets = targets
177+
self.ogg_args = ogg_args
178+
self.acoustic_mixtures = acoustic_mixtures
175179

176180
def get_data_dict(self):
177181
return {
@@ -188,10 +192,11 @@ def get_data_dict(self):
188192
"class": "OggZipDataset",
189193
"audio": self.audio,
190194
"partition_epoch": self.partition_epoch,
191-
"path": tuple(self.oggzip_files.get_path()),
195+
"path": self.oggzip_files,
192196
"seq_ordering": self.seq_ordering,
193197
"targets": self.targets,
194198
"use_cache_manager": True,
199+
**(self.ogg_args or {}),
195200
},
196201
},
197202
"seq_order_control_dataset": "ogg",

0 commit comments

Comments
 (0)