Skip to content

Commit a8195db

Browse files
committed
#118 added LDAP groups support for sAMAccountName/userPrincipalName patterns
1 parent 507dd44 commit a8195db

File tree

1 file changed

+86
-24
lines changed

1 file changed

+86
-24
lines changed

src/auth/auth_ldap.py

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@
1919

2020

2121
def _resolve_base_dn(full_username):
22-
if ',dc=' in full_username:
23-
base_dn_start = full_username.find('dc=')
24-
return full_username[base_dn_start:]
22+
if not full_username:
23+
return ''
24+
25+
username_lower = full_username.lower()
26+
if ',dc=' in username_lower:
27+
base_dn_start = username_lower.find('dc=')
28+
return username_lower[base_dn_start:]
29+
2530
elif '@' in full_username:
2631
domain_start = full_username.find('@') + 1
2732
domain = full_username[domain_start:]
@@ -32,6 +37,31 @@ def _resolve_base_dn(full_username):
3237
return ''
3338

3439

40+
def _search(dn, search_filter, attributes, connection):
41+
success = connection.search(dn, search_filter, attributes=attributes)
42+
if not success:
43+
if connection.last_error:
44+
LOGGER.warning('ldap search failed: ' + connection.last_error
45+
+ '. dn:' + dn + ', filter: ' + search_filter)
46+
return None
47+
48+
return connection.entries
49+
50+
51+
def _load_multiple_entries_values(dn, search_filter, attribute_name, connection):
52+
entries = _search(dn, search_filter, [attribute_name], connection)
53+
if entries is None:
54+
return []
55+
56+
result = []
57+
for entry in entries:
58+
value = entry[attribute_name].value
59+
if value is not None:
60+
result.append(value)
61+
62+
return result
63+
64+
3565
class LdapAuthenticator(auth_base.Authenticator):
3666
def __init__(self, params_dict, temp_folder):
3767
super().__init__()
@@ -44,19 +74,19 @@ def __init__(self, params_dict, temp_folder):
4474
else:
4575
self.username_template = None
4676

47-
groups_base_dn = params_dict.get('groups_base_dn')
48-
if groups_base_dn:
49-
self._groups_base_dn = groups_base_dn.strip()
77+
base_dn = params_dict.get('base_dn')
78+
if base_dn:
79+
self._base_dn = base_dn.strip()
5080
else:
5181
resolved_base_dn = _resolve_base_dn(username_pattern)
5282

5383
if resolved_base_dn:
54-
LOGGER.info('Resolved base dn for groups: ' + resolved_base_dn)
55-
self._groups_base_dn = resolved_base_dn
84+
LOGGER.info('Resolved base dn: ' + resolved_base_dn)
85+
self._base_dn = resolved_base_dn
5686
else:
5787
LOGGER.warning(
58-
'Cannot resolve LDAP base dn, so using empty. Please specify it using "groups_base_dn" attribute')
59-
self._groups_base_dn = ''
88+
'Cannot resolve LDAP base dn, so using empty. Please specify it using "base_dn" attribute')
89+
self._base_dn = ''
6090

6191
self.version = params_dict.get("version")
6292
if not self.version:
@@ -90,7 +120,10 @@ def authenticate(self, request_handler):
90120

91121
if connection.bound:
92122
try:
93-
user_groups = self._fetch_user_groups(username, full_username, connection)
123+
user_dn, user_uid = self._get_user_ids(full_username, connection)
124+
LOGGER.debug('user ids: ' + str((user_dn, user_uid)))
125+
126+
user_groups = self._fetch_user_groups(user_dn, user_uid, connection)
94127
LOGGER.info('Loaded groups for ' + username + ': ' + str(user_groups))
95128
self._set_user_groups(username, user_groups)
96129
except:
@@ -125,20 +158,49 @@ def _get_groups(self, user):
125158
def get_groups(self, user):
126159
return self._get_groups(user)
127160

128-
def _fetch_user_groups(self, username, full_username, connection):
129-
name_attribute = 'cn'
130-
base_dn = self._groups_base_dn
161+
def _fetch_user_groups(self, user_dn, user_uid, connection):
162+
base_dn = self._base_dn
163+
164+
result = set()
165+
166+
result.update(_load_multiple_entries_values(base_dn, '(member=%s)' % user_dn, 'cn', connection))
167+
168+
if user_uid:
169+
result.update(_load_multiple_entries_values(
170+
base_dn,
171+
'(&(objectClass=posixGroup)(memberUid=%s))' % user_uid,
172+
'cn',
173+
connection))
174+
175+
return sorted(list(result))
176+
177+
def _get_user_ids(self, full_username, connection):
178+
base_dn = self._base_dn
179+
180+
username_lower = full_username.lower()
181+
if ',dc=' in username_lower:
182+
base_dn = username_lower
183+
search_filter = '(objectClass=*)'
184+
elif '@' in full_username:
185+
search_filter = '(userPrincipalName=%s)' % full_username
186+
elif '\\' in full_username:
187+
username_index = full_username.rfind('\\') + 1
188+
username = full_username[username_index:]
189+
search_filter = '(sAMAccountName=%s)' % username
190+
else:
191+
LOGGER.warning('Unsupported username pattern for ' + full_username)
192+
return full_username, None
131193

132-
search_filter = '(|(member=%s)(&(objectClass=posixGroup)(memberUid=%s)))' % (full_username, username)
133-
success = connection.search(base_dn, search_filter, attributes=[name_attribute])
134-
if not success:
135-
return []
194+
entries = _search(base_dn, search_filter, ['uid'], connection)
195+
if not entries:
196+
return full_username, None
136197

137-
group_entries = connection.response
138-
if not group_entries:
139-
return []
198+
if len(entries) > 1:
199+
LOGGER.warning('More than one user found by filter: ' + search_filter)
200+
return full_username, None
140201

141-
return [entry['attributes'][name_attribute][0] for entry in group_entries]
202+
entry = entries[0]
203+
return entry.entry_dn, entry.uid.value
142204

143205
def _load_groups(self, groups_file):
144206
if not os.path.exists(groups_file):
@@ -150,5 +212,5 @@ def _load_groups(self, groups_file):
150212
def _set_user_groups(self, user, groups):
151213
self._user_groups[user] = groups
152214

153-
with open(self._groups_file, 'w') as fd:
154-
json.dump(self._user_groups, fd, indent=2)
215+
new_groups_content = json.dumps(self._user_groups, indent=2)
216+
file_utils.write_file(self._groups_file, new_groups_content)

0 commit comments

Comments
 (0)