|
18 | 18 | package org.apache.toree.magic.builtin
|
19 | 19 |
|
20 | 20 | import java.io.{File, PrintStream}
|
21 |
| -import java.net.URL |
| 21 | +import java.net.{URL, URI} |
22 | 22 | import java.nio.file.{Files, Paths}
|
23 |
| - |
24 | 23 | import org.apache.toree.magic._
|
25 | 24 | import org.apache.toree.magic.builtin.AddJar._
|
26 | 25 | import org.apache.toree.magic.dependencies._
|
27 | 26 | import org.apache.toree.utils.{ArgumentParsingSupport, DownloadSupport, LogLike, FileUtils}
|
28 | 27 | import com.typesafe.config.Config
|
| 28 | +import org.apache.hadoop.fs.Path |
29 | 29 | import org.apache.toree.plugins.annotations.Event
|
30 | 30 |
|
31 | 31 | object AddJar {
|
| 32 | + val HADOOP_FS_SCHEMES = Set("hdfs", "s3", "s3n", "file") |
32 | 33 |
|
33 | 34 | private var jarDir:Option[String] = None
|
34 | 35 |
|
@@ -63,18 +64,18 @@ class AddJar
|
63 | 64 | private def printStream = new PrintStream(outputStream)
|
64 | 65 |
|
65 | 66 | /**
|
66 |
| - * Retrieves file name from URL. |
| 67 | + * Retrieves file name from a URI. |
67 | 68 | *
|
68 |
| - * @param location The remote location (URL) |
69 |
| - * @return The name of the remote URL, or an empty string if one does not exist |
| 69 | + * @param location a URI |
| 70 | + * @return The file name of the remote URI, or an empty string if one does not exist |
70 | 71 | */
|
71 | 72 | def getFileFromLocation(location: String): String = {
|
72 |
| - val url = new URL(location) |
73 |
| - val file = url.getFile.split("/") |
74 |
| - if (file.length > 0) { |
75 |
| - file.last |
| 73 | + val uri = new URI(location) |
| 74 | + val pathParts = uri.getPath.split("/") |
| 75 | + if (pathParts.nonEmpty) { |
| 76 | + pathParts.last |
76 | 77 | } else {
|
77 |
| - "" |
| 78 | + "" |
78 | 79 | }
|
79 | 80 | }
|
80 | 81 |
|
@@ -122,10 +123,27 @@ class AddJar
|
122 | 123 | // Report beginning of download
|
123 | 124 | printStream.println(s"Starting download from $jarRemoteLocation")
|
124 | 125 |
|
125 |
| - downloadFile( |
126 |
| - new URL(jarRemoteLocation), |
127 |
| - new File(downloadLocation).toURI.toURL |
128 |
| - ) |
| 126 | + val jar = URI.create(jarRemoteLocation) |
| 127 | + if (HADOOP_FS_SCHEMES.contains(jar.getScheme)) { |
| 128 | + val conf = kernel.sparkContext.hadoopConfiguration |
| 129 | + val jarPath = new Path(jarRemoteLocation) |
| 130 | + val fs = jarPath.getFileSystem(conf) |
| 131 | + val destPath = if (downloadLocation.startsWith("file:")) { |
| 132 | + new Path(downloadLocation) |
| 133 | + } else { |
| 134 | + new Path("file:" + downloadLocation) |
| 135 | + } |
| 136 | + |
| 137 | + fs.copyToLocalFile( |
| 138 | + false /* keep original file */, |
| 139 | + jarPath, destPath, |
| 140 | + true /* don't create checksum files */) |
| 141 | + } else { |
| 142 | + downloadFile( |
| 143 | + new URL(jarRemoteLocation), |
| 144 | + new File(downloadLocation).toURI.toURL |
| 145 | + ) |
| 146 | + } |
129 | 147 |
|
130 | 148 | // Report download finished
|
131 | 149 | printStream.println(s"Finished download of $jarName")
|
|
0 commit comments