1919
2020
2121def _resolve_base_dn (full_username ):
22- if ',dc=' in full_username :
23- base_dn_start = full_username .find ('dc=' )
24- return full_username [base_dn_start :]
22+ if not full_username :
23+ return ''
24+
25+ username_lower = full_username .lower ()
26+ if ',dc=' in username_lower :
27+ base_dn_start = username_lower .find ('dc=' )
28+ return username_lower [base_dn_start :]
29+
2530 elif '@' in full_username :
2631 domain_start = full_username .find ('@' ) + 1
2732 domain = full_username [domain_start :]
@@ -32,6 +37,31 @@ def _resolve_base_dn(full_username):
3237 return ''
3338
3439
40+ def _search (dn , search_filter , attributes , connection ):
41+ success = connection .search (dn , search_filter , attributes = attributes )
42+ if not success :
43+ if connection .last_error :
44+ LOGGER .warning ('ldap search failed: ' + connection .last_error
45+ + '. dn:' + dn + ', filter: ' + search_filter )
46+ return None
47+
48+ return connection .entries
49+
50+
51+ def _load_multiple_entries_values (dn , search_filter , attribute_name , connection ):
52+ entries = _search (dn , search_filter , [attribute_name ], connection )
53+ if entries is None :
54+ return []
55+
56+ result = []
57+ for entry in entries :
58+ value = entry [attribute_name ].value
59+ if value is not None :
60+ result .append (value )
61+
62+ return result
63+
64+
3565class LdapAuthenticator (auth_base .Authenticator ):
3666 def __init__ (self , params_dict , temp_folder ):
3767 super ().__init__ ()
@@ -44,19 +74,19 @@ def __init__(self, params_dict, temp_folder):
4474 else :
4575 self .username_template = None
4676
47- groups_base_dn = params_dict .get ('groups_base_dn ' )
48- if groups_base_dn :
49- self ._groups_base_dn = groups_base_dn .strip ()
77+ base_dn = params_dict .get ('base_dn ' )
78+ if base_dn :
79+ self ._base_dn = base_dn .strip ()
5080 else :
5181 resolved_base_dn = _resolve_base_dn (username_pattern )
5282
5383 if resolved_base_dn :
54- LOGGER .info ('Resolved base dn for groups : ' + resolved_base_dn )
55- self ._groups_base_dn = resolved_base_dn
84+ LOGGER .info ('Resolved base dn: ' + resolved_base_dn )
85+ self ._base_dn = resolved_base_dn
5686 else :
5787 LOGGER .warning (
58- 'Cannot resolve LDAP base dn, so using empty. Please specify it using "groups_base_dn " attribute' )
59- self ._groups_base_dn = ''
88+ 'Cannot resolve LDAP base dn, so using empty. Please specify it using "base_dn " attribute' )
89+ self ._base_dn = ''
6090
6191 self .version = params_dict .get ("version" )
6292 if not self .version :
@@ -90,7 +120,10 @@ def authenticate(self, request_handler):
90120
91121 if connection .bound :
92122 try :
93- user_groups = self ._fetch_user_groups (username , full_username , connection )
123+ user_dn , user_uid = self ._get_user_ids (full_username , connection )
124+ LOGGER .debug ('user ids: ' + str ((user_dn , user_uid )))
125+
126+ user_groups = self ._fetch_user_groups (user_dn , user_uid , connection )
94127 LOGGER .info ('Loaded groups for ' + username + ': ' + str (user_groups ))
95128 self ._set_user_groups (username , user_groups )
96129 except :
@@ -125,20 +158,49 @@ def _get_groups(self, user):
125158 def get_groups (self , user ):
126159 return self ._get_groups (user )
127160
128- def _fetch_user_groups (self , username , full_username , connection ):
129- name_attribute = 'cn'
130- base_dn = self ._groups_base_dn
161+ def _fetch_user_groups (self , user_dn , user_uid , connection ):
162+ base_dn = self ._base_dn
163+
164+ result = set ()
165+
166+ result .update (_load_multiple_entries_values (base_dn , '(member=%s)' % user_dn , 'cn' , connection ))
167+
168+ if user_uid :
169+ result .update (_load_multiple_entries_values (
170+ base_dn ,
171+ '(&(objectClass=posixGroup)(memberUid=%s))' % user_uid ,
172+ 'cn' ,
173+ connection ))
174+
175+ return sorted (list (result ))
176+
177+ def _get_user_ids (self , full_username , connection ):
178+ base_dn = self ._base_dn
179+
180+ username_lower = full_username .lower ()
181+ if ',dc=' in username_lower :
182+ base_dn = username_lower
183+ search_filter = '(objectClass=*)'
184+ elif '@' in full_username :
185+ search_filter = '(userPrincipalName=%s)' % full_username
186+ elif '\\ ' in full_username :
187+ username_index = full_username .rfind ('\\ ' ) + 1
188+ username = full_username [username_index :]
189+ search_filter = '(sAMAccountName=%s)' % username
190+ else :
191+ LOGGER .warning ('Unsupported username pattern for ' + full_username )
192+ return full_username , None
131193
132- search_filter = '(|(member=%s)(&(objectClass=posixGroup)(memberUid=%s)))' % (full_username , username )
133- success = connection .search (base_dn , search_filter , attributes = [name_attribute ])
134- if not success :
135- return []
194+ entries = _search (base_dn , search_filter , ['uid' ], connection )
195+ if not entries :
196+ return full_username , None
136197
137- group_entries = connection . response
138- if not group_entries :
139- return []
198+ if len ( entries ) > 1 :
199+ LOGGER . warning ( 'More than one user found by filter: ' + search_filter )
200+ return full_username , None
140201
141- return [entry ['attributes' ][name_attribute ][0 ] for entry in group_entries ]
202+ entry = entries [0 ]
203+ return entry .entry_dn , entry .uid .value
142204
143205 def _load_groups (self , groups_file ):
144206 if not os .path .exists (groups_file ):
@@ -150,5 +212,5 @@ def _load_groups(self, groups_file):
150212 def _set_user_groups (self , user , groups ):
151213 self ._user_groups [user ] = groups
152214
153- with open (self ._groups_file , 'w' ) as fd :
154- json . dump (self ._user_groups , fd , indent = 2 )
215+ new_groups_content = json . dumps (self ._user_groups , indent = 2 )
216+ file_utils . write_file (self ._groups_file , new_groups_content )
0 commit comments