@@ -367,7 +367,7 @@ def count_per_rspec(self, rspec: dict[str, Any], type: str) -> int | None:
367367 return child ["count" ]
368368 return None
369369
370- def count_per_node (self , type : str ) -> int :
370+ def count_per_node (self , type : str , default : int | None = None ) -> int :
371371 for rspec in self .resource_specs :
372372 if rspec ["type" ] == "node" :
373373 count = self .count_per_rspec (rspec , type )
@@ -376,18 +376,22 @@ def count_per_node(self, type: str) -> int:
376376 try :
377377 count_per_socket = self .count_per_socket (type )
378378 except ValueError :
379+ if default is not None :
380+ return default
379381 raise ValueError (f"Unable to determine count_per_node for { type !r} " ) from None
380382 else :
381383 return count_per_socket * self .sockets_per_node
382384
383- def count_per_socket (self , type : str ) -> int :
385+ def count_per_socket (self , type : str , default : int | None = None ) -> int :
384386 for rspec1 in self .resource_specs :
385387 if rspec1 ["type" ] == "node" :
386388 for rspec2 in rspec1 ["resources" ]:
387389 if rspec2 ["type" ] == "socket" :
388390 count = self .count_per_rspec (rspec2 , type )
389391 if count is not None :
390392 return count
393+ if default is not None :
394+ return default
391395 raise ValueError (f"Unable to determine count_per_socket for { type !r} " )
392396
393397 @cached_property
@@ -405,47 +409,24 @@ def sockets_per_node(self) -> int:
405409 except ValueError :
406410 return 1
407411
408- @cached_property
409- def cpus_per_socket (self ) -> int :
410- try :
411- return self .count_per_socket ("cpu" )
412- except ValueError :
413- return psutil .cpu_count ()
414-
415- @cached_property
416- def gpus_per_socket (self ) -> int :
417- try :
418- return self .count_per_socket ("gpu" )
419- except ValueError :
420- return 0
421-
422- @cached_property
423- def cpus_per_node (self ) -> int :
424- return self .sockets_per_node * self .cpus_per_socket
425-
426- @cached_property
427- def gpus_per_node (self ) -> int :
428- if gpus_per_resource_type := self .gpus_per_socket :
429- return self .sockets_per_node * gpus_per_resource_type
430- try :
431- return self .count_per_node ("gpu" )
432- except ValueError :
433- return 0
434-
435- @cached_property
436- def cpu_count (self ) -> int :
437- return self .node_count * self .cpus_per_node
438-
439- @cached_property
440- def gpu_count (self ) -> int :
441- return self .node_count * self .gpus_per_node
442-
443- def nodes_required (self , max_cpus : int | None = None , max_gpus : int | None = None ) -> int :
412+ def nodes_required (self , ** types : int ) -> int :
444413 """Nodes required to run ``tasks`` tasks. A task can be thought of as a single MPI
445414 rank"""
446- nodes = max (1 , int (math .ceil ((max_cpus or 1 ) / self .cpus_per_node )))
447- if self .gpus_per_node :
448- nodes = max (nodes , int (math .ceil ((max_gpus or 0 ) / self .gpus_per_node )))
415+ # backward compatible
416+ if n := types .pop ("max_cpus" , None ):
417+ types ["cpu" ] = n
418+ if n := types .pop ("max_gpus" , None ):
419+ types ["gpu" ] = n
420+ nodes : int = 1
421+ for type , count in types .items ():
422+ try :
423+ count_per_node = self .count_per_node (type )
424+ except ValueError :
425+ continue
426+ else :
427+ if count_per_node == 0 :
428+ continue
429+ nodes = max (nodes , int (math .ceil (count / count_per_node )))
449430 return nodes
450431
451432 def compute_required_resources (
@@ -487,8 +468,8 @@ def compute_required_resources(
487468 ranks = ranks_per_socket = 1
488469 nodes = 1
489470 elif ranks is not None and ranks_per_socket is None :
490- ranks_per_socket = min (ranks , self .cpus_per_socket )
491- nodes = int (math .ceil (ranks / self .cpus_per_socket / self .sockets_per_node ))
471+ ranks_per_socket = min (ranks , self .count_per_socket ( "cpu" ) )
472+ nodes = int (math .ceil (ranks / self .count_per_socket ( "cpu" ) / self .sockets_per_node ))
492473 else :
493474 assert ranks is not None
494475 assert ranks_per_socket is not None
0 commit comments