diff --git a/microsoft/testsuites/gpu/gpusuite.py b/microsoft/testsuites/gpu/gpusuite.py index d51da484ca..f860c64bde 100644 --- a/microsoft/testsuites/gpu/gpusuite.py +++ b/microsoft/testsuites/gpu/gpusuite.py @@ -124,11 +124,20 @@ def verify_gpu_provision(self, node: Node, log: Logger) -> None: timeout=TIMEOUT, # min_gpu_count is 8 since it is current # max GPU count available in Azure - requirement=simple_requirement(min_gpu_count=8), + requirement=simple_requirement(min_gpu_count=2), priority=3, ) def verify_max_gpu_provision(self, node: Node, log: Logger) -> None: - _gpu_provision_check(8, node, log) + actual_gpu_count = node.capability.gpu_count + if not isinstance(actual_gpu_count, int): + raise SkippedException("GPU count is not available") + # For "max" GPU test, we want to test with high GPU counts + if actual_gpu_count < 2: + raise SkippedException( + f"Test is for scenarios with more than 2 GPUs, " + f" current Node only has {actual_gpu_count} GPUs." + ) + _gpu_provision_check(actual_gpu_count, node, log) @TestCaseMetadata( description="""