Skip to content

[ZEPPELIN-6186] fix searching user using JdbcRealm #4926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions zeppelin-server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<selenium.java.version>2.48.2</selenium.java.version>
<xml.apis.version>1.4.01</xml.apis.version>
<hamcrest.version>2.2</hamcrest.version>
<h2.version>2.3.232</h2.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -368,6 +369,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>${h2.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.bitbucket.cowwoc</groupId>
<artifactId>diff-match-patch</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import jakarta.inject.Inject;
import java.util.regex.Pattern;

import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attributes;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;
import javax.sql.DataSource;

import jakarta.inject.Inject;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.shiro.UnavailableSecurityManagerException;
Expand Down Expand Up @@ -68,6 +71,12 @@ public class ShiroAuthenticationService implements AuthenticationService {
private static final String ACTIVE_DIRECTORY_GROUP_REALM = "org.apache.zeppelin.realm.ActiveDirectoryGroupRealm";
private static final String JDBC_REALM = "org.apache.shiro.realm.jdbc.JdbcRealm";

private static final Pattern VALID_SQL_NAME_IDENTIFIER_PATTERN = Pattern.compile("^[a-zA-Z0-9_]+$");

private static boolean isValidSqlIdentifier(String name) {
return name != null && VALID_SQL_NAME_IDENTIFIER_PATTERN.matcher(name).matches();
}

private final ZeppelinConfiguration zConf;

@Inject
Expand Down Expand Up @@ -171,10 +180,9 @@ public List<String> getMatchedUsers(String searchText, int numUsersToFetch) {
} else if (LDAP_REALM.equals(realClassName)) {
usersList.addAll(getUserList((LdapRealm) realm, searchText, numUsersToFetch));
} else if (ACTIVE_DIRECTORY_GROUP_REALM.equals(realClassName)) {
usersList.addAll(
getUserList((ActiveDirectoryGroupRealm) realm, searchText, numUsersToFetch));
usersList.addAll(getUserList((ActiveDirectoryGroupRealm) realm, searchText, numUsersToFetch));
} else if (JDBC_REALM.equals(realClassName)) {
usersList.addAll(getUserList((JdbcRealm) realm));
usersList.addAll(getUserList((JdbcRealm) realm, searchText, numUsersToFetch));
}
}
}
Expand Down Expand Up @@ -401,7 +409,7 @@ private List<String> getUserList(
}

/** Function to extract users from JDBCs. */
private List<String> getUserList(JdbcRealm obj) {
private List<String> getUserList(JdbcRealm obj, String searchText, int numUsersToFetch) {
List<String> userlist = new ArrayList<>();
Connection con = null;
PreparedStatement ps = null;
Expand All @@ -415,26 +423,39 @@ private List<String> getUserList(JdbcRealm obj) {
try {
dataSource = (DataSource) FieldUtils.readField(obj, "dataSource", true);
authQuery = (String) FieldUtils.readField(obj, "authenticationQuery", true);
LOGGER.info(authQuery);
LOGGER.debug("authenticationQuery={}", authQuery);
String authQueryLowerCase = authQuery.toLowerCase();
retval = authQueryLowerCase.split("from", 2);
if (retval.length >= 2) {
retval = retval[1].split("with|where", 2);
tablename = retval[0];
tablename = retval[0].strip();
retval = retval[1].split("where", 2);
if (retval.length >= 2) {
retval = retval[1].split("=", 2);
} else {
retval = retval[0].split("=", 2);
}
username = retval[0];
username = retval[0].strip();
}

if (StringUtils.isBlank(username) || StringUtils.isBlank(tablename)) {
return userlist;
}
if (!isValidSqlIdentifier(username)) {
throw new IllegalArgumentException(
"Invalid column name in authenticationQuery to build userlist query: "
+ authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN
+ ", name identifier: [" + username + "]");
}
if (!isValidSqlIdentifier(tablename)) {
throw new IllegalArgumentException(
"Invalid table name in authenticationQuery to build userlist query: "
+ authQuery + ", allowed pattern: " + VALID_SQL_NAME_IDENTIFIER_PATTERN
+ ", name identifier: [" + tablename + "]");
}

userquery = "SELECT ? FROM ?";
userquery = String.format("SELECT %s FROM %s WHERE %s LIKE ?", username, tablename, username);
LOGGER.info("Built query for user list. userquery={}", userquery);
} catch (IllegalAccessException e) {
LOGGER.error("Error while accessing dataSource for JDBC Realm", e);
return new ArrayList<>();
Expand All @@ -443,11 +464,10 @@ private List<String> getUserList(JdbcRealm obj) {
try {
con = dataSource.getConnection();
ps = con.prepareStatement(userquery);
ps.setString(1, username);
ps.setString(2, tablename);
ps.setString(1, "%" + searchText + "%");
rs = ps.executeQuery();
while (rs.next()) {
userlist.add(rs.getString(1).trim());
while (rs.next() && userlist.size() < numUsersToFetch) {
userlist.add(rs.getString(1));
}
} catch (Exception e) {
LOGGER.error("Error retrieving User list from JDBC Realm", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@

import java.io.IOException;
import java.security.Principal;
import java.sql.Connection;
import java.sql.Statement;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.mgt.DefaultSecurityManager;
import org.apache.shiro.realm.jdbc.JdbcRealm;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.LifecycleUtils;
import org.apache.shiro.util.ThreadContext;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.realm.jwt.KnoxJwtRealm;
import org.apache.zeppelin.service.shiro.AbstractShiroTest;
import org.h2.jdbcx.JdbcDataSource;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -53,6 +58,35 @@ void setup() throws Exception {
shiroSecurityService = new ShiroAuthenticationService(zConf);
}

@Test
void testGetMatchedUsersWithJdbcRealm() throws Exception {

// given in-memory jdbcRealm with some users
JdbcRealm realm = new JdbcRealm();
JdbcDataSource dataSource = new JdbcDataSource();
dataSource.setURL("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");
dataSource.setUser("sa");
realm.setDataSource(dataSource);

LifecycleUtils.init(realm);
DefaultSecurityManager securityManager = new DefaultSecurityManager(realm);
ThreadContext.bind(securityManager);

try (Connection conn = dataSource.getConnection(); Statement stmt = conn.createStatement()) {
stmt.execute("CREATE TABLE users (username VARCHAR PRIMARY KEY, password VARCHAR)");
stmt.execute("INSERT INTO users VALUES ('admin', '')");
stmt.execute("INSERT INTO users VALUES ('admin1', '')");
stmt.execute("INSERT INTO users VALUES ('test', '')");
}

// when
List<String> users = shiroSecurityService.getMatchedUsers("adm", 1);

// then
assertEquals(1, users.size());
assertEquals("admin", users.get(0));
}

@Test
void canGetPrincipalName() {
String expectedName = "java.security.Principal.getName()";
Expand Down
Loading