Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.FileInputStream;
import java.io.FilenameFilter;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
Expand All @@ -28,13 +29,15 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.NullOutputStream;
import org.apache.commons.lang3.StringUtils;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.interpreter.recovery.RecoveryStorage;
Expand Down Expand Up @@ -272,22 +275,57 @@ private String detectSparkScalaVersion(String sparkHome, Map<String, String> env
builder.environment().putAll(env);
File processOutputFile = File.createTempFile("zeppelin-spark", ".out");
builder.redirectError(processOutputFile);

Process process = builder.start();
process.waitFor();
String processOutput = IOUtils.toString(new FileInputStream(processOutputFile), StandardCharsets.UTF_8);
Pattern pattern = Pattern.compile(".*Using Scala version (.*),.*");
Matcher matcher = pattern.matcher(processOutput);
if (matcher.find()) {
String scalaVersion = matcher.group(1);
if (scalaVersion.startsWith("2.12")) {
return "2.12";
} else if (scalaVersion.startsWith("2.13")) {
return "2.13";
try {
// Consume stdout to prevent buffer overflow
try (InputStream stdout = process.getInputStream()) {
IOUtils.copy(stdout, NullOutputStream.NULL_OUTPUT_STREAM);
}

// Wait with timeout (30 seconds)
boolean finished = process.waitFor(30, TimeUnit.SECONDS);
if (!finished) {
process.destroyForcibly();
throw new IOException("spark-submit --version command timed out after 30 seconds");
}

// Check exit value
int exitValue = process.exitValue();
if (exitValue != 0) {
LOGGER.warn("spark-submit --version exited with non-zero code: {}", exitValue);
}

// Read the output from the file
String processOutput;
try (FileInputStream in = new FileInputStream(processOutputFile)) {
processOutput = IOUtils.toString(in, StandardCharsets.UTF_8);
}

Pattern pattern = Pattern.compile(".*Using Scala version (.*),.*");
Matcher matcher = pattern.matcher(processOutput);
if (matcher.find()) {
String scalaVersion = matcher.group(1);
if (scalaVersion.startsWith("2.12")) {
return "2.12";
} else if (scalaVersion.startsWith("2.13")) {
return "2.13";
} else {
throw new Exception("Unsupported scala version: " + scalaVersion);
}
} else {
throw new Exception("Unsupported scala version: " + scalaVersion);
LOGGER.debug("Could not detect Scala version from spark-submit output, falling back to jar inspection");
return detectSparkScalaVersionByReplClass(sparkHome);
}
} finally {
// Ensure process is cleaned up
if (process.isAlive()) {
process.destroyForcibly();
}
// Clean up temporary file
if (!processOutputFile.delete() && processOutputFile.exists()) {
LOGGER.warn("Failed to delete temporary file: {}", processOutputFile.getAbsolutePath());
}
} else {
return detectSparkScalaVersionByReplClass(sparkHome);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -325,4 +328,27 @@ void testYarnClusterMode_3() throws IOException {
}
FileUtils.deleteDirectory(localRepoPath.toFile());
}

@Test
void testDetectSparkScalaVersionProcessManagement() throws Exception {
SparkInterpreterLauncher launcher = new SparkInterpreterLauncher(zConf, null);

// Use reflection to access private method
Method detectSparkScalaVersionMethod = SparkInterpreterLauncher.class.getDeclaredMethod(
"detectSparkScalaVersion", String.class, Map.class);
detectSparkScalaVersionMethod.setAccessible(true);

Map<String, String> env = new HashMap<>();

// Call the method multiple times to ensure processes are properly cleaned
for (int i = 0; i < 3; i++) {
String scalaVersion = (String) detectSparkScalaVersionMethod.invoke(launcher, sparkHome, env);
assertTrue(scalaVersion.equals("2.12") || scalaVersion.equals("2.13"),
"Expected scala version 2.12 or 2.13 but got: " + scalaVersion);
}

// Note: We cannot easily test that processes are destroyed or that stdout is consumed
// without mocking ProcessBuilder, which would require significant refactoring.
// The test above ensures the method still works correctly with the new implementation.
}
}
Loading