26
26
import java .util .List ;
27
27
import java .util .Map ;
28
28
import java .util .Set ;
29
- import jakarta .inject .Inject ;
29
+ import java .util .regex .Pattern ;
30
+
30
31
import javax .naming .NamingEnumeration ;
31
32
import javax .naming .NamingException ;
32
33
import javax .naming .directory .Attributes ;
33
34
import javax .naming .directory .SearchControls ;
34
35
import javax .naming .directory .SearchResult ;
35
36
import javax .naming .ldap .LdapContext ;
36
37
import javax .sql .DataSource ;
38
+
39
+ import jakarta .inject .Inject ;
37
40
import org .apache .commons .lang3 .StringUtils ;
38
41
import org .apache .commons .lang3 .reflect .FieldUtils ;
39
42
import org .apache .shiro .UnavailableSecurityManagerException ;
@@ -68,6 +71,12 @@ public class ShiroAuthenticationService implements AuthenticationService {
68
71
private static final String ACTIVE_DIRECTORY_GROUP_REALM = "org.apache.zeppelin.realm.ActiveDirectoryGroupRealm" ;
69
72
private static final String JDBC_REALM = "org.apache.shiro.realm.jdbc.JdbcRealm" ;
70
73
74
+ private static final Pattern VALID_SQL_NAME_IDENTIFIER_PATTERN = Pattern .compile ("^[a-zA-Z0-9_]+$" );
75
+
76
+ private static boolean isValidSqlIdentifier (String name ) {
77
+ return name != null && VALID_SQL_NAME_IDENTIFIER_PATTERN .matcher (name ).matches ();
78
+ }
79
+
71
80
private final ZeppelinConfiguration zConf ;
72
81
73
82
@ Inject
@@ -174,7 +183,7 @@ public List<String> getMatchedUsers(String searchText, int numUsersToFetch) {
174
183
usersList .addAll (
175
184
getUserList ((ActiveDirectoryGroupRealm ) realm , searchText , numUsersToFetch ));
176
185
} else if (JDBC_REALM .equals (realClassName )) {
177
- usersList .addAll (getUserList ((JdbcRealm ) realm ));
186
+ usersList .addAll (getUserList ((JdbcRealm ) realm , searchText , numUsersToFetch ));
178
187
}
179
188
}
180
189
}
@@ -401,7 +410,7 @@ private List<String> getUserList(
401
410
}
402
411
403
412
/** Function to extract users from JDBCs. */
404
- private List <String > getUserList (JdbcRealm obj ) {
413
+ private List <String > getUserList (JdbcRealm obj , String searchText , int numUsersToFetch ) {
405
414
List <String > userlist = new ArrayList <>();
406
415
Connection con = null ;
407
416
PreparedStatement ps = null ;
@@ -415,26 +424,39 @@ private List<String> getUserList(JdbcRealm obj) {
415
424
try {
416
425
dataSource = (DataSource ) FieldUtils .readField (obj , "dataSource" , true );
417
426
authQuery = (String ) FieldUtils .readField (obj , "authenticationQuery" , true );
418
- LOGGER .info ( authQuery );
427
+ LOGGER .debug ( "authenticationQuery={}" , authQuery );
419
428
String authQueryLowerCase = authQuery .toLowerCase ();
420
429
retval = authQueryLowerCase .split ("from" , 2 );
421
430
if (retval .length >= 2 ) {
422
431
retval = retval [1 ].split ("with|where" , 2 );
423
- tablename = retval [0 ];
432
+ tablename = retval [0 ]. trim () ;
424
433
retval = retval [1 ].split ("where" , 2 );
425
434
if (retval .length >= 2 ) {
426
435
retval = retval [1 ].split ("=" , 2 );
427
436
} else {
428
437
retval = retval [0 ].split ("=" , 2 );
429
438
}
430
- username = retval [0 ];
439
+ username = retval [0 ]. trim () ;
431
440
}
432
441
433
442
if (StringUtils .isBlank (username ) || StringUtils .isBlank (tablename )) {
434
443
return userlist ;
435
444
}
445
+ if (!isValidSqlIdentifier (username )) {
446
+ throw new IllegalArgumentException (
447
+ "Invalid column name in authenticationQuery to build userlist query: "
448
+ + authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN
449
+ + ", name identifier: [" + username + "]" );
450
+ }
451
+ if (!isValidSqlIdentifier (tablename )) {
452
+ throw new IllegalArgumentException (
453
+ "Invalid table name in authenticationQuery to build userlist query: "
454
+ + authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN
455
+ + ", name identifier: [" + tablename + "]" );
456
+ }
436
457
437
- userquery = "SELECT ? FROM ?" ;
458
+ userquery = String .format ("SELECT %s FROM %s WHERE %s LIKE ?" , username , tablename , username );
459
+ LOGGER .info ("Built query for user list. userquery={}" , userquery );
438
460
} catch (IllegalAccessException e ) {
439
461
LOGGER .error ("Error while accessing dataSource for JDBC Realm" , e );
440
462
return new ArrayList <>();
@@ -443,11 +465,12 @@ private List<String> getUserList(JdbcRealm obj) {
443
465
try {
444
466
con = dataSource .getConnection ();
445
467
ps = con .prepareStatement (userquery );
446
- ps .setString (1 , username );
447
- ps .setString (2 , tablename );
468
+ ps .setString (1 , "%" + searchText + "%" );
448
469
rs = ps .executeQuery ();
449
- while (rs .next ()) {
450
- userlist .add (rs .getString (1 ).trim ());
470
+ int count = 0 ;
471
+ while (rs .next () && count < numUsersToFetch ) {
472
+ userlist .add (rs .getString (1 ).trim ());
473
+ count ++;
451
474
}
452
475
} catch (Exception e ) {
453
476
LOGGER .error ("Error retrieving User list from JDBC Realm" , e );
0 commit comments