diff --git a/.gitignore b/.gitignore index fb42a0c..139b403 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ target/ metastore_db/ derby.log +scalastyle-output.xml +*.iml \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 7495683..c2d6c10 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,5 @@ sudo: false -cache: - directories: - - $HOME/.ivy2/cache - -language: scala -scala: - - 2.10.4 +language: java jdk: - oraclejdk8 diff --git a/build.sbt b/build.sbt deleted file mode 100644 index f1b0dcb..0000000 --- a/build.sbt +++ /dev/null @@ -1,95 +0,0 @@ -/******************** - * Identity * - ********************/ -name := "cookie-datasets" -organization := "ai.cookie" -licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")) - -/******************** - * Version * - ********************/ -scalaVersion := "2.10.4" -sparkVersion := "1.5.0" -crossScalaVersions := Seq("2.10.4") - -/******************** - * scaladocs * - ********************/ -autoAPIMappings := true - -/******************** - * Test * - ********************/ -parallelExecution in Test := false -fork := true -test in assembly := {} - -/******************* - * Spark Packages - ********************/ -spName := "cookieai/cookie-datasets" -spAppendScalaVersion := true -spIncludeMaven := true -spIgnoreProvided := true - -/******************** - * Release settings * - ********************/ -publishMavenStyle := true -pomIncludeRepository := { _ => false } -publishArtifact in Test := false -publishTo := { - val nexus = "https://oss.sonatype.org/" - if (isSnapshot.value) - Some("snapshots" at nexus + "content/repositories/snapshots") - else - Some("releases" at nexus + "service/local/staging/deploy/maven2") -} -pomExtra := - https://github.com/cookieai/cookie-datasets - - git@github.com:cookieai/cookie-datasets.git - scm:git:git@github.com:cookieai/cookie-datasets.git - - - - EronWright - Eron Wright - https://github.com/EronWright - - - -/******************** - * sbt-release * - ********************/ -releaseCrossBuild := true -releasePublishArtifactsAction := PgpKeys.publishSigned.value - -import ReleaseTransformations._ -releaseProcess := Seq[ReleaseStep]( - checkSnapshotDependencies, - inquireVersions, - runClean, - runTest, - setReleaseVersion, - commitReleaseVersion, - tagRelease, - ReleaseStep(action = Command.process("publishSigned", _), enableCrossBuild = true), - setNextVersion, - commitNextVersion, - ReleaseStep(action = Command.process("sonatypeReleaseAll", _), enableCrossBuild = true), - pushChanges -) - -/******************** - * Dependencies * - ********************/ -sparkComponents := Seq("core", "sql", "mllib") - -libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-core" % sparkVersion.value % Test force(), - "org.apache.spark" %% "spark-sql" % sparkVersion.value % Test force(), - "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Test force(), - "org.scalatest" %% "scalatest" % "2.2.5" % Test, - "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_0.2.1" % Test -) diff --git a/build/sbt b/build/sbt deleted file mode 100755 index cc3203d..0000000 --- a/build/sbt +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so -# that we can run Hive to generate the golden answer. This is not required for normal development -# or testing. -for i in "$HIVE_HOME"/lib/* -do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" -done -export HADOOP_CLASSPATH - -realpath () { -( - TARGET_FILE="$1" - - cd "$(dirname "$TARGET_FILE")" - TARGET_FILE="$(basename "$TARGET_FILE")" - - COUNT=0 - while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] - do - TARGET_FILE="$(readlink "$TARGET_FILE")" - cd $(dirname "$TARGET_FILE") - TARGET_FILE="$(basename $TARGET_FILE)" - COUNT=$(($COUNT + 1)) - done - - echo "$(pwd -P)/"$TARGET_FILE"" -) -} - -. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash - - -declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" -declare -r sbt_opts_file=".sbtopts" -declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" - -usage() { - cat < path to global settings/plugins directory (default: ~/.sbt) - -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) - -ivy path to local Ivy repository (default: ~/.ivy2) - -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) - -no-share use all local caches; no sharing - -no-global uses global caches, but does not use global ~/.sbt directory. - -jvm-debug Turn on JVM debugging, open at the given port. - -batch Disable interactive mode - - # sbt version (default: from project/build.properties if present, else latest release) - -sbt-version use the specified version of sbt - -sbt-jar use the specified jar as the sbt launcher - -sbt-rc use an RC version of sbt - -sbt-snapshot use a snapshot version of sbt - - # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) - -java-home alternate JAVA_HOME - - # jvm options and output control - JAVA_OPTS environment variable, if unset uses "$java_opts" - SBT_OPTS environment variable, if unset uses "$default_sbt_opts" - .sbtopts if this file exists in the current directory, it is - prepended to the runner args - /etc/sbt/sbtopts if this file exists, it is prepended to the runner args - -Dkey=val pass -Dkey=val directly to the java runtime - -J-X pass option -X directly to the java runtime - (-J is stripped) - -S-X add -X to sbt's scalacOptions (-S is stripped) - -PmavenProfiles Enable a maven profile for the build. - -In the case of duplicated or conflicting options, the order above -shows precedence: JAVA_OPTS lowest, command line options highest. -EOM -} - -process_my_args () { - while [[ $# -gt 0 ]]; do - case "$1" in - -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; - -no-share) addJava "$noshare_opts" && shift ;; - -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; - -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; - -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; - -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; - -batch) exec /dev/null) - if [[ ! $? ]]; then - saved_stty="" - fi -} - -saveSttySettings -trap onExit INT - -run "$@" - -exit_status=$? -onExit diff --git a/build/sbt-launch-0.13.9.jar b/build/sbt-launch-0.13.9.jar deleted file mode 100644 index c065b47..0000000 Binary files a/build/sbt-launch-0.13.9.jar and /dev/null differ diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash deleted file mode 100755 index 615f848..0000000 --- a/build/sbt-launch-lib.bash +++ /dev/null @@ -1,189 +0,0 @@ -#!/usr/bin/env bash -# - -# A library to simplify using the SBT launcher from other packages. -# Note: This should be used by tools like giter8/conscript etc. - -# TODO - Should we merge the main SBT script with this library? - -if test -z "$HOME"; then - declare -r script_dir="$(dirname "$script_path")" -else - declare -r script_dir="$HOME/.sbt" -fi - -declare -a residual_args -declare -a java_args -declare -a scalac_args -declare -a sbt_commands -declare -a maven_profiles - -if test -x "$JAVA_HOME/bin/java"; then - echo -e "Using $JAVA_HOME as default JAVA_HOME." - echo "Note, this will be overridden by -java-home if it is set." - declare java_cmd="$JAVA_HOME/bin/java" -else - declare java_cmd=java -fi - -echoerr () { - echo 1>&2 "$@" -} -vlog () { - [[ $verbose || $debug ]] && echoerr "$@" -} -dlog () { - [[ $debug ]] && echoerr "$@" -} - -acquire_sbt_jar () { - SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` - URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar - JAR=build/sbt-launch-${SBT_VERSION}.jar - - sbt_jar=$JAR - - if [[ ! -f "$sbt_jar" ]]; then - # Download sbt launch jar if it hasn't been downloaded yet - if [ ! -f "${JAR}" ]; then - # Download - printf "Attempting to fetch sbt\n" - JAR_DL="${JAR}.part" - if [ $(command -v curl) ]; then - curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ - mv "${JAR_DL}" "${JAR}" - elif [ $(command -v wget) ]; then - wget --quiet ${URL1} -O "${JAR_DL}" &&\ - mv "${JAR_DL}" "${JAR}" - else - printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" - exit -1 - fi - fi - if [ ! -f "${JAR}" ]; then - # We failed to download - printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" - exit -1 - fi - printf "Launching sbt from ${JAR}\n" - fi -} - -execRunner () { - # print the arguments one to a line, quoting any containing spaces - [[ $verbose || $debug ]] && echo "# Executing command line:" && { - for arg; do - if printf "%s\n" "$arg" | grep -q ' '; then - printf "\"%s\"\n" "$arg" - else - printf "%s\n" "$arg" - fi - done - echo "" - } - - "$@" -} - -addJava () { - dlog "[addJava] arg = '$1'" - java_args=( "${java_args[@]}" "$1" ) -} - -enableProfile () { - dlog "[enableProfile] arg = '$1'" - maven_profiles=( "${maven_profiles[@]}" "$1" ) - export SBT_MAVEN_PROFILES="${maven_profiles[@]}" -} - -addSbt () { - dlog "[addSbt] arg = '$1'" - sbt_commands=( "${sbt_commands[@]}" "$1" ) -} -addResidual () { - dlog "[residual] arg = '$1'" - residual_args=( "${residual_args[@]}" "$1" ) -} -addDebugger () { - addJava "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$1" -} - -# a ham-fisted attempt to move some memory settings in concert -# so they need not be dicked around with individually. -get_mem_opts () { - local mem=${1:-2048} - local perm=$(( $mem / 4 )) - (( $perm > 256 )) || perm=256 - (( $perm < 4096 )) || perm=4096 - local codecache=$(( $perm / 2 )) - - echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" -} - -require_arg () { - local type="$1" - local opt="$2" - local arg="$3" - if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then - echo "$opt requires <$type> argument" 1>&2 - exit 1 - fi -} - -is_function_defined() { - declare -f "$1" > /dev/null -} - -process_args () { - while [[ $# -gt 0 ]]; do - case "$1" in - -h|-help) usage; exit 1 ;; - -v|-verbose) verbose=1 && shift ;; - -d|-debug) debug=1 && shift ;; - - -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; - -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;; - -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;; - -batch) exec + + 4.0.0 + + ai.cookie + cookie-datasets + 0.2-SNAPSHOT + + + + mmlspark + mmlspark + https://mmlspark.azureedge.net/maven + + + + + + org.apache.spark + spark-core_2.11 + 2.1.1 + + + org.apache.spark + spark-sql_2.11 + 2.1.1 + + + org.apache.spark + spark-mllib_2.11 + 2.1.1 + + + org.scala-lang + scala-library + 2.11.8 + + + org.scala-lang + scala-compiler + 2.11.8 + + + org.scala-lang + scala-reflect + 2.11.8 + + + org.scalatest + scalatest_2.11 + 2.2.6 + test + + + com.holdenkarau + spark-testing-base_2.11 + 2.1.0_0.6.0 + test + + + + + + + + net.alchim31.maven + scala-maven-plugin + 3.2.1 + + + org.apache.maven.plugins + maven-compiler-plugin + 2.0.2 + + + + + + + org.scalastyle + scalastyle-maven-plugin + 0.8.0 + + false + true + true + false + ${basedir}/src/main/scala + ${basedir}/src/test/scala + scalastyle-config.xml + ${basedir}/scalastyle-output.xml + UTF-8 + + + + + check + + + + + + net.alchim31.maven + scala-maven-plugin + + + eclipse-add-source + + add-source + + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + incremental + true + + -unchecked + -deprecation + -feature + + + -source + ${java.version} + -target + ${java.version} + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + compile + + compile + + + + + + + org.scalatest + scalatest-maven-plugin + 1.0 + + + test + + test + + + + + + + + \ No newline at end of file diff --git a/project/build.properties b/project/build.properties deleted file mode 100644 index 176a863..0000000 --- a/project/build.properties +++ /dev/null @@ -1 +0,0 @@ -sbt.version=0.13.9 \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt deleted file mode 100644 index b8074e9..0000000 --- a/project/plugins.sbt +++ /dev/null @@ -1,19 +0,0 @@ -resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) - -resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" - -resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" - -resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" - -addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0") - -addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "1.0") - -addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") - -addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.3") - -addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0") - -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") diff --git a/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala b/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala index c2117d4..c95aa5e 100644 --- a/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala +++ b/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala @@ -28,4 +28,4 @@ object AttributeKeys { * For image data, the shape is typically 3-dimensional - numchannels, height, width. */ val SHAPE: String = "shape" -} \ No newline at end of file +} diff --git a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala index ad75062..d2380c1 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala @@ -21,13 +21,13 @@ package ai.cookie.spark.sql.sources.cifar import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} -import org.apache.hadoop.mapreduce.{InputSplit => HadoopInputSplit, JobContext, RecordReader => HadoopRecordReader, TaskAttemptContext} -import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vector +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext, InputSplit => HadoopInputSplit, RecordReader => HadoopRecordReader} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.Row import ai.cookie.spark.sql.sources.mapreduce.PrunedReader import ai.cookie.spark.sql.types.Conversions._ import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.spark.mllib.util.MLUtils private class CifarInputFormat extends FileInputFormat[String, Row] @@ -41,8 +41,7 @@ private class CifarInputFormat } private class CifarRecordReader() - extends HadoopRecordReader[String,Row] - with Logging { + extends HadoopRecordReader[String,Row] { private var parser: CifarReader = null @@ -60,8 +59,8 @@ private class CifarRecordReader() // initialize parser val format = { CifarRecordReader.getFormat(context.getConfiguration) match { - case Some("CIFAR-10") => CifarFormats._10 - case Some("CIFAR-100") => CifarFormats._100 + case Some("CIFAR-10") => CifarFormats.Cifar10 + case Some("CIFAR-100") => CifarFormats.Cifar100 case other => throw new RuntimeException(s"unsupported CIFAR format '$other'") } } @@ -113,4 +112,4 @@ private object CifarRecordReader { def getFormat(conf: HadoopConfiguration): Option[String] = { Option(conf.get(FORMAT)) } -} \ No newline at end of file +} diff --git a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala index abc18d1..4321ad4 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala @@ -22,14 +22,15 @@ import java.nio.file.Paths import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute} -import org.apache.spark.{SparkContext, Partition, TaskContext, Logging} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, StructField, StructType} import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.sources.{PrunedScan, BaseRelation, RelationProvider} +import org.apache.spark.sql.sources.{BaseRelation, PrunedScan, RelationProvider} +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import ai.cookie.spark.ml.attribute.AttributeKeys import ai.cookie.spark.sql.sources.mapreduce.PrunedReader import ai.cookie.spark.sql.types.VectorUDT @@ -43,11 +44,11 @@ private case class Cifar10Relation(val path: Path, val maxSplitSize: Option[Long extends CifarRelation(path, maxSplitSize)(sqlContext) { private lazy val labelMetadata = NominalAttribute.defaultAttr - .withName("label").withValues(CifarFormats._10.labels).toMetadata() + .withName("label").withValues(CifarFormats.Cifar10.labels).toMetadata() override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false, labelMetadata) :: - StructField("features", VectorUDT(), nullable = false, featureMetadata) :: Nil) + StructField("features", VectorType, nullable = false, featureMetadata) :: Nil) CifarRecordReader.setFormat(hadoopConf, "CIFAR-10") } @@ -61,15 +62,15 @@ private case class Cifar100Relation(val path: Path, val maxSplitSize: Option[Lon extends CifarRelation(path, maxSplitSize)(sqlContext) { private lazy val coarseLabelMetadata = NominalAttribute.defaultAttr - .withName("coarseLabel").withValues(CifarFormats._100.coarseLabels).toMetadata() + .withName("coarseLabel").withValues(CifarFormats.Cifar100.coarseLabels).toMetadata() private lazy val labelMetadata = NominalAttribute.defaultAttr - .withName("label").withValues(CifarFormats._100.fineLabels).toMetadata() + .withName("label").withValues(CifarFormats.Cifar100.fineLabels).toMetadata() override def schema: StructType = StructType( StructField("coarseLabel", DoubleType, nullable = false, coarseLabelMetadata) :: StructField("label", DoubleType, nullable = false, labelMetadata) :: - StructField("features", VectorUDT(), nullable = false, featureMetadata) :: Nil) + StructField("features", VectorType, nullable = false, featureMetadata) :: Nil) CifarRecordReader.setFormat(hadoopConf, "CIFAR-100") } @@ -78,7 +79,7 @@ private abstract class CifarRelation( path: Path, maxSplitSize: Option[Long] = None) (val sqlContext: SQLContext) - extends BaseRelation with PrunedScan with Logging { + extends BaseRelation with PrunedScan { protected val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) diff --git a/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala b/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala index 8c67597..7212e97 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala @@ -53,8 +53,8 @@ private class CifarReader def next(): CifarRecord = { val row = CifarRecord( coarseLabel = format match { - case CifarFormats._100 => stream.readUnsignedByte() - case CifarFormats._10 => 0 + case CifarFormats.Cifar100 => stream.readUnsignedByte() + case CifarFormats.Cifar10 => 0 }, fineLabel = stream.readUnsignedByte(), image = readImage match { @@ -110,12 +110,12 @@ private object CifarFormats { sealed abstract class Format(val name: String, val recordSize: Int) - case object _10 extends Format("CIFAR-10", 1 + IMAGE_FIELD_SIZE) { + case object Cifar10 extends Format("CIFAR-10", 1 + IMAGE_FIELD_SIZE) { val labels = read("batches.meta.txt") } - case object _100 extends Format("CIFAR-100", 1 + 1 + IMAGE_FIELD_SIZE) { + case object Cifar100 extends Format("CIFAR-100", 1 + 1 + IMAGE_FIELD_SIZE) { val fineLabels = read("fine_label_names.txt") val coarseLabels = read("coarse_label_names.txt") } -} \ No newline at end of file +} diff --git a/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala b/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala index ecf9baf..94a43b3 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala @@ -18,14 +18,15 @@ package ai.cookie.spark.sql.sources.iris import ai.cookie.spark.sql.types.VectorUDT -import org.apache.spark.Logging -import org.apache.spark.ml.attribute.{Attribute, NumericAttribute, AttributeGroup, NominalAttribute} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources._ import ai.cookie.spark.sql.sources.libsvm.LibSVMRelation +import org.apache.spark.ml +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType /** * Iris dataset as a Spark SQL relation. @@ -51,11 +52,11 @@ private class IrisLibSVMRelation(override val path: String) */ private class IrisCsvRelation(val path: String) (@transient override val sqlContext: SQLContext) - extends BaseRelation with TableScan with IrisRelation with Logging { + extends BaseRelation with TableScan with IrisRelation { override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false, metadata = labelMetadata) :: - StructField("features", VectorUDT(), nullable = false, metadata = featuresMetadata) :: Nil) + StructField("features", VectorType, nullable = false, metadata = featuresMetadata) :: Nil) override def buildScan(): RDD[Row] = { val sc = sqlContext.sparkContext @@ -70,7 +71,7 @@ private class IrisCsvRelation(val path: String) case "Iris-virginica" => 2.0 }, // features - Vectors.dense(a.slice(0,4).map(_.toDouble))) + ml.linalg.Vectors.dense(a.slice(0,4).map(_.toDouble))) case _ => sys.error("unrecognized format") } } @@ -124,4 +125,4 @@ object DefaultSource { * The format (csv or libsvm) */ val Format = "format" -} \ No newline at end of file +} diff --git a/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala b/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala index a57e6af..33b7043 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala @@ -17,13 +17,13 @@ package ai.cookie.spark.sql.sources.libsvm -import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, TableScan} import org.apache.spark.sql.types.{DoubleType, Metadata, StructField, StructType} import org.apache.spark.sql.{Row, SQLContext} import ai.cookie.spark.sql.types.VectorUDT +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType /** * LibSVMRelation provides a DataFrame representation of LibSVM formatted data. @@ -32,7 +32,7 @@ import ai.cookie.spark.sql.types.VectorUDT */ private[spark] class LibSVMRelation(val path: String, val numFeatures: Option[Int] = None) (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging { + extends BaseRelation with TableScan { protected def labelMetadata: Metadata = Metadata.empty @@ -40,7 +40,7 @@ private[spark] class LibSVMRelation(val path: String, val numFeatures: Option[In override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false, metadata = labelMetadata) :: - StructField("features", VectorUDT(), nullable = false, metadata = featuresMetadata) :: Nil) + StructField("features", VectorType, nullable = false, metadata = featuresMetadata) :: Nil) override def buildScan(): RDD[Row] = { @@ -52,7 +52,7 @@ private[spark] class LibSVMRelation(val path: String, val numFeatures: Option[In } baseRdd.map(pt => { - Row(pt.label, pt.features.toDense) + Row(pt.label, pt.features.toDense.asML) }) } diff --git a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala index 7792d91..1fc9a1f 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala @@ -21,8 +21,7 @@ package ai.cookie.spark.sql.sources.mnist import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.hadoop.mapreduce.{InputSplit => HadoopInputSplit, JobContext, RecordReader => HadoopRecordReader, TaskAttemptContext} -import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.Row import ai.cookie.spark.sql.sources.mapreduce.PrunedReader import ai.cookie.spark.sql.types.Conversions._ @@ -40,7 +39,7 @@ private class MnistInputFormat private class MnistRecordReader() extends HadoopRecordReader[String,Row] - with Logging { + { private var imageParser: MnistImageReader = null private var labelParser: MnistLabelReader = null @@ -61,7 +60,7 @@ private class MnistRecordReader() .getOrElse(throw new RuntimeException("expected labelsPath")) imageParser = new MnistImageReader(file.getPath) labelParser = new MnistLabelReader(labelsPath) - //val recordRange = calculateRange(file, parser) + // val recordRange = calculateRange(file, parser) // calculate the range of records to scan, based on byte-level split information val start = imageParser.recordAt(file.getStart) diff --git a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala index e4ddce6..b59acbc 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala @@ -22,16 +22,17 @@ import java.nio.file.Paths import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, AlphaComponent} -import org.apache.spark.{SparkContext, Partition, TaskContext, Logging} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.sources.{PrunedScan, BaseRelation, RelationProvider} +import org.apache.spark.sql.sources.{BaseRelation, PrunedScan, RelationProvider} import ai.cookie.spark.ml.attribute.AttributeKeys import ai.cookie.spark.sql.sources.mapreduce.PrunedReader import ai.cookie.spark.sql.types.VectorUDT +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.sql.types.MetadataBuilder /** @@ -43,7 +44,7 @@ private case class MnistRelation( maxRecords: Option[Int] = None, maxSplitSize: Option[Long] = None) (@transient val sqlContext: SQLContext) extends BaseRelation - with PrunedScan with Logging { + with PrunedScan { private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) @@ -60,7 +61,7 @@ private case class MnistRelation( override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false) :: - StructField("features", VectorUDT(), nullable = false, featureMetadata) :: Nil) + StructField("features", VectorType, nullable = false, featureMetadata) :: Nil) override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val sc = sqlContext.sparkContext diff --git a/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala b/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala index cfebef3..0b33e8f 100644 --- a/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala +++ b/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala @@ -31,8 +31,9 @@ private[mnist] class MnistLabelReader(path: Path)(implicit conf: Configuration) fs.open(path) } - if(stream.readInt() != MnistLabelReader.HEADER_MAGIC) + if (stream.readInt() != MnistLabelReader.HEADER_MAGIC) { throw new IOException("labels database file is unreadable") + } val numLabels: Int = stream.readInt() @@ -69,8 +70,9 @@ private[mnist] class MnistImageReader(path: Path)(implicit conf: Configuration) fs.open(path) } - if(stream.readInt() != MnistImageReader.HEADER_MAGIC) + if (stream.readInt() != MnistImageReader.HEADER_MAGIC) { throw new IOException("images database file is unreadable") + } val numImages: Int = stream.readInt() @@ -105,4 +107,4 @@ private[mnist] class MnistImageReader(path: Path)(implicit conf: Configuration) private object MnistImageReader { val HEADER_SIZE = 16 val HEADER_MAGIC = 0x00000803 -} \ No newline at end of file +} diff --git a/src/main/scala/ai/cookie/spark/sql/types/package.scala b/src/main/scala/ai/cookie/spark/sql/types/package.scala index 36686bf..ab7c2c8 100644 --- a/src/main/scala/ai/cookie/spark/sql/types/package.scala +++ b/src/main/scala/ai/cookie/spark/sql/types/package.scala @@ -18,8 +18,8 @@ package ai.cookie.spark.sql -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.Vectors package object types { private[spark] object Conversions { @@ -39,4 +39,4 @@ package object types { Vectors.dense(data) } } -} \ No newline at end of file +} diff --git a/src/main/scala/ai/cookie/spark/sql/types/vectors.scala b/src/main/scala/ai/cookie/spark/sql/types/vectors.scala index 1cff37e..9f86463 100644 --- a/src/main/scala/ai/cookie/spark/sql/types/vectors.scala +++ b/src/main/scala/ai/cookie/spark/sql/types/vectors.scala @@ -17,14 +17,14 @@ package ai.cookie.spark.sql.types -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.types.{DataType, SQLUserDefinedType} private[spark] object VectorUDT { /** - * Get the UDT associated with the {@code org.apache.spark.mllib.linalg.Vector} type. + * Get the UDT associated with the {@code org.apache.spark.ml.linalg.Vector} type. */ def apply(): DataType = { classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() } -} \ No newline at end of file +} diff --git a/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala b/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala index ec728b3..698ea9e 100644 --- a/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala +++ b/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala @@ -18,11 +18,11 @@ package ai.cookie.spark.ml.feature import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.attribute.{NominalAttribute, Attribute} +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, DataType} +import org.apache.spark.sql.types.{DataType, StringType} /** * A transformer that maps an ML column of label indices to @@ -41,7 +41,7 @@ class IndexToString(override val uid: String) throw new UnsupportedOperationException(); } - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val schema = transformSchema(dataset.schema, logging = true) val values = Attribute.fromStructField(schema($(inputCol))) match { @@ -49,8 +49,8 @@ class IndexToString(override val uid: String) case _ => throw new UnsupportedOperationException("input column must be a nominal column") } - dataset.withColumn($(outputCol), - callUDF((index: Double) => values(index.toInt), outputDataType, dataset($(inputCol)))) + val toStringUdf = udf((index: Double) => values(index.toInt)) + dataset.withColumn($(outputCol), toStringUdf(dataset($(inputCol)))) } override protected def outputDataType: DataType = StringType diff --git a/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala b/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala index 1643534..00cf0c4 100644 --- a/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala +++ b/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala @@ -37,4 +37,4 @@ private[sql] abstract trait SharedSQLContext extends SharedSparkContext { self: _sqlContext = null super.afterAll() } -} \ No newline at end of file +} diff --git a/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala b/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala index 4fdcea5..3111433 100644 --- a/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala +++ b/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala @@ -25,7 +25,7 @@ import ai.cookie.spark.sql.sources.SharedSQLContext import ai.cookie.spark.sql.types.Conversions._ import org.apache.hadoop.fs.Path import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructField import org.scalatest.{FunSuite, Matchers} @@ -34,8 +34,8 @@ class CifarRelationSuite extends FunSuite with SharedSQLContext with Matchers { private val testDatasets = Seq( - (CifarFormats._100, new Path("src/test/resources/cifar-100-binary/sample.bin"), 100), - (CifarFormats._10, new Path("src/test/resources/cifar-10-batches-bin/sample.bin"), 100) + (CifarFormats.Cifar100, new Path("src/test/resources/cifar-100-binary/sample.bin"), 100), + (CifarFormats.Cifar10, new Path("src/test/resources/cifar-10-batches-bin/sample.bin"), 100) ) private def recordStream(implicit parser: CifarReader): Stream[CifarRecord] = { @@ -65,11 +65,11 @@ with SharedSQLContext with Matchers { } format match { - case CifarFormats._10 => - values(df.schema("label")) should equal(CifarFormats._10.labels) - case CifarFormats._100 => - values(df.schema("label")) should equal(CifarFormats._100.fineLabels) - values(df.schema("coarseLabel")) should equal(CifarFormats._100.coarseLabels) + case CifarFormats.Cifar10 => + values(df.schema("label")) should equal(CifarFormats.Cifar10.labels) + case CifarFormats.Cifar100 => + values(df.schema("label")) should equal(CifarFormats.Cifar100.fineLabels) + values(df.schema("coarseLabel")) should equal(CifarFormats.Cifar100.coarseLabels) } val featureMetadata = df.schema("features").metadata @@ -111,7 +111,7 @@ with SharedSQLContext with Matchers { for((format, path, count) <- testDatasets) { format match { - case CifarFormats._10 => + case CifarFormats.Cifar10 => val df = sqlContext.read.cifar(path.toString, format.name, Some(Long.MaxValue)) .select("label", "features") @@ -133,7 +133,7 @@ with SharedSQLContext with Matchers { df.stat.freqItems(Seq("label")).show - case CifarFormats._100 => + case CifarFormats.Cifar100 => val df = sqlContext.read.cifar(path.toString, format.name, Some(Long.MaxValue)) .select("coarseLabel", "label", "features") diff --git a/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala b/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala index eb63e37..f22454c 100644 --- a/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala +++ b/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala @@ -24,7 +24,7 @@ import ai.cookie.spark.sql.sources.SharedSQLContext import ai.cookie.spark.sql.types.Conversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector import org.scalatest.{FunSuite, Matchers} class MnistRelationSuite extends FunSuite diff --git a/version.sbt b/version.sbt deleted file mode 100644 index 13416e2..0000000 --- a/version.sbt +++ /dev/null @@ -1 +0,0 @@ -version in ThisBuild := "0.2-SNAPSHOT" \ No newline at end of file