2222# cython: c_string_type=unicode, c_string_encoding=default
2323# cython: language_level=3
2424
25- from pyslurm.core.error import RPCError
25+ from pyslurm.core.error import RPCError, verify_rpc
2626from pyslurm.utils.helpers import (
2727 instance_to_dict,
2828 user_to_uid,
@@ -33,6 +33,47 @@ from pyslurm import settings
3333from pyslurm import xcollections
3434
3535
36+ cdef class AssociationList(SlurmList):
37+
38+ def __init__ (self , owned = True ):
39+ self .info = slurm.slurm_list_create(slurm.slurmdb_destroy_assoc_rec)
40+ self .owned = owned
41+
42+ def append (self , Association assoc ):
43+ slurm.slurm_list_append(self .info, assoc.ptr)
44+ assoc.owned = False
45+ self .cnt = slurm.slurm_list_count(self .info)
46+
47+ def __iter__ (self ):
48+ return super ().__iter__()
49+
50+ def __next__ (self ):
51+ if self .is_null or self .is_itr_null:
52+ raise StopIteration
53+
54+ if self .itr_cnt < self .cnt:
55+ self .itr_cnt += 1
56+ assoc = Association.from_ptr(< slurmdb_assoc_rec_t* > slurm.slurm_list_next(self .itr))
57+ assoc.owned = False
58+ return assoc
59+
60+ self ._dealloc_itr()
61+ raise StopIteration
62+
63+ def extend (self , list_in ):
64+ for item in list_in:
65+ self .append(< Association> item)
66+
67+
68+ class AssociationModifyResponse :
69+
70+ def __init__ (self , user = None , account = None , cluster = None , partition = None ):
71+ self .user = user
72+ self .account = account
73+ self .cluster = cluster
74+ self .partition = partition
75+
76+
3677cdef class Associations(MultiClusterMap):
3778
3879 def __init__ (self , assocs = None ):
@@ -63,7 +104,7 @@ cdef class Associations(MultiClusterMap):
63104 conn.ptr, cond.ptr))
64105
65106 if assoc_data.is_null:
66- raise RPCError(msg = " Failed to get Association data from slurmdbd" )
107+ raise RPCError(msg = " Failed to get Association data from slurmdbd. " )
67108
68109 # Fetch other necessary dependencies needed for translating some
69110 # attributes (i.e QoS IDs to its name)
@@ -104,16 +145,13 @@ cdef class Associations(MultiClusterMap):
104145 afilter = < AssociationFilter> db_filter
105146 afilter._create()
106147
107- # Setup DB conn
108148 conn = _open_conn_or_error(db_connection)
109149
110150 # Any data that isn't parsed yet or needs validation is done in this
111151 # function.
112152 _create_assoc_ptr(changes, conn)
113153
114- # Modify associations, get the result
115- # This returns a List of char* with the associations that were
116- # modified
154+ # Returns a List of char* with the associations that were modified
117155 response = SlurmList.wrap(slurmdb_associations_modify(
118156 conn.ptr, afilter.ptr, changes.ptr))
119157
@@ -128,7 +166,7 @@ cdef class Associations(MultiClusterMap):
128166
129167 elif not response.is_null:
130168 # There was no real error, but simply nothing has been modified
131- raise RPCError( msg = " Nothing was modified " )
169+ return None
132170 else :
133171 # Autodetects the last slurm error
134172 raise RPCError()
@@ -139,6 +177,30 @@ cdef class Associations(MultiClusterMap):
139177
140178 return out
141179
180+ @staticmethod
181+ def create (associations , Connection db_connection = None ):
182+ cdef:
183+ Connection conn
184+ Association assoc
185+ AssociationList assoc_list = AssociationList(owned = False )
186+
187+ if not associations:
188+ return
189+
190+ for i, assoc in enumerate (associations):
191+ # Make sure to remove any duplicate associations, i.e. associations
192+ # having the same account name set. For some reason, the slurmdbd
193+ # doesn't like that.
194+ if assoc not in assoc_list:
195+ assoc_list.append(assoc)
196+
197+ conn = _open_conn_or_error(db_connection)
198+ verify_rpc(slurmdb_associations_add(conn.ptr, assoc_list.info))
199+
200+ if not db_connection:
201+ # Autocommit if no connection was explicitly specified.
202+ conn.commit()
203+
142204
143205cdef class AssociationFilter:
144206
@@ -172,19 +234,21 @@ cdef class AssociationFilter:
172234 cdef slurmdb_assoc_cond_t * ptr = self .ptr
173235
174236 make_char_list(& ptr.user_list, self .users)
175- make_char_list(& ptr.user_list , self .ids)
237+ make_char_list(& ptr.id_list , self .ids)
176238 make_char_list(& ptr.acct_list, self .accounts)
177239 make_char_list(& ptr.parent_acct_list, self .parent_accounts)
178240 make_char_list(& ptr.cluster_list, self .clusters)
179241 make_char_list(& ptr.partition_list, self .partitions)
180- # TODO: These are QOS ids, not names
242+ # TODO: These should be QOS ids, not names
181243 make_char_list(& ptr.qos_list, self .qos)
244+ # TODO: ASSOC_COND_FLAGS
182245
183246
184247cdef class Association:
185248
186249 def __cinit__ (self ):
187250 self .ptr = NULL
251+ self .owned = True
188252
189253 def __init__ (self , **kwargs ):
190254 self ._alloc_impl()
@@ -194,7 +258,8 @@ cdef class Association:
194258 setattr (self , k, v)
195259
196260 def __dealloc__ (self ):
197- self ._dealloc_impl()
261+ if self .owned:
262+ self ._dealloc_impl()
198263
199264 def _dealloc_impl (self ):
200265 slurmdb_destroy_assoc_rec(self .ptr)
@@ -228,7 +293,8 @@ cdef class Association:
228293
229294 def __eq__ (self , other ):
230295 if isinstance (other, Association):
231- return self .id == other.id and self .cluster == other.cluster
296+ # return self.id == other.id and self.cluster == other.cluster
297+ return self .cluster == other.cluster and self .partition == other.partition and self .account == other.account and self .user == other.user
232298 return NotImplemented
233299
234300 @property
@@ -351,6 +417,10 @@ cdef class Association:
351417 def parent_account_id (self ):
352418 return u32_parse(self .ptr.parent_id, zero_is_noval = False )
353419
420+ @property
421+ def lineage (self ):
422+ return cstr.to_unicode(self .ptr.lineage)
423+
354424 @property
355425 def partition (self ):
356426 return cstr.to_unicode(self .ptr.partition)
@@ -383,6 +453,10 @@ cdef class Association:
383453 def user (self , val ):
384454 cstr.fmalloc(& self .ptr.user, val)
385455
456+ @property
457+ def user_id (self ):
458+ return u32_parse(self .ptr.uid, zero_is_noval = False )
459+
386460
387461cdef _parse_assoc_ptr(Association ass):
388462 cdef:
@@ -397,13 +471,15 @@ cdef _parse_assoc_ptr(Association ass):
397471 ass.ptr.grp_tres_mins, tres)
398472 ass.max_tres_mins_per_job = TrackableResourceLimits.from_ids(
399473 ass.ptr.max_tres_mins_pj, tres)
474+ # TODO rename, remove _per_user
400475 ass.max_tres_run_mins_per_user = TrackableResourceLimits.from_ids(
401476 ass.ptr.max_tres_run_mins, tres)
402477 ass.max_tres_per_job = TrackableResourceLimits.from_ids(
403478 ass.ptr.max_tres_pj, tres)
404479 ass.max_tres_per_node = TrackableResourceLimits.from_ids(
405480 ass.ptr.max_tres_pn, tres)
406481 ass.qos = qos_list_to_pylist(ass.ptr.qos_list, qos)
482+ # TODO: default_qos
407483
408484
409485cdef _create_assoc_ptr(Association ass, conn = None ):
0 commit comments