Skip to content

Commit 8adc00d

Browse files
committed
Fix linter
Signed-off-by: Piotr Żelasko <[email protected]>
1 parent 5154ad4 commit 8adc00d

File tree

5 files changed

+26
-15
lines changed

5 files changed

+26
-15
lines changed

nemo/collections/asr/models/msdd_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def get_range_average(
12581258
seq_len = end - stt
12591259
if stt < clus_label_tensor.shape[0]:
12601260
target_clus_label_tensor = clus_label_tensor[stt:end]
1261-
emb_seq, seg_length = (
1261+
emb_seq, _ = (
12621262
signals[stt:end, :, :],
12631263
min(
12641264
self.diar_window_length,
@@ -1291,7 +1291,7 @@ def get_range_average(
12911291
return emb_vectors_split, emb_seq, seq_len
12921292

12931293
def get_range_clus_avg_emb(
1294-
self, test_batch: List[torch.Tensor], _test_data_collection: List[Any], device: torch.device('cpu')
1294+
self, test_batch: List[torch.Tensor], _test_data_collection: List[Any], device: torch.device
12951295
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
12961296
"""
12971297
This function is only used when `get_range_average` function is called. This module calculates
@@ -1470,7 +1470,7 @@ def run_overlap_aware_eval(
14701470
verbose=self._cfg.verbose,
14711471
)
14721472
outputs.append(output)
1473-
logging.info(f" \n")
1473+
logging.info(" \n")
14741474
return outputs
14751475

14761476
@classmethod

nemo/collections/audio/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
# limitations under the License.
1414

1515
from nemo.collections.audio.losses.audio import MAELoss, MSELoss, SDRLoss
16+
17+
__all__ = ["MAELoss", "MSELoss", "SDRLoss"]

nemo/collections/audio/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,12 @@
2020
SchroedingerBridgeAudioToAudioModel,
2121
ScoreBasedGenerativeAudioToAudioModel,
2222
)
23+
24+
__all__ = [
25+
"AudioToAudioModel",
26+
"EncMaskDecAudioToAudioModel",
27+
"FlowMatchingAudioToAudioModel",
28+
"PredictiveAudioToAudioModel",
29+
"SchroedingerBridgeAudioToAudioModel",
30+
"ScoreBasedGenerativeAudioToAudioModel",
31+
]

nemo/core/classes/modelPT.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,27 +1529,27 @@ def extract_state_dict_from(
15291529
Example:
15301530
To convert the .nemo tarfile into a single Model level PyTorch checkpoint::
15311531
1532-
state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts')
1532+
state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './ckpts')
15331533
15341534
To restore a model from a Model level checkpoint::
15351535
15361536
model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration
1537-
model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt"))
1537+
model.load_state_dict(torch.load("./ckpts/model_weights.ckpt"))
15381538
15391539
To convert the .nemo tarfile into multiple Module level PyTorch checkpoints::
15401540
15411541
state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from(
1542-
'asr.nemo', './asr_ckpts', split_by_module=True
1542+
'asr.nemo', './ckpts', split_by_module=True
15431543
)
15441544
15451545
To restore a module from a Module level checkpoint::
15461546
15471547
model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration
15481548
15491549
# load the individual components
1550-
model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt"))
1551-
model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt"))
1552-
model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt"))
1550+
model.preprocessor.load_state_dict(torch.load("./ckpts/preprocessor.ckpt"))
1551+
model.encoder.load_state_dict(torch.load("./ckpts/encoder.ckpt"))
1552+
model.decoder.load_state_dict(torch.load("./ckpts/decoder.ckpt"))
15531553
15541554
Returns:
15551555
The state dict that was loaded from the original .nemo checkpoint

nemo/core/connectors/save_restore_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,27 +299,27 @@ def extract_state_dict_from(self, restore_path: str, save_dir: str, split_by_mod
299299
Example:
300300
To convert the .nemo tarfile into a single Model level PyTorch checkpoint::
301301
302-
state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts')
302+
state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './ckpts')
303303
304304
To restore a model from a Model level checkpoint::
305305
306306
model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration
307-
model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt"))
307+
model.load_state_dict(torch.load("./ckpts/model_weights.ckpt"))
308308
309309
To convert the .nemo tarfile into multiple Module level PyTorch checkpoints::
310310
311311
state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from(
312-
'asr.nemo', './asr_ckpts', split_by_module=True
312+
'asr.nemo', './ckpts', split_by_module=True
313313
)
314314
315315
To restore a module from a Module level checkpoint::
316316
317317
model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration
318318
319319
# load the individual components
320-
model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt"))
321-
model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt"))
322-
model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt"))
320+
model.preprocessor.load_state_dict(torch.load("./ckpts/preprocessor.ckpt"))
321+
model.encoder.load_state_dict(torch.load("./ckpts/encoder.ckpt"))
322+
model.decoder.load_state_dict(torch.load("./ckpts/decoder.ckpt"))
323323
324324
Returns:
325325
The state dict that was loaded from the original .nemo checkpoint

0 commit comments

Comments
 (0)