Skip to content

Commit 05f8cb3

Browse files
committed
Add the initial shapefile datasource prototype
1 parent 44f4072 commit 05f8cb3

File tree

10 files changed

+201
-4
lines changed

10 files changed

+201
-4
lines changed

build.sbt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ lazy val datasource = project
128128
spark("core").value % Provided,
129129
spark("mllib").value % Provided,
130130
spark("sql").value % Provided,
131-
`better-files`
131+
`better-files`,
132+
geotrellis("shapefile").value,
133+
geotoolsMain,
134+
geotoolsOpengis,
135+
geotoolsShapefile
132136
),
133137
Compile / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) },
134138
Test / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) },

datasource/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource
55
org.locationtech.rasterframes.datasource.geojson.GeoJsonDataSource
66
org.locationtech.rasterframes.datasource.stac.api.StacApiDataSource
77
org.locationtech.rasterframes.datasource.tiles.TilesDataSource
8-
org.locationtech.rasterframes.datasource.slippy.SlippyDataSource
8+
org.locationtech.rasterframes.datasource.slippy.SlippyDataSource
9+
org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource

datasource/src/main/scala/org/locationtech/rasterframes/datasource/package.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import org.apache.spark.sql.{Column, DataFrame}
2828
import org.apache.spark.sql.util.CaseInsensitiveStringMap
2929
import sttp.model.Uri
3030

31-
import java.net.URI
31+
import java.io.File
32+
import java.net.{URI, URL}
3233
import scala.util.Try
3334

3435
/**
@@ -65,6 +66,25 @@ package object datasource {
6566
if(parameters.containsKey(key)) Uri.parse(parameters.get(key)).toOption
6667
else None
6768

69+
private[rasterframes]
70+
def urlParam(key: String, parameters: Map[String, String]): Option[URL] =
71+
parameters.get(key).flatMap { p =>
72+
Try {
73+
if (p.contains("://")) new URL(p)
74+
else new URL(s"file://${new File(p).getAbsolutePath}")
75+
}.toOption
76+
}
77+
78+
private[rasterframes]
79+
def urlParam(key: String, parameters: CaseInsensitiveStringMap): Option[URL] =
80+
if(parameters.containsKey(key)) {
81+
val p = parameters.get(key)
82+
Try {
83+
if (p.contains("://")) new URL(p)
84+
else new URL(s"file://${new File(p).getAbsolutePath}")
85+
}.toOption
86+
} else None
87+
6888
private[rasterframes]
6989
def jsonParam(key: String, parameters: Map[String, String]): Option[Json] =
7090
parameters.get(key).flatMap(p => parser.parse(p).toOption)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.locationtech.rasterframes.datasource.shapefile
2+
3+
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
4+
import org.apache.spark.sql.connector.expressions.Transform
5+
import org.apache.spark.sql.sources.DataSourceRegister
6+
import org.apache.spark.sql.types.StructType
7+
import org.apache.spark.sql.util.CaseInsensitiveStringMap
8+
9+
import java.util
10+
11+
class ShapeFileDataSource extends TableProvider with DataSourceRegister {
12+
13+
def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType =
14+
getTable(null, Array.empty[Transform], caseInsensitiveStringMap.asCaseSensitiveMap()).schema()
15+
16+
def getTable(structType: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table =
17+
new ShapeFileTable()
18+
19+
def shortName(): String = ShapeFileDataSource.SHORT_NAME
20+
}
21+
22+
object ShapeFileDataSource {
23+
final val SHORT_NAME = "shapefile"
24+
final val URL_PARAM = "url"
25+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package org.locationtech.rasterframes.datasource.shapefile
2+
3+
import org.locationtech.rasterframes.encoders.syntax._
4+
5+
import geotrellis.vector.Geometry
6+
import org.apache.spark.sql.catalyst.InternalRow
7+
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
8+
import org.geotools.data.shapefile.ShapefileDataStore
9+
import org.geotools.data.simple.SimpleFeatureIterator
10+
11+
import java.net.URL
12+
13+
case class ShapeFilePartition(url: URL) extends InputPartition
14+
15+
class ShapeFilePartitionReaderFactory extends PartitionReaderFactory {
16+
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match {
17+
case p: ShapeFilePartition => new ShapeFilePartitionReader(p)
18+
case _ => throw new UnsupportedOperationException("Partition processing is unsupported by the reader.")
19+
}
20+
}
21+
22+
class ShapeFilePartitionReader(partition: ShapeFilePartition) extends PartitionReader[InternalRow] {
23+
import geotrellis.shapefile.ShapeFileReader._
24+
25+
@transient lazy val ds = new ShapefileDataStore(partition.url)
26+
@transient lazy val partitionValues: SimpleFeatureIterator = ds.getFeatureSource.getFeatures.features
27+
28+
def next: Boolean = partitionValues.hasNext
29+
30+
def get: InternalRow = partitionValues.next.geom[Geometry].toInternalRow
31+
32+
def close(): Unit = { partitionValues.close(); ds.dispose() }
33+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package org.locationtech.rasterframes.datasource.shapefile
2+
3+
import org.locationtech.rasterframes.datasource.stac.api.encoders._
4+
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
5+
import org.apache.spark.sql.types.StructType
6+
7+
import java.net.URL
8+
9+
class ShapeFileScanBuilder(url: URL) extends ScanBuilder {
10+
def build(): Scan = new ShapeFileBatchScan(url)
11+
}
12+
13+
/** Batch Reading Support. The schema is repeated here as it can change after column pruning, etc. */
14+
class ShapeFileBatchScan(url: URL) extends Scan with Batch {
15+
def readSchema(): StructType = geometryExpressionEncoder.schema
16+
17+
override def toBatch: Batch = this
18+
19+
/**
20+
* Unfortunately, we can only load everything into a single partition, due to the nature of STAC API endpoints.
21+
* To perform a distributed load, we'd need to know some internals about how the next page token is computed.
22+
* This can be a good idea for the STAC Spec extension.
23+
* */
24+
def planInputPartitions(): Array[InputPartition] = Array(ShapeFilePartition(url))
25+
def createReaderFactory(): PartitionReaderFactory = new ShapeFilePartitionReaderFactory()
26+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package org.locationtech.rasterframes.datasource.shapefile
2+
3+
import org.locationtech.rasterframes.datasource.stac.api.encoders._
4+
5+
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
6+
import org.apache.spark.sql.connector.read.ScanBuilder
7+
import org.apache.spark.sql.types.StructType
8+
import org.apache.spark.sql.util.CaseInsensitiveStringMap
9+
import org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource.URL_PARAM
10+
import org.locationtech.rasterframes.datasource.urlParam
11+
import java.net.URL
12+
13+
import scala.collection.JavaConverters._
14+
import java.util
15+
16+
class ShapeFileTable extends Table with SupportsRead {
17+
import ShapeFileTable._
18+
19+
def name(): String = this.getClass.toString
20+
21+
def schema(): StructType = geometryExpressionEncoder.schema
22+
23+
def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava
24+
25+
def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
26+
new ShapeFileScanBuilder(options.url)
27+
}
28+
29+
object ShapeFileTable {
30+
implicit class CaseInsensitiveStringMapOps(val options: CaseInsensitiveStringMap) extends AnyVal {
31+
def url: URL = urlParam(URL_PARAM, options).getOrElse(throw new IllegalArgumentException("Missing URL."))
32+
}
33+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.locationtech.rasterframes.datasource
2+
3+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
4+
import org.locationtech.jts.geom.Geometry
5+
6+
package object shapefile extends Serializable {
7+
// see org.locationtech.geomesa.spark.jts.encoders.SpatialEncoders
8+
// GeometryUDT should be registered before the encoder below is used
9+
// TODO: use TypedEncoders derived from UDT instances?
10+
@transient implicit lazy val geometryExpressionEncoder: ExpressionEncoder[Option[Geometry]] = ExpressionEncoder()
11+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.locationtech.rasterframes.datasource.shapefile
2+
3+
import org.locationtech.rasterframes._
4+
5+
import geotrellis.shapefile.ShapeFileReader
6+
import org.locationtech.jts.geom.Geometry
7+
import org.locationtech.rasterframes.TestEnvironment
8+
9+
import java.net.URL
10+
11+
class ShapeFileDataSourceTest extends TestEnvironment { self =>
12+
import spark.implicits._
13+
14+
describe("ShapeFile Spark reader") {
15+
it("should read a shapefile") {
16+
val url = "https://github.com/locationtech/geotrellis/raw/master/shapefile/data/shapefiles/demographics/demographics.shp"
17+
import ShapeFileReader._
18+
19+
val expected = ShapeFileReader
20+
.readSimpleFeatures(new URL(url))
21+
.map(_.geom[Geometry])
22+
.take(2)
23+
24+
val results =
25+
spark
26+
.read
27+
.format("shapefile")
28+
.option("url", url)
29+
.load()
30+
.limit(2)
31+
32+
// results.printSchema()
33+
34+
results.as[Option[Geometry]].collect() shouldBe expected
35+
}
36+
37+
}
38+
}

project/RFDependenciesPlugin.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ object RFDependenciesPlugin extends AutoPlugin {
5757
val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1"
5858
val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1"
5959
val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test
60+
61+
val geotoolsVersion = "25.0"
62+
val geotoolsMain = "org.geotools" % "gt-main" % geotoolsVersion
63+
val geotoolsShapefile = "org.geotools" % "gt-shapefile" % geotoolsVersion
64+
val geotoolsOpengis = "org.geotools" % "gt-opengis" % geotoolsVersion
6065
}
6166
import autoImport._
6267

@@ -67,7 +72,8 @@ object RFDependenciesPlugin extends AutoPlugin {
6772
"boundless-releases" at "https://repo.boundlessgeo.com/main/",
6873
"Open Source Geospatial Foundation Repository" at "https://download.osgeo.org/webdav/geotools/",
6974
"oss-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
70-
"jitpack" at "https://jitpack.io"
75+
"jitpack" at "https://jitpack.io",
76+
"osgeo-releases" at "https://repo.osgeo.org/repository/release/"
7177
),
7278
// dependencyOverrides += "com.azavea.gdal" % "gdal-warp-bindings" % "33.f746890",
7379
// NB: Make sure to update the Spark version in pyrasterframes/python/setup.py

0 commit comments

Comments
 (0)