55import logging
66import math
77import os
8- import shlex
9- import shutil
108import sys
119from collections .abc import ValuesView
1210from functools import cached_property
1311from typing import IO
1412from typing import Any
1513
16- import pluggy
17- import psutil
14+ import schema
1815import yaml
19- from schema import Optional
20- from schema import Or
21- from schema import Schema
22- from schema import Use
2316
17+ from .discover import default_resource_set
2418from .pluginmanager import HPCConnectPluginManager
19+ from .schemas import config_schema
20+ from .schemas import environment_variable_schema
21+ from .schemas import launch_schema
22+ from .schemas import machine_schema
23+ from .schemas import submit_schema
2524from .util import collections
2625from .util import safe_loads
2726from .util .string import strip_quotes
2827
2928logger = logging .getLogger ("hpc_connect" )
3029
31-
32- def flag_splitter (arg : list [str ] | str ) -> list [str ]:
33- if isinstance (arg , str ):
34- return shlex .split (arg )
35- elif not isinstance (arg , list ) and not all (isinstance (str , _ ) for _ in arg ):
36- raise ValueError ("expected list[str]" )
37- return arg
38-
39-
40- def dict_str_str (arg : Any ) -> bool :
41- f = isinstance
42- return f (arg , dict ) and all ([f (_ , str ) for k , v in arg .items () for _ in (k , v )])
43-
44-
45- class choose_from :
46- def __init__ (self , * choices : str | None ):
47- self .choices = set (choices )
48-
49- def __call__ (self , arg : str | None ) -> str | None :
50- if arg not in self .choices :
51- raise ValueError (f"Invalid choice { arg !r} , choose from { self .choices !r} " )
52- return arg
53-
54-
55- def which (arg : str ) -> str :
56- if path := shutil .which (arg ):
57- return path
58- logger .debug (f"{ arg } not found on PATH" )
59- return arg
60-
61-
62- # Resource spec have the following form:
63- # machine:
64- # resources:
65- # - type: node
66- # count: node_count
67- # resources:
68- # - type: socket
69- # count: sockets_per_node
70- # resources:
71- # - type: resource_name (like cpus)
72- # count: type_per_socket
73- # additional_properties: (optional)
74- # - type: slots
75- # count: 1
76-
77- resource_spec = {
78- "type" : "node" ,
79- "count" : int ,
80- Optional ("additional_properties" ): Or (dict , None ),
81- "resources" : [
82- {
83- "type" : str ,
84- "count" : int ,
85- Optional ("additional_properties" ): Or (dict , None ),
86- Optional ("resources" ): [
87- {
88- "type" : str ,
89- "count" : int ,
90- Optional ("additional_properties" ): Or (dict , None ),
91- },
92- ],
93- },
94- ],
30+ section_schemas : dict [str , schema .Schema ] = {
31+ "config" : config_schema ,
32+ "machine" : machine_schema ,
33+ "submit" : submit_schema ,
34+ "launch" : launch_schema ,
9535}
9636
9737
98- launch_spec = {
99- Optional ("numproc_flag" ): str ,
100- Optional ("default_options" ): Use (flag_splitter ),
101- Optional ("local_options" ): Use (flag_splitter ),
102- Optional ("pre_options" ): Use (flag_splitter ),
103- Optional ("mappings" ): dict_str_str ,
104- }
105-
106- schema = Schema (
107- {
108- "hpc_connect" : {
109- Optional ("config" ): {
110- Optional ("debug" ): bool ,
111- },
112- Optional ("submit" ): {
113- Optional ("backend" ): Use (
114- choose_from (None , "shell" , "slurm" , "sbatch" , "pbs" , "qsub" , "flux" )
115- ),
116- Optional ("default_options" ): Use (flag_splitter ),
117- Optional (str ): {
118- Optional ("default_options" ): Use (flag_splitter ),
119- },
120- },
121- Optional ("machine" ): {
122- Optional ("resources" ): Or ([resource_spec ], None ),
123- },
124- Optional ("launch" ): {
125- Optional ("exec" ): Use (which ),
126- ** launch_spec ,
127- Optional (str ): launch_spec ,
128- },
129- }
130- },
131- ignore_extra_keys = True ,
132- description = "HPC connect configuration schema" ,
133- )
134-
135-
13638class ConfigScope :
13739 def __init__ (self , name : str , file : str | None , data : dict [str , Any ]) -> None :
13840 self .name = name
13941 self .file = file
140- self .data = schema .validate ({"hpc_connect" : data })["hpc_connect" ]
42+ self .data : dict [str , Any ] = {}
43+ for section , data in data .items ():
44+ schema = section_schemas [section ]
45+ self .data [section ] = schema .validate (data )
14146
14247 def __repr__ (self ):
14348 file = self .file or "<none>"
@@ -151,44 +56,50 @@ def __eq__(self, other):
15156 def __iter__ (self ):
15257 return iter (self .data )
15358
59+ def __contains__ (self , section : str ) -> bool :
60+ return section in self .data
61+
15462 def get_section (self , section : str ) -> Any :
15563 return self .data .get (section )
15664
65+ def pop_section (self , section : str ) -> Any :
66+ return self .data .pop (section , None )
67+
15768 def dump (self ) -> None :
15869 if self .file is None :
15970 return
16071 with open (self .file , "w" ) as fh :
16172 yaml .dump ({"hpc_connect" : self .data }, fh , default_flow_style = False )
16273
16374
164- config_defaults = {
165- "config" : {
166- "debug" : False ,
167- },
168- "machine" : {
169- "resources" : None ,
170- },
171- "submit" : {
172- "backend" : None ,
173- "default_options" : [],
174- },
175- "launch" : {
176- "exec" : "mpiexec" ,
177- "numproc_flag" : "-n" ,
178- "default_options" : [],
179- "local_options" : [],
180- "pre_options" : [],
181- "mappings" : {},
182- },
183- }
184-
185-
18675class Config :
18776 def __init__ (self ) -> None :
188- self .pluginmanager : pluggy .PluginManager = HPCConnectPluginManager ()
189- self .scopes : dict [str , ConfigScope ] = {
190- "defaults" : ConfigScope ("defaults" , None , config_defaults )
77+ self .pluginmanager : HPCConnectPluginManager = HPCConnectPluginManager ()
78+ rspec = self .pluginmanager .hook .hpc_connect_discover_resources ()
79+ defaults = {
80+ "config" : {
81+ "debug" : False ,
82+ "plugins" : [],
83+ },
84+ "machine" : {
85+ "resources" : rspec ,
86+ },
87+ "submit" : {
88+ "backend" : None ,
89+ "default_options" : [],
90+ },
91+ "launch" : {
92+ "exec" : "mpiexec" ,
93+ "numproc_flag" : "-n" ,
94+ "default_options" : [],
95+ "local_options" : [],
96+ "pre_options" : [],
97+ "mappings" : {},
98+ },
19199 }
100+ self .scopes : dict [str , ConfigScope ] = {}
101+ default_scope = ConfigScope ("defaults" , None , defaults )
102+ self .push_scope (default_scope )
192103 for scope in ("site" , "global" , "local" ):
193104 config_scope = read_config_scope (scope )
194105 self .push_scope (config_scope )
@@ -202,6 +113,10 @@ def read_only_scope(self, scope: str) -> bool:
202113
203114 def push_scope (self , scope : ConfigScope ) -> None :
204115 self .scopes [scope .name ] = scope
116+ if cfg := scope .get_section ("config" ):
117+ if plugins := cfg .get ("plugins" ):
118+ for f in plugins :
119+ self .pluginmanager .consider_plugin (f )
205120
206121 def pop_scope (self , scope : ConfigScope ) -> ConfigScope | None :
207122 return self .scopes .pop (scope .name , None )
@@ -235,6 +150,14 @@ def get(self, path: str, default: Any = None, scope: str | None = None) -> Any:
235150 value = value [key ]
236151 return value
237152
153+ def get_highest_priority (self , path : str , default : Any = None ) -> tuple [Any , str ]:
154+ sentinel = object ()
155+ for scope in reversed (self .scopes .keys ()):
156+ value = self .get (path , default = sentinel , scope = scope )
157+ if value is not sentinel :
158+ return value , scope
159+ return default , "none"
160+
238161 def set (self , path : str , value : Any , scope : str | None = None ) -> None :
239162 parts = process_config_path (path )
240163 section = parts .pop (0 )
@@ -336,18 +259,12 @@ def set_main_options(self, args: argparse.Namespace) -> None:
336259
337260 @property
338261 def resource_specs (self ) -> list [dict ]:
339- from .submit import factory
340-
341- if resource_specs := self .get ("machine:resources" ):
342- return resource_specs
343- if self .get ("submit:backend" ):
344- # backend may set resources
345- factory (config = self )
346- if resource_specs := self .get ("machine:resources" ):
347- return resource_specs
348- resource_specs = default_resource_spec ()
349- self .set ("machine:resources" , resource_specs , scope = "defaults" )
350- return resource_specs
262+ specs , _ = self .get_highest_priority ("machine:resources" )
263+ if specs is not None :
264+ return specs
265+ resources = default_resource_set ()
266+ self .set ("machine:resources" , specs , scope = "defaults" )
267+ return resources
351268
352269 def resource_types (self ) -> list [str ]:
353270 """Return the types of resources available"""
@@ -486,20 +403,15 @@ def compute_required_resources(
486403 return reqd_resources
487404
488405 def dump (self , stream : IO [Any ], scope : str | None = None , ** kwargs : Any ) -> None :
489- from .submit import factory
490-
491- # initialize the resource spec
492- if self .get ("machine:resources" ) is None :
493- if self .get ("submit:backend" ):
494- factory (self )
495- if not self .get ("machine:resources" ):
496- self .set ("machine:resources" , default_resource_spec (), scope = "defaults" )
497406 data : dict [str , Any ] = {}
498407 for section in self .scopes ["defaults" ]:
408+ if section == "machine" :
409+ continue
499410 section_data = self .get_config (section , scope = scope )
500411 if not section_data and scope is not None :
501412 continue
502413 data [section ] = section_data
414+ data .setdefault ("machine" , {})["resources" ] = self .resource_specs
503415 yaml .dump ({"hpc_connect" : data }, stream , ** kwargs )
504416
505417
@@ -532,32 +444,10 @@ def get_scope_filename(scope: str) -> str | None:
532444
533445
534446def read_env_config () -> ConfigScope | None :
535- def load_mappings (arg : str ) -> dict [str , str ]:
536- mappings : dict [str , str ] = {}
537- for kv in arg .split ("," ):
538- k , v = [_ .strip () for _ in kv .split (":" ) if _ .split ()]
539- mappings [k ] = v
540- return mappings
541-
542- data : dict [str , Any ] = {}
543- for var in os .environ :
544- if not var .startswith ("HPCC_" ):
545- continue
546- try :
547- section , * parts = var [5 :].lower ().split ("_" )
548- key = "_" .join (parts )
549- except ValueError :
550- continue
551- if section not in config_defaults :
552- continue
553- value : Any
554- if key == "mappings" :
555- value = load_mappings (os .environ [var ])
556- else :
557- value = safe_loads (os .environ [var ])
558- data .setdefault (section , {}).update ({key : value })
559- if not data :
447+ variables = {key : var for key , var in os .environ .items () if key .startswith ("HPC_CONNECT_" )}
448+ if not variables :
560449 return None
450+ data = environment_variable_schema .validate (variables )
561451 return ConfigScope ("environment" , None , data )
562452
563453
@@ -609,25 +499,3 @@ def set_logging_level(levelname: str) -> None:
609499 for h in logger .handlers :
610500 h .setLevel (level )
611501 logger .setLevel (level )
612-
613-
614- def default_resource_spec () -> list [dict ]:
615- resource_spec : list [dict ] = [
616- {
617- "type" : "node" ,
618- "count" : 1 ,
619- "resources" : [
620- {
621- "type" : "socket" ,
622- "count" : 1 ,
623- "resources" : [
624- {
625- "type" : "cpu" ,
626- "count" : psutil .cpu_count (),
627- },
628- ],
629- },
630- ],
631- }
632- ]
633- return resource_spec
0 commit comments