Skip to content

Commit 34d5072

Browse files
authored
[MLIR][test] Check for ml_dtypes before running tests (#123061)
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded: ```python try: import ml_dtypes except ModuleNotFoundError: # The third-party ml_dtypes provides some optional low precision data-types for NumPy. ml_dtypes = None ``` This makes `ml_dtypes` an optional dependency. Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does. This is a replacement for #123051 which was excluding tests too broadly.
1 parent d0a3642 commit 34d5072

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

mlir/test/python/execution_engine.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from mlir.passmanager import *
66
from mlir.execution_engine import *
77
from mlir.runtime import *
8-
from ml_dtypes import bfloat16, float8_e5m2
8+
9+
try:
10+
from ml_dtypes import bfloat16, float8_e5m2
11+
12+
HAS_ML_DTYPES = True
13+
except ModuleNotFoundError:
14+
HAS_ML_DTYPES = False
15+
916

1017
MLIR_RUNNER_UTILS = os.getenv(
1118
"MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
@@ -564,7 +571,8 @@ def testBF16Memref():
564571
log(npout)
565572

566573

567-
run(testBF16Memref)
574+
if HAS_ML_DTYPES:
575+
run(testBF16Memref)
568576

569577

570578
# Test f8E5M2 memrefs
@@ -603,7 +611,8 @@ def testF8E5M2Memref():
603611
log(npout)
604612

605613

606-
run(testF8E5M2Memref)
614+
if HAS_ML_DTYPES:
615+
run(testF8E5M2Memref)
607616

608617

609618
# Test addition of two 2d_memref

0 commit comments

Comments
 (0)