Skip to content

Commit 48c75e9

Browse files
committed
Fix support for torch 2.0
1 parent c79acea commit 48c75e9

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

AudioLoader/speech/speechcommands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch import Tensor
66
import os
77
import tqdm
8+
import torch
89
__TORCH_GTE_2_0 = False
910
split_version = torch.__version__.split(".")
1011
major_version = int(split_version[0])
@@ -282,4 +283,4 @@ def __getitem__(self, n: int) -> Tuple[Any, Any]:
282283
return waveform, label
283284

284285
def __len__(self) -> int:
285-
return len(self._data)
286+
return len(self._data)

AudioLoader/speech/timit.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
import warnings
1414
from distutils.dir_util import copy_tree
1515
from torchaudio.compliance import kaldi # for downsampling
16-
from torchaudio.datasets.utils import (
17-
download_url,
18-
extract_archive,
19-
)
16+
__TORCH_GTE_2_0 = False
17+
split_version = torch.__version__.split(".")
18+
major_version = int(split_version[0])
19+
if major_version > 1:
20+
__TORCH_GTE_2_0 = True
21+
from torchaudio.datasets.utils import _extract_zip as extract_archive
22+
from torch.hub import download_url_to_file as download_url
23+
else:
24+
from torchaudio.datasets.utils import (
25+
download_url,
26+
extract_archive,
27+
)
2028
import hashlib
2129
import torch.nn.functional as F
2230
from AudioLoader.music.utils import check_md5
@@ -183,4 +191,4 @@ def available_groups(self, groups):
183191
if groups=='all':
184192
return [f'DR{i+1}' for i in range(9)] # select all dialect regions
185193
elif isinstance(groups, list):
186-
return [f'DR{i}' for i in groups]
194+
return [f'DR{i}' for i in groups]

0 commit comments

Comments
 (0)