Skip to content

Commit 35b88ff

Browse files
committed
better determination of cpus per socket for slurm scheduler
1 parent 7974627 commit 35b88ff

File tree

5 files changed

+37
-52
lines changed

5 files changed

+37
-52
lines changed

src/hpc_connect/config.py

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def count_per_rspec(self, rspec: dict[str, Any], type: str) -> int | None:
367367
return child["count"]
368368
return None
369369

370-
def count_per_node(self, type: str) -> int:
370+
def count_per_node(self, type: str, default: int | None = None) -> int:
371371
for rspec in self.resource_specs:
372372
if rspec["type"] == "node":
373373
count = self.count_per_rspec(rspec, type)
@@ -376,18 +376,22 @@ def count_per_node(self, type: str) -> int:
376376
try:
377377
count_per_socket = self.count_per_socket(type)
378378
except ValueError:
379+
if default is not None:
380+
return default
379381
raise ValueError(f"Unable to determine count_per_node for {type!r}") from None
380382
else:
381383
return count_per_socket * self.sockets_per_node
382384

383-
def count_per_socket(self, type: str) -> int:
385+
def count_per_socket(self, type: str, default: int | None = None) -> int:
384386
for rspec1 in self.resource_specs:
385387
if rspec1["type"] == "node":
386388
for rspec2 in rspec1["resources"]:
387389
if rspec2["type"] == "socket":
388390
count = self.count_per_rspec(rspec2, type)
389391
if count is not None:
390392
return count
393+
if default is not None:
394+
return default
391395
raise ValueError(f"Unable to determine count_per_socket for {type!r}")
392396

393397
@cached_property
@@ -405,47 +409,24 @@ def sockets_per_node(self) -> int:
405409
except ValueError:
406410
return 1
407411

408-
@cached_property
409-
def cpus_per_socket(self) -> int:
410-
try:
411-
return self.count_per_socket("cpu")
412-
except ValueError:
413-
return psutil.cpu_count()
414-
415-
@cached_property
416-
def gpus_per_socket(self) -> int:
417-
try:
418-
return self.count_per_socket("gpu")
419-
except ValueError:
420-
return 0
421-
422-
@cached_property
423-
def cpus_per_node(self) -> int:
424-
return self.sockets_per_node * self.cpus_per_socket
425-
426-
@cached_property
427-
def gpus_per_node(self) -> int:
428-
if gpus_per_resource_type := self.gpus_per_socket:
429-
return self.sockets_per_node * gpus_per_resource_type
430-
try:
431-
return self.count_per_node("gpu")
432-
except ValueError:
433-
return 0
434-
435-
@cached_property
436-
def cpu_count(self) -> int:
437-
return self.node_count * self.cpus_per_node
438-
439-
@cached_property
440-
def gpu_count(self) -> int:
441-
return self.node_count * self.gpus_per_node
442-
443-
def nodes_required(self, max_cpus: int | None = None, max_gpus: int | None = None) -> int:
412+
def nodes_required(self, **types: int) -> int:
444413
"""Nodes required to run ``tasks`` tasks. A task can be thought of as a single MPI
445414
rank"""
446-
nodes = max(1, int(math.ceil((max_cpus or 1) / self.cpus_per_node)))
447-
if self.gpus_per_node:
448-
nodes = max(nodes, int(math.ceil((max_gpus or 0) / self.gpus_per_node)))
415+
# backward compatible
416+
if n := types.pop("max_cpus", None):
417+
types["cpu"] = n
418+
if n := types.pop("max_gpus", None):
419+
types["gpu"] = n
420+
nodes: int = 1
421+
for type, count in types.items():
422+
try:
423+
count_per_node = self.count_per_node(type)
424+
except ValueError:
425+
continue
426+
else:
427+
if count_per_node == 0:
428+
continue
429+
nodes = max(nodes, int(math.ceil(count / count_per_node)))
449430
return nodes
450431

451432
def compute_required_resources(
@@ -487,8 +468,8 @@ def compute_required_resources(
487468
ranks = ranks_per_socket = 1
488469
nodes = 1
489470
elif ranks is not None and ranks_per_socket is None:
490-
ranks_per_socket = min(ranks, self.cpus_per_socket)
491-
nodes = int(math.ceil(ranks / self.cpus_per_socket / self.sockets_per_node))
471+
ranks_per_socket = min(ranks, self.count_per_socket("cpu"))
472+
nodes = int(math.ceil(ranks / self.count_per_socket("cpu") / self.sockets_per_node))
492473
else:
493474
assert ranks is not None
494475
assert ranks_per_socket is not None

src/hpc_connect/submit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def format_submission_data(
180180
"nodes": nodes,
181181
"cpus": cpus,
182182
"gpus": gpus,
183-
"cpus_per_node": self.config.cpus_per_node,
184-
"gpus_per_node": self.config.gpus_per_node,
183+
"cpus_per_node": self.config.count_per_node("cpu"),
184+
"gpus_per_node": self.config.count_per_node("gpu", default=0),
185185
"user": getpass.getuser(),
186186
"date": datetime.now().strftime("%c"),
187187
"variables": variables or {},

src/hpcc_flux/submit_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ def get_alloc_settings(
254254
alloc: dict[str, Any] = {}
255255
if nodes is not None:
256256
if cpus is None:
257-
cpus = nodes * self.config.cpus_per_node
257+
cpus = nodes * self.config.count_per_node("cpu")
258258
if gpus is None:
259-
gpus = nodes * self.config.gpus_per_node
259+
gpus = nodes * self.config.count_per_node("gpu", default=0)
260260
else:
261261
cpus = cpus or 1
262262
gpus = gpus or 0

src/hpcc_slurm/submit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,18 @@ def read_sinfo() -> dict[str, Any] | None:
254254
"resources": [
255255
{
256256
"type": "cpu",
257-
"count": cores_per_socket,
258-
"additional_properties": {
259-
"threads_per_core": threads_per_core,
260-
},
257+
"count": int(cpus_per_node / sockets_per_node),
261258
},
262259
],
263260
}
264261
],
262+
"additional_properties": {
263+
"sockets_per_node": sockets_per_node,
264+
"cores_per_socket": cores_per_socket,
265+
"threads_per_core": threads_per_core,
266+
"cpus_per_node": cpus_per_node,
267+
"gres": " ".join(gres),
268+
},
265269
}
266270
for res in gres:
267271
if not res:

tests/pbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_basic():
3939
text = fh.getvalue()
4040
assert "#!/bin/sh" in text
4141
assert "#PBS -N my-job" in text
42-
assert f"#PBS -l nodes=1:ppn={backend.config.cpus_per_node}" in text
42+
assert f"#PBS -l nodes=1:ppn={backend.config.count_per_node('cpu')}" in text
4343
assert "#PBS -l walltime=00:00:01" in text
4444
assert "#PBS --job-name=my-job" in text
4545
assert "#PBS -o my-out.txt" in text

0 commit comments

Comments
 (0)