diff --git a/zeppelin-server/pom.xml b/zeppelin-server/pom.xml
index 09e377fb472..4a3a4be702f 100644
--- a/zeppelin-server/pom.xml
+++ b/zeppelin-server/pom.xml
@@ -45,6 +45,7 @@
2.48.2
1.4.01
2.2
+
2.3.232
@@ -368,6 +369,13 @@
test
+
+ com.h2database
+ h2
+ ${h2.version}
+ test
+
+
org.bitbucket.cowwoc
diff-match-patch
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java b/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java
index 72e6ffb9d83..14a2abe52cc 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java
@@ -26,7 +26,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import jakarta.inject.Inject;
+import java.util.regex.Pattern;
+
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attributes;
@@ -34,6 +35,8 @@
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;
import javax.sql.DataSource;
+
+import jakarta.inject.Inject;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.shiro.UnavailableSecurityManagerException;
@@ -68,6 +71,12 @@ public class ShiroAuthenticationService implements AuthenticationService {
private static final String ACTIVE_DIRECTORY_GROUP_REALM = "org.apache.zeppelin.realm.ActiveDirectoryGroupRealm";
private static final String JDBC_REALM = "org.apache.shiro.realm.jdbc.JdbcRealm";
+ private static final Pattern VALID_SQL_NAME_IDENTIFIER_PATTERN = Pattern.compile("^[a-zA-Z0-9_]+$");
+
+ private static boolean isValidSqlIdentifier(String name) {
+ return name != null && VALID_SQL_NAME_IDENTIFIER_PATTERN.matcher(name).matches();
+ }
+
private final ZeppelinConfiguration zConf;
@Inject
@@ -171,10 +180,9 @@ public List getMatchedUsers(String searchText, int numUsersToFetch) {
} else if (LDAP_REALM.equals(realClassName)) {
usersList.addAll(getUserList((LdapRealm) realm, searchText, numUsersToFetch));
} else if (ACTIVE_DIRECTORY_GROUP_REALM.equals(realClassName)) {
- usersList.addAll(
- getUserList((ActiveDirectoryGroupRealm) realm, searchText, numUsersToFetch));
+ usersList.addAll(getUserList((ActiveDirectoryGroupRealm) realm, searchText, numUsersToFetch));
} else if (JDBC_REALM.equals(realClassName)) {
- usersList.addAll(getUserList((JdbcRealm) realm));
+ usersList.addAll(getUserList((JdbcRealm) realm, searchText, numUsersToFetch));
}
}
}
@@ -401,7 +409,7 @@ private List getUserList(
}
/** Function to extract users from JDBCs. */
- private List getUserList(JdbcRealm obj) {
+ private List getUserList(JdbcRealm obj, String searchText, int numUsersToFetch) {
List userlist = new ArrayList<>();
Connection con = null;
PreparedStatement ps = null;
@@ -415,26 +423,39 @@ private List getUserList(JdbcRealm obj) {
try {
dataSource = (DataSource) FieldUtils.readField(obj, "dataSource", true);
authQuery = (String) FieldUtils.readField(obj, "authenticationQuery", true);
- LOGGER.info(authQuery);
+ LOGGER.debug("authenticationQuery={}", authQuery);
String authQueryLowerCase = authQuery.toLowerCase();
retval = authQueryLowerCase.split("from", 2);
if (retval.length >= 2) {
retval = retval[1].split("with|where", 2);
- tablename = retval[0];
+ tablename = retval[0].strip();
retval = retval[1].split("where", 2);
if (retval.length >= 2) {
retval = retval[1].split("=", 2);
} else {
retval = retval[0].split("=", 2);
}
- username = retval[0];
+ username = retval[0].strip();
}
if (StringUtils.isBlank(username) || StringUtils.isBlank(tablename)) {
return userlist;
}
+ if (!isValidSqlIdentifier(username)) {
+ throw new IllegalArgumentException(
+ "Invalid column name in authenticationQuery to build userlist query: "
+ + authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN
+ + ", name identifier: [" + username + "]");
+ }
+ if (!isValidSqlIdentifier(tablename)) {
+ throw new IllegalArgumentException(
+ "Invalid table name in authenticationQuery to build userlist query: "
+ + authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN
+ + ", name identifier: [" + tablename + "]");
+ }
- userquery = "SELECT ? FROM ?";
+ userquery = String.format("SELECT %s FROM %s WHERE %s LIKE ?", username, tablename, username);
+ LOGGER.info("Built query for user list. userquery={}", userquery);
} catch (IllegalAccessException e) {
LOGGER.error("Error while accessing dataSource for JDBC Realm", e);
return new ArrayList<>();
@@ -443,11 +464,10 @@ private List getUserList(JdbcRealm obj) {
try {
con = dataSource.getConnection();
ps = con.prepareStatement(userquery);
- ps.setString(1, username);
- ps.setString(2, tablename);
+ ps.setString(1, "%" + searchText + "%");
rs = ps.executeQuery();
- while (rs.next()) {
- userlist.add(rs.getString(1).trim());
+ while (rs.next() && userlist.size() < numUsersToFetch) {
+ userlist.add(rs.getString(1));
}
} catch (Exception e) {
LOGGER.error("Error retrieving User list from JDBC Realm", e);
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java
index c0c7fe6c94e..f82539e715d 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java
@@ -23,17 +23,22 @@
import java.io.IOException;
import java.security.Principal;
+import java.sql.Connection;
+import java.sql.Statement;
import java.util.HashSet;
+import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.mgt.DefaultSecurityManager;
+import org.apache.shiro.realm.jdbc.JdbcRealm;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.LifecycleUtils;
import org.apache.shiro.util.ThreadContext;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.realm.jwt.KnoxJwtRealm;
import org.apache.zeppelin.service.shiro.AbstractShiroTest;
+import org.h2.jdbcx.JdbcDataSource;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -53,6 +58,35 @@ void setup() throws Exception {
shiroSecurityService = new ShiroAuthenticationService(zConf);
}
+ @Test
+ void testGetMatchedUsersWithJdbcRealm() throws Exception {
+
+ // given in-memory jdbcRealm with some users
+ JdbcRealm realm = new JdbcRealm();
+ JdbcDataSource dataSource = new JdbcDataSource();
+ dataSource.setURL("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");
+ dataSource.setUser("sa");
+ realm.setDataSource(dataSource);
+
+ LifecycleUtils.init(realm);
+ DefaultSecurityManager securityManager = new DefaultSecurityManager(realm);
+ ThreadContext.bind(securityManager);
+
+ try (Connection conn = dataSource.getConnection(); Statement stmt = conn.createStatement()) {
+ stmt.execute("CREATE TABLE users (username VARCHAR PRIMARY KEY, password VARCHAR)");
+ stmt.execute("INSERT INTO users VALUES ('admin', '')");
+ stmt.execute("INSERT INTO users VALUES ('admin1', '')");
+ stmt.execute("INSERT INTO users VALUES ('test', '')");
+ }
+
+ // when
+ List users = shiroSecurityService.getMatchedUsers("adm", 1);
+
+ // then
+ assertEquals(1, users.size());
+ assertEquals("admin", users.get(0));
+ }
+
@Test
void canGetPrincipalName() {
String expectedName = "java.security.Principal.getName()";