Skip to content

Commit e7237d8

Browse files
authored
add mac m1 mps support (#2477)
1 parent 8c336fe commit e7237d8

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

funasr/auto/auto_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def build_model(**kwargs):
184184
device = kwargs.get("device", "cuda")
185185
if ((device =="cuda" and not torch.cuda.is_available())
186186
or (device == "xpu" and not torch.xpu.is_available())
187+
or (device == "mps" and not torch.backends.mps.is_available())
187188
or kwargs.get("ngpu", 1) == 0):
188189
device = "cpu"
189190
kwargs["batch_size"] = 1

funasr/frontends/fused.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def __init__(self, frontends=None, align_method="linear_projection", proj_dim=10
8080
dev = "cuda"
8181
elif torch.xpu.is_available():
8282
dev = "xpu"
83+
elif torch.backends.mps.is_available():
84+
dev = "mps"
8385
else:
8486
dev = "cpu"
8587
if self.align_method == "linear_projection":

funasr/utils/export_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ def export(
2828
**kwargs,
2929
)
3030
elif type == "torchscript":
31-
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
31+
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
3232
print("Exporting torchscripts on device {}".format(device))
3333
_torchscripts(m, path=export_dir, device=device)
3434
elif type == "bladedisc":
3535
assert (
36-
torch.cuda.is_available() or torch.xpu.is_available()
36+
torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
3737
), "Currently bladedisc optimization for FunASR only supports GPU"
3838
# bladedisc only optimizes encoder/decoder modules
3939
if hasattr(m, "encoder") and hasattr(m, "decoder"):
@@ -44,7 +44,7 @@ def export(
4444

4545
elif type == "onnx_fp16":
4646
assert (
47-
torch.cuda.is_available() or torch.xpu.is_available()
47+
torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
4848
), "Currently onnx_fp16 optimization for FunASR only supports GPU"
4949

5050
if hasattr(m, "encoder") and hasattr(m, "decoder"):

0 commit comments

Comments
 (0)