diff --git a/zeppelin-server/pom.xml b/zeppelin-server/pom.xml index 09e377fb472..4a3a4be702f 100644 --- a/zeppelin-server/pom.xml +++ b/zeppelin-server/pom.xml @@ -45,6 +45,7 @@ 2.48.2 1.4.01 2.2 + 2.3.232 @@ -368,6 +369,13 @@ test + + com.h2database + h2 + ${h2.version} + test + + org.bitbucket.cowwoc diff-match-patch diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java b/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java index 72e6ffb9d83..14a2abe52cc 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/service/ShiroAuthenticationService.java @@ -26,7 +26,8 @@ import java.util.List; import java.util.Map; import java.util.Set; -import jakarta.inject.Inject; +import java.util.regex.Pattern; + import javax.naming.NamingEnumeration; import javax.naming.NamingException; import javax.naming.directory.Attributes; @@ -34,6 +35,8 @@ import javax.naming.directory.SearchResult; import javax.naming.ldap.LdapContext; import javax.sql.DataSource; + +import jakarta.inject.Inject; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.shiro.UnavailableSecurityManagerException; @@ -68,6 +71,12 @@ public class ShiroAuthenticationService implements AuthenticationService { private static final String ACTIVE_DIRECTORY_GROUP_REALM = "org.apache.zeppelin.realm.ActiveDirectoryGroupRealm"; private static final String JDBC_REALM = "org.apache.shiro.realm.jdbc.JdbcRealm"; + private static final Pattern VALID_SQL_NAME_IDENTIFIER_PATTERN = Pattern.compile("^[a-zA-Z0-9_]+$"); + + private static boolean isValidSqlIdentifier(String name) { + return name != null && VALID_SQL_NAME_IDENTIFIER_PATTERN.matcher(name).matches(); + } + private final ZeppelinConfiguration zConf; @Inject @@ -171,10 +180,9 @@ public List getMatchedUsers(String searchText, int numUsersToFetch) { } else if (LDAP_REALM.equals(realClassName)) { usersList.addAll(getUserList((LdapRealm) realm, searchText, numUsersToFetch)); } else if (ACTIVE_DIRECTORY_GROUP_REALM.equals(realClassName)) { - usersList.addAll( - getUserList((ActiveDirectoryGroupRealm) realm, searchText, numUsersToFetch)); + usersList.addAll(getUserList((ActiveDirectoryGroupRealm) realm, searchText, numUsersToFetch)); } else if (JDBC_REALM.equals(realClassName)) { - usersList.addAll(getUserList((JdbcRealm) realm)); + usersList.addAll(getUserList((JdbcRealm) realm, searchText, numUsersToFetch)); } } } @@ -401,7 +409,7 @@ private List getUserList( } /** Function to extract users from JDBCs. */ - private List getUserList(JdbcRealm obj) { + private List getUserList(JdbcRealm obj, String searchText, int numUsersToFetch) { List userlist = new ArrayList<>(); Connection con = null; PreparedStatement ps = null; @@ -415,26 +423,39 @@ private List getUserList(JdbcRealm obj) { try { dataSource = (DataSource) FieldUtils.readField(obj, "dataSource", true); authQuery = (String) FieldUtils.readField(obj, "authenticationQuery", true); - LOGGER.info(authQuery); + LOGGER.debug("authenticationQuery={}", authQuery); String authQueryLowerCase = authQuery.toLowerCase(); retval = authQueryLowerCase.split("from", 2); if (retval.length >= 2) { retval = retval[1].split("with|where", 2); - tablename = retval[0]; + tablename = retval[0].strip(); retval = retval[1].split("where", 2); if (retval.length >= 2) { retval = retval[1].split("=", 2); } else { retval = retval[0].split("=", 2); } - username = retval[0]; + username = retval[0].strip(); } if (StringUtils.isBlank(username) || StringUtils.isBlank(tablename)) { return userlist; } + if (!isValidSqlIdentifier(username)) { + throw new IllegalArgumentException( + "Invalid column name in authenticationQuery to build userlist query: " + + authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN + + ", name identifier: [" + username + "]"); + } + if (!isValidSqlIdentifier(tablename)) { + throw new IllegalArgumentException( + "Invalid table name in authenticationQuery to build userlist query: " + + authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN + + ", name identifier: [" + tablename + "]"); + } - userquery = "SELECT ? FROM ?"; + userquery = String.format("SELECT %s FROM %s WHERE %s LIKE ?", username, tablename, username); + LOGGER.info("Built query for user list. userquery={}", userquery); } catch (IllegalAccessException e) { LOGGER.error("Error while accessing dataSource for JDBC Realm", e); return new ArrayList<>(); @@ -443,11 +464,10 @@ private List getUserList(JdbcRealm obj) { try { con = dataSource.getConnection(); ps = con.prepareStatement(userquery); - ps.setString(1, username); - ps.setString(2, tablename); + ps.setString(1, "%" + searchText + "%"); rs = ps.executeQuery(); - while (rs.next()) { - userlist.add(rs.getString(1).trim()); + while (rs.next() && userlist.size() < numUsersToFetch) { + userlist.add(rs.getString(1)); } } catch (Exception e) { LOGGER.error("Error retrieving User list from JDBC Realm", e); diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java index c0c7fe6c94e..f82539e715d 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/service/ShiroAuthenticationServiceTest.java @@ -23,17 +23,22 @@ import java.io.IOException; import java.security.Principal; +import java.sql.Connection; +import java.sql.Statement; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.apache.commons.lang3.StringUtils; import org.apache.shiro.mgt.DefaultSecurityManager; +import org.apache.shiro.realm.jdbc.JdbcRealm; import org.apache.shiro.subject.Subject; import org.apache.shiro.util.LifecycleUtils; import org.apache.shiro.util.ThreadContext; import org.apache.zeppelin.conf.ZeppelinConfiguration; import org.apache.zeppelin.realm.jwt.KnoxJwtRealm; import org.apache.zeppelin.service.shiro.AbstractShiroTest; +import org.h2.jdbcx.JdbcDataSource; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -53,6 +58,35 @@ void setup() throws Exception { shiroSecurityService = new ShiroAuthenticationService(zConf); } + @Test + void testGetMatchedUsersWithJdbcRealm() throws Exception { + + // given in-memory jdbcRealm with some users + JdbcRealm realm = new JdbcRealm(); + JdbcDataSource dataSource = new JdbcDataSource(); + dataSource.setURL("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1"); + dataSource.setUser("sa"); + realm.setDataSource(dataSource); + + LifecycleUtils.init(realm); + DefaultSecurityManager securityManager = new DefaultSecurityManager(realm); + ThreadContext.bind(securityManager); + + try (Connection conn = dataSource.getConnection(); Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE users (username VARCHAR PRIMARY KEY, password VARCHAR)"); + stmt.execute("INSERT INTO users VALUES ('admin', '')"); + stmt.execute("INSERT INTO users VALUES ('admin1', '')"); + stmt.execute("INSERT INTO users VALUES ('test', '')"); + } + + // when + List users = shiroSecurityService.getMatchedUsers("adm", 1); + + // then + assertEquals(1, users.size()); + assertEquals("admin", users.get(0)); + } + @Test void canGetPrincipalName() { String expectedName = "java.security.Principal.getName()";