Skip to content

Commit 030c796

Browse files
authored
add default machine:resource scope (#4)
* add default machine:resource scope * update schema for environment variables * dont discover the default resource spec so that it can be loaded lazily * Update discover.py
1 parent 98148f8 commit 030c796

File tree

15 files changed

+599
-403
lines changed

15 files changed

+599
-403
lines changed

src/hpc_connect/config.py

Lines changed: 71 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -5,139 +5,44 @@
55
import logging
66
import math
77
import os
8-
import shlex
9-
import shutil
108
import sys
119
from collections.abc import ValuesView
1210
from functools import cached_property
1311
from typing import IO
1412
from typing import Any
1513

16-
import pluggy
17-
import psutil
14+
import schema
1815
import yaml
19-
from schema import Optional
20-
from schema import Or
21-
from schema import Schema
22-
from schema import Use
2316

17+
from .discover import default_resource_set
2418
from .pluginmanager import HPCConnectPluginManager
19+
from .schemas import config_schema
20+
from .schemas import environment_variable_schema
21+
from .schemas import launch_schema
22+
from .schemas import machine_schema
23+
from .schemas import submit_schema
2524
from .util import collections
2625
from .util import safe_loads
2726
from .util.string import strip_quotes
2827

2928
logger = logging.getLogger("hpc_connect")
3029

31-
32-
def flag_splitter(arg: list[str] | str) -> list[str]:
33-
if isinstance(arg, str):
34-
return shlex.split(arg)
35-
elif not isinstance(arg, list) and not all(isinstance(str, _) for _ in arg):
36-
raise ValueError("expected list[str]")
37-
return arg
38-
39-
40-
def dict_str_str(arg: Any) -> bool:
41-
f = isinstance
42-
return f(arg, dict) and all([f(_, str) for k, v in arg.items() for _ in (k, v)])
43-
44-
45-
class choose_from:
46-
def __init__(self, *choices: str | None):
47-
self.choices = set(choices)
48-
49-
def __call__(self, arg: str | None) -> str | None:
50-
if arg not in self.choices:
51-
raise ValueError(f"Invalid choice {arg!r}, choose from {self.choices!r}")
52-
return arg
53-
54-
55-
def which(arg: str) -> str:
56-
if path := shutil.which(arg):
57-
return path
58-
logger.debug(f"{arg} not found on PATH")
59-
return arg
60-
61-
62-
# Resource spec have the following form:
63-
# machine:
64-
# resources:
65-
# - type: node
66-
# count: node_count
67-
# resources:
68-
# - type: socket
69-
# count: sockets_per_node
70-
# resources:
71-
# - type: resource_name (like cpus)
72-
# count: type_per_socket
73-
# additional_properties: (optional)
74-
# - type: slots
75-
# count: 1
76-
77-
resource_spec = {
78-
"type": "node",
79-
"count": int,
80-
Optional("additional_properties"): Or(dict, None),
81-
"resources": [
82-
{
83-
"type": str,
84-
"count": int,
85-
Optional("additional_properties"): Or(dict, None),
86-
Optional("resources"): [
87-
{
88-
"type": str,
89-
"count": int,
90-
Optional("additional_properties"): Or(dict, None),
91-
},
92-
],
93-
},
94-
],
30+
section_schemas: dict[str, schema.Schema] = {
31+
"config": config_schema,
32+
"machine": machine_schema,
33+
"submit": submit_schema,
34+
"launch": launch_schema,
9535
}
9636

9737

98-
launch_spec = {
99-
Optional("numproc_flag"): str,
100-
Optional("default_options"): Use(flag_splitter),
101-
Optional("local_options"): Use(flag_splitter),
102-
Optional("pre_options"): Use(flag_splitter),
103-
Optional("mappings"): dict_str_str,
104-
}
105-
106-
schema = Schema(
107-
{
108-
"hpc_connect": {
109-
Optional("config"): {
110-
Optional("debug"): bool,
111-
},
112-
Optional("submit"): {
113-
Optional("backend"): Use(
114-
choose_from(None, "shell", "slurm", "sbatch", "pbs", "qsub", "flux")
115-
),
116-
Optional("default_options"): Use(flag_splitter),
117-
Optional(str): {
118-
Optional("default_options"): Use(flag_splitter),
119-
},
120-
},
121-
Optional("machine"): {
122-
Optional("resources"): Or([resource_spec], None),
123-
},
124-
Optional("launch"): {
125-
Optional("exec"): Use(which),
126-
**launch_spec,
127-
Optional(str): launch_spec,
128-
},
129-
}
130-
},
131-
ignore_extra_keys=True,
132-
description="HPC connect configuration schema",
133-
)
134-
135-
13638
class ConfigScope:
13739
def __init__(self, name: str, file: str | None, data: dict[str, Any]) -> None:
13840
self.name = name
13941
self.file = file
140-
self.data = schema.validate({"hpc_connect": data})["hpc_connect"]
42+
self.data: dict[str, Any] = {}
43+
for section, data in data.items():
44+
schema = section_schemas[section]
45+
self.data[section] = schema.validate(data)
14146

14247
def __repr__(self):
14348
file = self.file or "<none>"
@@ -151,44 +56,50 @@ def __eq__(self, other):
15156
def __iter__(self):
15257
return iter(self.data)
15358

59+
def __contains__(self, section: str) -> bool:
60+
return section in self.data
61+
15462
def get_section(self, section: str) -> Any:
15563
return self.data.get(section)
15664

65+
def pop_section(self, section: str) -> Any:
66+
return self.data.pop(section, None)
67+
15768
def dump(self) -> None:
15869
if self.file is None:
15970
return
16071
with open(self.file, "w") as fh:
16172
yaml.dump({"hpc_connect": self.data}, fh, default_flow_style=False)
16273

16374

164-
config_defaults = {
165-
"config": {
166-
"debug": False,
167-
},
168-
"machine": {
169-
"resources": None,
170-
},
171-
"submit": {
172-
"backend": None,
173-
"default_options": [],
174-
},
175-
"launch": {
176-
"exec": "mpiexec",
177-
"numproc_flag": "-n",
178-
"default_options": [],
179-
"local_options": [],
180-
"pre_options": [],
181-
"mappings": {},
182-
},
183-
}
184-
185-
18675
class Config:
18776
def __init__(self) -> None:
188-
self.pluginmanager: pluggy.PluginManager = HPCConnectPluginManager()
189-
self.scopes: dict[str, ConfigScope] = {
190-
"defaults": ConfigScope("defaults", None, config_defaults)
77+
self.pluginmanager: HPCConnectPluginManager = HPCConnectPluginManager()
78+
rspec = self.pluginmanager.hook.hpc_connect_discover_resources()
79+
defaults = {
80+
"config": {
81+
"debug": False,
82+
"plugins": [],
83+
},
84+
"machine": {
85+
"resources": rspec,
86+
},
87+
"submit": {
88+
"backend": None,
89+
"default_options": [],
90+
},
91+
"launch": {
92+
"exec": "mpiexec",
93+
"numproc_flag": "-n",
94+
"default_options": [],
95+
"local_options": [],
96+
"pre_options": [],
97+
"mappings": {},
98+
},
19199
}
100+
self.scopes: dict[str, ConfigScope] = {}
101+
default_scope = ConfigScope("defaults", None, defaults)
102+
self.push_scope(default_scope)
192103
for scope in ("site", "global", "local"):
193104
config_scope = read_config_scope(scope)
194105
self.push_scope(config_scope)
@@ -202,6 +113,10 @@ def read_only_scope(self, scope: str) -> bool:
202113

203114
def push_scope(self, scope: ConfigScope) -> None:
204115
self.scopes[scope.name] = scope
116+
if cfg := scope.get_section("config"):
117+
if plugins := cfg.get("plugins"):
118+
for f in plugins:
119+
self.pluginmanager.consider_plugin(f)
205120

206121
def pop_scope(self, scope: ConfigScope) -> ConfigScope | None:
207122
return self.scopes.pop(scope.name, None)
@@ -235,6 +150,14 @@ def get(self, path: str, default: Any = None, scope: str | None = None) -> Any:
235150
value = value[key]
236151
return value
237152

153+
def get_highest_priority(self, path: str, default: Any = None) -> tuple[Any, str]:
154+
sentinel = object()
155+
for scope in reversed(self.scopes.keys()):
156+
value = self.get(path, default=sentinel, scope=scope)
157+
if value is not sentinel:
158+
return value, scope
159+
return default, "none"
160+
238161
def set(self, path: str, value: Any, scope: str | None = None) -> None:
239162
parts = process_config_path(path)
240163
section = parts.pop(0)
@@ -336,18 +259,12 @@ def set_main_options(self, args: argparse.Namespace) -> None:
336259

337260
@property
338261
def resource_specs(self) -> list[dict]:
339-
from .submit import factory
340-
341-
if resource_specs := self.get("machine:resources"):
342-
return resource_specs
343-
if self.get("submit:backend"):
344-
# backend may set resources
345-
factory(config=self)
346-
if resource_specs := self.get("machine:resources"):
347-
return resource_specs
348-
resource_specs = default_resource_spec()
349-
self.set("machine:resources", resource_specs, scope="defaults")
350-
return resource_specs
262+
specs, _ = self.get_highest_priority("machine:resources")
263+
if specs is not None:
264+
return specs
265+
resources = default_resource_set()
266+
self.set("machine:resources", specs, scope="defaults")
267+
return resources
351268

352269
def resource_types(self) -> list[str]:
353270
"""Return the types of resources available"""
@@ -486,20 +403,15 @@ def compute_required_resources(
486403
return reqd_resources
487404

488405
def dump(self, stream: IO[Any], scope: str | None = None, **kwargs: Any) -> None:
489-
from .submit import factory
490-
491-
# initialize the resource spec
492-
if self.get("machine:resources") is None:
493-
if self.get("submit:backend"):
494-
factory(self)
495-
if not self.get("machine:resources"):
496-
self.set("machine:resources", default_resource_spec(), scope="defaults")
497406
data: dict[str, Any] = {}
498407
for section in self.scopes["defaults"]:
408+
if section == "machine":
409+
continue
499410
section_data = self.get_config(section, scope=scope)
500411
if not section_data and scope is not None:
501412
continue
502413
data[section] = section_data
414+
data.setdefault("machine", {})["resources"] = self.resource_specs
503415
yaml.dump({"hpc_connect": data}, stream, **kwargs)
504416

505417

@@ -532,32 +444,10 @@ def get_scope_filename(scope: str) -> str | None:
532444

533445

534446
def read_env_config() -> ConfigScope | None:
535-
def load_mappings(arg: str) -> dict[str, str]:
536-
mappings: dict[str, str] = {}
537-
for kv in arg.split(","):
538-
k, v = [_.strip() for _ in kv.split(":") if _.split()]
539-
mappings[k] = v
540-
return mappings
541-
542-
data: dict[str, Any] = {}
543-
for var in os.environ:
544-
if not var.startswith("HPCC_"):
545-
continue
546-
try:
547-
section, *parts = var[5:].lower().split("_")
548-
key = "_".join(parts)
549-
except ValueError:
550-
continue
551-
if section not in config_defaults:
552-
continue
553-
value: Any
554-
if key == "mappings":
555-
value = load_mappings(os.environ[var])
556-
else:
557-
value = safe_loads(os.environ[var])
558-
data.setdefault(section, {}).update({key: value})
559-
if not data:
447+
variables = {key: var for key, var in os.environ.items() if key.startswith("HPC_CONNECT_")}
448+
if not variables:
560449
return None
450+
data = environment_variable_schema.validate(variables)
561451
return ConfigScope("environment", None, data)
562452

563453

@@ -609,25 +499,3 @@ def set_logging_level(levelname: str) -> None:
609499
for h in logger.handlers:
610500
h.setLevel(level)
611501
logger.setLevel(level)
612-
613-
614-
def default_resource_spec() -> list[dict]:
615-
resource_spec: list[dict] = [
616-
{
617-
"type": "node",
618-
"count": 1,
619-
"resources": [
620-
{
621-
"type": "socket",
622-
"count": 1,
623-
"resources": [
624-
{
625-
"type": "cpu",
626-
"count": psutil.cpu_count(),
627-
},
628-
],
629-
},
630-
],
631-
}
632-
]
633-
return resource_spec

src/hpc_connect/discover.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright NTESS. See COPYRIGHT file for details.
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import fnmatch
6+
import json
7+
import os
8+
from typing import Any
9+
10+
import psutil
11+
12+
from .hookspec import hookimpl
13+
14+
15+
def default_resource_set() -> list[dict[str, Any]]:
16+
local_resource = {"type": "cpu", "count": psutil.cpu_count()}
17+
socket_resource = {"type": "socket", "count": 1, "resources": [local_resource]}
18+
return [{"type": "node", "count": 1, "resources": [socket_resource]}]
19+
20+
21+
@hookimpl(tryfirst=True, specname="hpc_connect_discover_resources")
22+
def read_resources_from_hostfile() -> dict[str, list] | None:
23+
if file := os.getenv("HPC_CONNECT_HOSTFILE"):
24+
with open(file) as fh:
25+
data = json.load(fh)
26+
host: str = os.getenv("HPC_CONNECT_HOSTNAME") or os.uname().nodename
27+
for pattern, rspec in data.items():
28+
if fnmatch.fnmatch(host, pattern):
29+
return rspec
30+
return None

0 commit comments

Comments
 (0)