diff --git a/.gitignore b/.gitignore
index fb42a0c..139b403 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,5 @@
target/
metastore_db/
derby.log
+scalastyle-output.xml
+*.iml
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
index 7495683..c2d6c10 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,11 +1,5 @@
sudo: false
-cache:
- directories:
- - $HOME/.ivy2/cache
-
-language: scala
-scala:
- - 2.10.4
+language: java
jdk:
- oraclejdk8
diff --git a/build.sbt b/build.sbt
deleted file mode 100644
index f1b0dcb..0000000
--- a/build.sbt
+++ /dev/null
@@ -1,95 +0,0 @@
-/********************
- * Identity *
- ********************/
-name := "cookie-datasets"
-organization := "ai.cookie"
-licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))
-
-/********************
- * Version *
- ********************/
-scalaVersion := "2.10.4"
-sparkVersion := "1.5.0"
-crossScalaVersions := Seq("2.10.4")
-
-/********************
- * scaladocs *
- ********************/
-autoAPIMappings := true
-
-/********************
- * Test *
- ********************/
-parallelExecution in Test := false
-fork := true
-test in assembly := {}
-
-/*******************
- * Spark Packages
- ********************/
-spName := "cookieai/cookie-datasets"
-spAppendScalaVersion := true
-spIncludeMaven := true
-spIgnoreProvided := true
-
-/********************
- * Release settings *
- ********************/
-publishMavenStyle := true
-pomIncludeRepository := { _ => false }
-publishArtifact in Test := false
-publishTo := {
- val nexus = "https://oss.sonatype.org/"
- if (isSnapshot.value)
- Some("snapshots" at nexus + "content/repositories/snapshots")
- else
- Some("releases" at nexus + "service/local/staging/deploy/maven2")
-}
-pomExtra :=
- https://github.com/cookieai/cookie-datasets
-
- git@github.com:cookieai/cookie-datasets.git
- scm:git:git@github.com:cookieai/cookie-datasets.git
-
-
-
- EronWright
- Eron Wright
- https://github.com/EronWright
-
-
-
-/********************
- * sbt-release *
- ********************/
-releaseCrossBuild := true
-releasePublishArtifactsAction := PgpKeys.publishSigned.value
-
-import ReleaseTransformations._
-releaseProcess := Seq[ReleaseStep](
- checkSnapshotDependencies,
- inquireVersions,
- runClean,
- runTest,
- setReleaseVersion,
- commitReleaseVersion,
- tagRelease,
- ReleaseStep(action = Command.process("publishSigned", _), enableCrossBuild = true),
- setNextVersion,
- commitNextVersion,
- ReleaseStep(action = Command.process("sonatypeReleaseAll", _), enableCrossBuild = true),
- pushChanges
-)
-
-/********************
- * Dependencies *
- ********************/
-sparkComponents := Seq("core", "sql", "mllib")
-
-libraryDependencies ++= Seq(
- "org.apache.spark" %% "spark-core" % sparkVersion.value % Test force(),
- "org.apache.spark" %% "spark-sql" % sparkVersion.value % Test force(),
- "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Test force(),
- "org.scalatest" %% "scalatest" % "2.2.5" % Test,
- "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_0.2.1" % Test
-)
diff --git a/build/sbt b/build/sbt
deleted file mode 100755
index cc3203d..0000000
--- a/build/sbt
+++ /dev/null
@@ -1,156 +0,0 @@
-#!/usr/bin/env bash
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so
-# that we can run Hive to generate the golden answer. This is not required for normal development
-# or testing.
-for i in "$HIVE_HOME"/lib/*
-do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i"
-done
-export HADOOP_CLASSPATH
-
-realpath () {
-(
- TARGET_FILE="$1"
-
- cd "$(dirname "$TARGET_FILE")"
- TARGET_FILE="$(basename "$TARGET_FILE")"
-
- COUNT=0
- while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ]
- do
- TARGET_FILE="$(readlink "$TARGET_FILE")"
- cd $(dirname "$TARGET_FILE")
- TARGET_FILE="$(basename $TARGET_FILE)"
- COUNT=$(($COUNT + 1))
- done
-
- echo "$(pwd -P)/"$TARGET_FILE""
-)
-}
-
-. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash
-
-
-declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy"
-declare -r sbt_opts_file=".sbtopts"
-declare -r etc_sbt_opts_file="/etc/sbt/sbtopts"
-
-usage() {
- cat < path to global settings/plugins directory (default: ~/.sbt)
- -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series)
- -ivy path to local Ivy repository (default: ~/.ivy2)
- -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem))
- -no-share use all local caches; no sharing
- -no-global uses global caches, but does not use global ~/.sbt directory.
- -jvm-debug Turn on JVM debugging, open at the given port.
- -batch Disable interactive mode
-
- # sbt version (default: from project/build.properties if present, else latest release)
- -sbt-version use the specified version of sbt
- -sbt-jar use the specified jar as the sbt launcher
- -sbt-rc use an RC version of sbt
- -sbt-snapshot use a snapshot version of sbt
-
- # java version (default: java from PATH, currently $(java -version 2>&1 | grep version))
- -java-home alternate JAVA_HOME
-
- # jvm options and output control
- JAVA_OPTS environment variable, if unset uses "$java_opts"
- SBT_OPTS environment variable, if unset uses "$default_sbt_opts"
- .sbtopts if this file exists in the current directory, it is
- prepended to the runner args
- /etc/sbt/sbtopts if this file exists, it is prepended to the runner args
- -Dkey=val pass -Dkey=val directly to the java runtime
- -J-X pass option -X directly to the java runtime
- (-J is stripped)
- -S-X add -X to sbt's scalacOptions (-S is stripped)
- -PmavenProfiles Enable a maven profile for the build.
-
-In the case of duplicated or conflicting options, the order above
-shows precedence: JAVA_OPTS lowest, command line options highest.
-EOM
-}
-
-process_my_args () {
- while [[ $# -gt 0 ]]; do
- case "$1" in
- -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;;
- -no-share) addJava "$noshare_opts" && shift ;;
- -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;;
- -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;;
- -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;;
- -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;;
- -batch) exec /dev/null)
- if [[ ! $? ]]; then
- saved_stty=""
- fi
-}
-
-saveSttySettings
-trap onExit INT
-
-run "$@"
-
-exit_status=$?
-onExit
diff --git a/build/sbt-launch-0.13.9.jar b/build/sbt-launch-0.13.9.jar
deleted file mode 100644
index c065b47..0000000
Binary files a/build/sbt-launch-0.13.9.jar and /dev/null differ
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
deleted file mode 100755
index 615f848..0000000
--- a/build/sbt-launch-lib.bash
+++ /dev/null
@@ -1,189 +0,0 @@
-#!/usr/bin/env bash
-#
-
-# A library to simplify using the SBT launcher from other packages.
-# Note: This should be used by tools like giter8/conscript etc.
-
-# TODO - Should we merge the main SBT script with this library?
-
-if test -z "$HOME"; then
- declare -r script_dir="$(dirname "$script_path")"
-else
- declare -r script_dir="$HOME/.sbt"
-fi
-
-declare -a residual_args
-declare -a java_args
-declare -a scalac_args
-declare -a sbt_commands
-declare -a maven_profiles
-
-if test -x "$JAVA_HOME/bin/java"; then
- echo -e "Using $JAVA_HOME as default JAVA_HOME."
- echo "Note, this will be overridden by -java-home if it is set."
- declare java_cmd="$JAVA_HOME/bin/java"
-else
- declare java_cmd=java
-fi
-
-echoerr () {
- echo 1>&2 "$@"
-}
-vlog () {
- [[ $verbose || $debug ]] && echoerr "$@"
-}
-dlog () {
- [[ $debug ]] && echoerr "$@"
-}
-
-acquire_sbt_jar () {
- SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties`
- URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
- JAR=build/sbt-launch-${SBT_VERSION}.jar
-
- sbt_jar=$JAR
-
- if [[ ! -f "$sbt_jar" ]]; then
- # Download sbt launch jar if it hasn't been downloaded yet
- if [ ! -f "${JAR}" ]; then
- # Download
- printf "Attempting to fetch sbt\n"
- JAR_DL="${JAR}.part"
- if [ $(command -v curl) ]; then
- curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\
- mv "${JAR_DL}" "${JAR}"
- elif [ $(command -v wget) ]; then
- wget --quiet ${URL1} -O "${JAR_DL}" &&\
- mv "${JAR_DL}" "${JAR}"
- else
- printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
- exit -1
- fi
- fi
- if [ ! -f "${JAR}" ]; then
- # We failed to download
- printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n"
- exit -1
- fi
- printf "Launching sbt from ${JAR}\n"
- fi
-}
-
-execRunner () {
- # print the arguments one to a line, quoting any containing spaces
- [[ $verbose || $debug ]] && echo "# Executing command line:" && {
- for arg; do
- if printf "%s\n" "$arg" | grep -q ' '; then
- printf "\"%s\"\n" "$arg"
- else
- printf "%s\n" "$arg"
- fi
- done
- echo ""
- }
-
- "$@"
-}
-
-addJava () {
- dlog "[addJava] arg = '$1'"
- java_args=( "${java_args[@]}" "$1" )
-}
-
-enableProfile () {
- dlog "[enableProfile] arg = '$1'"
- maven_profiles=( "${maven_profiles[@]}" "$1" )
- export SBT_MAVEN_PROFILES="${maven_profiles[@]}"
-}
-
-addSbt () {
- dlog "[addSbt] arg = '$1'"
- sbt_commands=( "${sbt_commands[@]}" "$1" )
-}
-addResidual () {
- dlog "[residual] arg = '$1'"
- residual_args=( "${residual_args[@]}" "$1" )
-}
-addDebugger () {
- addJava "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$1"
-}
-
-# a ham-fisted attempt to move some memory settings in concert
-# so they need not be dicked around with individually.
-get_mem_opts () {
- local mem=${1:-2048}
- local perm=$(( $mem / 4 ))
- (( $perm > 256 )) || perm=256
- (( $perm < 4096 )) || perm=4096
- local codecache=$(( $perm / 2 ))
-
- echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m"
-}
-
-require_arg () {
- local type="$1"
- local opt="$2"
- local arg="$3"
- if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then
- echo "$opt requires <$type> argument" 1>&2
- exit 1
- fi
-}
-
-is_function_defined() {
- declare -f "$1" > /dev/null
-}
-
-process_args () {
- while [[ $# -gt 0 ]]; do
- case "$1" in
- -h|-help) usage; exit 1 ;;
- -v|-verbose) verbose=1 && shift ;;
- -d|-debug) debug=1 && shift ;;
-
- -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;;
- -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;;
- -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;;
- -batch) exec
+
+ 4.0.0
+
+ ai.cookie
+ cookie-datasets
+ 0.2-SNAPSHOT
+
+
+
+ mmlspark
+ mmlspark
+ https://mmlspark.azureedge.net/maven
+
+
+
+
+
+ org.apache.spark
+ spark-core_2.11
+ 2.1.1
+
+
+ org.apache.spark
+ spark-sql_2.11
+ 2.1.1
+
+
+ org.apache.spark
+ spark-mllib_2.11
+ 2.1.1
+
+
+ org.scala-lang
+ scala-library
+ 2.11.8
+
+
+ org.scala-lang
+ scala-compiler
+ 2.11.8
+
+
+ org.scala-lang
+ scala-reflect
+ 2.11.8
+
+
+ org.scalatest
+ scalatest_2.11
+ 2.2.6
+ test
+
+
+ com.holdenkarau
+ spark-testing-base_2.11
+ 2.1.0_0.6.0
+ test
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.1
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 2.0.2
+
+
+
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+ 0.8.0
+
+ false
+ true
+ true
+ false
+ ${basedir}/src/main/scala
+ ${basedir}/src/test/scala
+ scalastyle-config.xml
+ ${basedir}/scalastyle-output.xml
+ UTF-8
+
+
+
+
+ check
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ eclipse-add-source
+
+ add-source
+
+
+
+ scala-compile-first
+ process-resources
+
+ add-source
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+ incremental
+ true
+
+ -unchecked
+ -deprecation
+ -feature
+
+
+ -source
+ ${java.version}
+ -target
+ ${java.version}
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ compile
+
+ compile
+
+
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 1.0
+
+
+ test
+
+ test
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/project/build.properties b/project/build.properties
deleted file mode 100644
index 176a863..0000000
--- a/project/build.properties
+++ /dev/null
@@ -1 +0,0 @@
-sbt.version=0.13.9
\ No newline at end of file
diff --git a/project/plugins.sbt b/project/plugins.sbt
deleted file mode 100644
index b8074e9..0000000
--- a/project/plugins.sbt
+++ /dev/null
@@ -1,19 +0,0 @@
-resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns)
-
-resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
-
-resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
-
-resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven"
-
-addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0")
-
-addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "1.0")
-
-addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
-
-addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.3")
-
-addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0")
-
-addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0")
diff --git a/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala b/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala
index c2117d4..c95aa5e 100644
--- a/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala
+++ b/src/main/scala/ai/cookie/spark/ml/attribute/attributes.scala
@@ -28,4 +28,4 @@ object AttributeKeys {
* For image data, the shape is typically 3-dimensional - numchannels, height, width.
*/
val SHAPE: String = "shape"
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala
index ad75062..d2380c1 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarInputFormat.scala
@@ -21,13 +21,13 @@ package ai.cookie.spark.sql.sources.cifar
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
-import org.apache.hadoop.mapreduce.{InputSplit => HadoopInputSplit, JobContext, RecordReader => HadoopRecordReader, TaskAttemptContext}
-import org.apache.spark.Logging
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext, InputSplit => HadoopInputSplit, RecordReader => HadoopRecordReader}
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Row
import ai.cookie.spark.sql.sources.mapreduce.PrunedReader
import ai.cookie.spark.sql.types.Conversions._
import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
+import org.apache.spark.mllib.util.MLUtils
private class CifarInputFormat
extends FileInputFormat[String, Row]
@@ -41,8 +41,7 @@ private class CifarInputFormat
}
private class CifarRecordReader()
- extends HadoopRecordReader[String,Row]
- with Logging {
+ extends HadoopRecordReader[String,Row] {
private var parser: CifarReader = null
@@ -60,8 +59,8 @@ private class CifarRecordReader()
// initialize parser
val format = {
CifarRecordReader.getFormat(context.getConfiguration) match {
- case Some("CIFAR-10") => CifarFormats._10
- case Some("CIFAR-100") => CifarFormats._100
+ case Some("CIFAR-10") => CifarFormats.Cifar10
+ case Some("CIFAR-100") => CifarFormats.Cifar100
case other => throw new RuntimeException(s"unsupported CIFAR format '$other'")
}
}
@@ -113,4 +112,4 @@ private object CifarRecordReader {
def getFormat(conf: HadoopConfiguration): Option[String] = {
Option(conf.get(FORMAT))
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala
index abc18d1..4321ad4 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/cifar/CifarRelation.scala
@@ -22,14 +22,15 @@ import java.nio.file.Paths
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.annotation.{Experimental, AlphaComponent}
+import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute}
-import org.apache.spark.{SparkContext, Partition, TaskContext, Logging}
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.sources.{PrunedScan, BaseRelation, RelationProvider}
+import org.apache.spark.sql.sources.{BaseRelation, PrunedScan, RelationProvider}
+import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import ai.cookie.spark.ml.attribute.AttributeKeys
import ai.cookie.spark.sql.sources.mapreduce.PrunedReader
import ai.cookie.spark.sql.types.VectorUDT
@@ -43,11 +44,11 @@ private case class Cifar10Relation(val path: Path, val maxSplitSize: Option[Long
extends CifarRelation(path, maxSplitSize)(sqlContext) {
private lazy val labelMetadata = NominalAttribute.defaultAttr
- .withName("label").withValues(CifarFormats._10.labels).toMetadata()
+ .withName("label").withValues(CifarFormats.Cifar10.labels).toMetadata()
override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false, labelMetadata) ::
- StructField("features", VectorUDT(), nullable = false, featureMetadata) :: Nil)
+ StructField("features", VectorType, nullable = false, featureMetadata) :: Nil)
CifarRecordReader.setFormat(hadoopConf, "CIFAR-10")
}
@@ -61,15 +62,15 @@ private case class Cifar100Relation(val path: Path, val maxSplitSize: Option[Lon
extends CifarRelation(path, maxSplitSize)(sqlContext) {
private lazy val coarseLabelMetadata = NominalAttribute.defaultAttr
- .withName("coarseLabel").withValues(CifarFormats._100.coarseLabels).toMetadata()
+ .withName("coarseLabel").withValues(CifarFormats.Cifar100.coarseLabels).toMetadata()
private lazy val labelMetadata = NominalAttribute.defaultAttr
- .withName("label").withValues(CifarFormats._100.fineLabels).toMetadata()
+ .withName("label").withValues(CifarFormats.Cifar100.fineLabels).toMetadata()
override def schema: StructType = StructType(
StructField("coarseLabel", DoubleType, nullable = false, coarseLabelMetadata) ::
StructField("label", DoubleType, nullable = false, labelMetadata) ::
- StructField("features", VectorUDT(), nullable = false, featureMetadata) :: Nil)
+ StructField("features", VectorType, nullable = false, featureMetadata) :: Nil)
CifarRecordReader.setFormat(hadoopConf, "CIFAR-100")
}
@@ -78,7 +79,7 @@ private abstract class CifarRelation(
path: Path,
maxSplitSize: Option[Long] = None)
(val sqlContext: SQLContext)
- extends BaseRelation with PrunedScan with Logging {
+ extends BaseRelation with PrunedScan {
protected val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala b/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala
index 8c67597..7212e97 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/cifar/cifar.scala
@@ -53,8 +53,8 @@ private class CifarReader
def next(): CifarRecord = {
val row = CifarRecord(
coarseLabel = format match {
- case CifarFormats._100 => stream.readUnsignedByte()
- case CifarFormats._10 => 0
+ case CifarFormats.Cifar100 => stream.readUnsignedByte()
+ case CifarFormats.Cifar10 => 0
},
fineLabel = stream.readUnsignedByte(),
image = readImage match {
@@ -110,12 +110,12 @@ private object CifarFormats {
sealed abstract class Format(val name: String, val recordSize: Int)
- case object _10 extends Format("CIFAR-10", 1 + IMAGE_FIELD_SIZE) {
+ case object Cifar10 extends Format("CIFAR-10", 1 + IMAGE_FIELD_SIZE) {
val labels = read("batches.meta.txt")
}
- case object _100 extends Format("CIFAR-100", 1 + 1 + IMAGE_FIELD_SIZE) {
+ case object Cifar100 extends Format("CIFAR-100", 1 + 1 + IMAGE_FIELD_SIZE) {
val fineLabels = read("fine_label_names.txt")
val coarseLabels = read("coarse_label_names.txt")
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala b/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala
index ecf9baf..94a43b3 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/iris/relation.scala
@@ -18,14 +18,15 @@
package ai.cookie.spark.sql.sources.iris
import ai.cookie.spark.sql.types.VectorUDT
-import org.apache.spark.Logging
-import org.apache.spark.ml.attribute.{Attribute, NumericAttribute, AttributeGroup, NominalAttribute}
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources._
import ai.cookie.spark.sql.sources.libsvm.LibSVMRelation
+import org.apache.spark.ml
+import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
/**
* Iris dataset as a Spark SQL relation.
@@ -51,11 +52,11 @@ private class IrisLibSVMRelation(override val path: String)
*/
private class IrisCsvRelation(val path: String)
(@transient override val sqlContext: SQLContext)
- extends BaseRelation with TableScan with IrisRelation with Logging {
+ extends BaseRelation with TableScan with IrisRelation {
override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false, metadata = labelMetadata) ::
- StructField("features", VectorUDT(), nullable = false, metadata = featuresMetadata) :: Nil)
+ StructField("features", VectorType, nullable = false, metadata = featuresMetadata) :: Nil)
override def buildScan(): RDD[Row] = {
val sc = sqlContext.sparkContext
@@ -70,7 +71,7 @@ private class IrisCsvRelation(val path: String)
case "Iris-virginica" => 2.0
},
// features
- Vectors.dense(a.slice(0,4).map(_.toDouble)))
+ ml.linalg.Vectors.dense(a.slice(0,4).map(_.toDouble)))
case _ => sys.error("unrecognized format")
}
}
@@ -124,4 +125,4 @@ object DefaultSource {
* The format (csv or libsvm)
*/
val Format = "format"
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala b/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala
index a57e6af..33b7043 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/libsvm/relation.scala
@@ -17,13 +17,13 @@
package ai.cookie.spark.sql.sources.libsvm
-import org.apache.spark.Logging
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, TableScan}
import org.apache.spark.sql.types.{DoubleType, Metadata, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
import ai.cookie.spark.sql.types.VectorUDT
+import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
/**
* LibSVMRelation provides a DataFrame representation of LibSVM formatted data.
@@ -32,7 +32,7 @@ import ai.cookie.spark.sql.types.VectorUDT
*/
private[spark] class LibSVMRelation(val path: String, val numFeatures: Option[Int] = None)
(@transient val sqlContext: SQLContext)
- extends BaseRelation with TableScan with Logging {
+ extends BaseRelation with TableScan {
protected def labelMetadata: Metadata = Metadata.empty
@@ -40,7 +40,7 @@ private[spark] class LibSVMRelation(val path: String, val numFeatures: Option[In
override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false, metadata = labelMetadata) ::
- StructField("features", VectorUDT(), nullable = false, metadata = featuresMetadata) :: Nil)
+ StructField("features", VectorType, nullable = false, metadata = featuresMetadata) :: Nil)
override def buildScan(): RDD[Row] = {
@@ -52,7 +52,7 @@ private[spark] class LibSVMRelation(val path: String, val numFeatures: Option[In
}
baseRdd.map(pt => {
- Row(pt.label, pt.features.toDense)
+ Row(pt.label, pt.features.toDense.asML)
})
}
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala
index 7792d91..1fc9a1f 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistInputFormat.scala
@@ -21,8 +21,7 @@ package ai.cookie.spark.sql.sources.mnist
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
import org.apache.hadoop.mapreduce.{InputSplit => HadoopInputSplit, JobContext, RecordReader => HadoopRecordReader, TaskAttemptContext}
-import org.apache.spark.Logging
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Row
import ai.cookie.spark.sql.sources.mapreduce.PrunedReader
import ai.cookie.spark.sql.types.Conversions._
@@ -40,7 +39,7 @@ private class MnistInputFormat
private class MnistRecordReader()
extends HadoopRecordReader[String,Row]
- with Logging {
+ {
private var imageParser: MnistImageReader = null
private var labelParser: MnistLabelReader = null
@@ -61,7 +60,7 @@ private class MnistRecordReader()
.getOrElse(throw new RuntimeException("expected labelsPath"))
imageParser = new MnistImageReader(file.getPath)
labelParser = new MnistLabelReader(labelsPath)
- //val recordRange = calculateRange(file, parser)
+ // val recordRange = calculateRange(file, parser)
// calculate the range of records to scan, based on byte-level split information
val start = imageParser.recordAt(file.getStart)
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala
index e4ddce6..b59acbc 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/mnist/MnistRelation.scala
@@ -22,16 +22,17 @@ import java.nio.file.Paths
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.annotation.{Experimental, AlphaComponent}
-import org.apache.spark.{SparkContext, Partition, TaskContext, Logging}
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.annotation.{AlphaComponent, Experimental}
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.sources.{PrunedScan, BaseRelation, RelationProvider}
+import org.apache.spark.sql.sources.{BaseRelation, PrunedScan, RelationProvider}
import ai.cookie.spark.ml.attribute.AttributeKeys
import ai.cookie.spark.sql.sources.mapreduce.PrunedReader
import ai.cookie.spark.sql.types.VectorUDT
+import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.types.MetadataBuilder
/**
@@ -43,7 +44,7 @@ private case class MnistRelation(
maxRecords: Option[Int] = None,
maxSplitSize: Option[Long] = None)
(@transient val sqlContext: SQLContext) extends BaseRelation
- with PrunedScan with Logging {
+ with PrunedScan {
private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
@@ -60,7 +61,7 @@ private case class MnistRelation(
override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
- StructField("features", VectorUDT(), nullable = false, featureMetadata) :: Nil)
+ StructField("features", VectorType, nullable = false, featureMetadata) :: Nil)
override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
val sc = sqlContext.sparkContext
diff --git a/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala b/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala
index cfebef3..0b33e8f 100644
--- a/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala
+++ b/src/main/scala/ai/cookie/spark/sql/sources/mnist/mnist.scala
@@ -31,8 +31,9 @@ private[mnist] class MnistLabelReader(path: Path)(implicit conf: Configuration)
fs.open(path)
}
- if(stream.readInt() != MnistLabelReader.HEADER_MAGIC)
+ if (stream.readInt() != MnistLabelReader.HEADER_MAGIC) {
throw new IOException("labels database file is unreadable")
+ }
val numLabels: Int = stream.readInt()
@@ -69,8 +70,9 @@ private[mnist] class MnistImageReader(path: Path)(implicit conf: Configuration)
fs.open(path)
}
- if(stream.readInt() != MnistImageReader.HEADER_MAGIC)
+ if (stream.readInt() != MnistImageReader.HEADER_MAGIC) {
throw new IOException("images database file is unreadable")
+ }
val numImages: Int = stream.readInt()
@@ -105,4 +107,4 @@ private[mnist] class MnistImageReader(path: Path)(implicit conf: Configuration)
private object MnistImageReader {
val HEADER_SIZE = 16
val HEADER_MAGIC = 0x00000803
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/ai/cookie/spark/sql/types/package.scala b/src/main/scala/ai/cookie/spark/sql/types/package.scala
index 36686bf..ab7c2c8 100644
--- a/src/main/scala/ai/cookie/spark/sql/types/package.scala
+++ b/src/main/scala/ai/cookie/spark/sql/types/package.scala
@@ -18,8 +18,8 @@
package ai.cookie.spark.sql
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.Vectors
package object types {
private[spark] object Conversions {
@@ -39,4 +39,4 @@ package object types {
Vectors.dense(data)
}
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/ai/cookie/spark/sql/types/vectors.scala b/src/main/scala/ai/cookie/spark/sql/types/vectors.scala
index 1cff37e..9f86463 100644
--- a/src/main/scala/ai/cookie/spark/sql/types/vectors.scala
+++ b/src/main/scala/ai/cookie/spark/sql/types/vectors.scala
@@ -17,14 +17,14 @@
package ai.cookie.spark.sql.types
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{DataType, SQLUserDefinedType}
private[spark] object VectorUDT {
/**
- * Get the UDT associated with the {@code org.apache.spark.mllib.linalg.Vector} type.
+ * Get the UDT associated with the {@code org.apache.spark.ml.linalg.Vector} type.
*/
def apply(): DataType = {
classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
}
-}
\ No newline at end of file
+}
diff --git a/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala b/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala
index ec728b3..698ea9e 100644
--- a/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala
+++ b/src/test/scala/ai/cookie/spark/ml/feature/IndexToString.scala
@@ -18,11 +18,11 @@
package ai.cookie.spark.ml.feature
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{StringType, DataType}
+import org.apache.spark.sql.types.{DataType, StringType}
/**
* A transformer that maps an ML column of label indices to
@@ -41,7 +41,7 @@ class IndexToString(override val uid: String)
throw new UnsupportedOperationException();
}
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val schema = transformSchema(dataset.schema, logging = true)
val values = Attribute.fromStructField(schema($(inputCol))) match {
@@ -49,8 +49,8 @@ class IndexToString(override val uid: String)
case _ => throw new UnsupportedOperationException("input column must be a nominal column")
}
- dataset.withColumn($(outputCol),
- callUDF((index: Double) => values(index.toInt), outputDataType, dataset($(inputCol))))
+ val toStringUdf = udf((index: Double) => values(index.toInt))
+ dataset.withColumn($(outputCol), toStringUdf(dataset($(inputCol))))
}
override protected def outputDataType: DataType = StringType
diff --git a/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala b/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala
index 1643534..00cf0c4 100644
--- a/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala
+++ b/src/test/scala/ai/cookie/spark/sql/sources/SharedSQLContext.scala
@@ -37,4 +37,4 @@ private[sql] abstract trait SharedSQLContext extends SharedSparkContext { self:
_sqlContext = null
super.afterAll()
}
-}
\ No newline at end of file
+}
diff --git a/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala b/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala
index 4fdcea5..3111433 100644
--- a/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala
+++ b/src/test/scala/ai/cookie/spark/sql/sources/cifar/CifarRelationSuite.scala
@@ -25,7 +25,7 @@ import ai.cookie.spark.sql.sources.SharedSQLContext
import ai.cookie.spark.sql.types.Conversions._
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructField
import org.scalatest.{FunSuite, Matchers}
@@ -34,8 +34,8 @@ class CifarRelationSuite extends FunSuite
with SharedSQLContext with Matchers {
private val testDatasets = Seq(
- (CifarFormats._100, new Path("src/test/resources/cifar-100-binary/sample.bin"), 100),
- (CifarFormats._10, new Path("src/test/resources/cifar-10-batches-bin/sample.bin"), 100)
+ (CifarFormats.Cifar100, new Path("src/test/resources/cifar-100-binary/sample.bin"), 100),
+ (CifarFormats.Cifar10, new Path("src/test/resources/cifar-10-batches-bin/sample.bin"), 100)
)
private def recordStream(implicit parser: CifarReader): Stream[CifarRecord] = {
@@ -65,11 +65,11 @@ with SharedSQLContext with Matchers {
}
format match {
- case CifarFormats._10 =>
- values(df.schema("label")) should equal(CifarFormats._10.labels)
- case CifarFormats._100 =>
- values(df.schema("label")) should equal(CifarFormats._100.fineLabels)
- values(df.schema("coarseLabel")) should equal(CifarFormats._100.coarseLabels)
+ case CifarFormats.Cifar10 =>
+ values(df.schema("label")) should equal(CifarFormats.Cifar10.labels)
+ case CifarFormats.Cifar100 =>
+ values(df.schema("label")) should equal(CifarFormats.Cifar100.fineLabels)
+ values(df.schema("coarseLabel")) should equal(CifarFormats.Cifar100.coarseLabels)
}
val featureMetadata = df.schema("features").metadata
@@ -111,7 +111,7 @@ with SharedSQLContext with Matchers {
for((format, path, count) <- testDatasets) {
format match {
- case CifarFormats._10 =>
+ case CifarFormats.Cifar10 =>
val df = sqlContext.read.cifar(path.toString, format.name, Some(Long.MaxValue))
.select("label", "features")
@@ -133,7 +133,7 @@ with SharedSQLContext with Matchers {
df.stat.freqItems(Seq("label")).show
- case CifarFormats._100 =>
+ case CifarFormats.Cifar100 =>
val df = sqlContext.read.cifar(path.toString, format.name, Some(Long.MaxValue))
.select("coarseLabel", "label", "features")
diff --git a/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala b/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala
index eb63e37..f22454c 100644
--- a/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala
+++ b/src/test/scala/ai/cookie/spark/sql/sources/mnist/MnistRelationSuite.scala
@@ -24,7 +24,7 @@ import ai.cookie.spark.sql.sources.SharedSQLContext
import ai.cookie.spark.sql.types.Conversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.ml.linalg.Vector
import org.scalatest.{FunSuite, Matchers}
class MnistRelationSuite extends FunSuite
diff --git a/version.sbt b/version.sbt
deleted file mode 100644
index 13416e2..0000000
--- a/version.sbt
+++ /dev/null
@@ -1 +0,0 @@
-version in ThisBuild := "0.2-SNAPSHOT"
\ No newline at end of file