Skip to content

Commit

Permalink
Add Flink module (#606)
Browse files Browse the repository at this point in the history
* [wip] Add FlinkSource

# Conflicts:
#	build.sbt

* current wip

* Minimal working solution with tests

* Refactor + add SparkEval tests

* Scalafmt fixes

* Update Flink to only build on 2.12

* Tweaks to make build happy

* Try using getExecutionEnv to see if it passes CI

* Yank flaky test for now

* Refactor GroupByServingInfoParsed

# Conflicts:
#	flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala
#	flink/src/main/scala/ai/chronon/flink/FlinkJob.scala

* Fix build

* Use version matrix for Flink

* Address PR comments

* Add scaladocs, fix more review comments

# Conflicts:
#	flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala
  • Loading branch information
piyushn-stripe authored Nov 22, 2023
1 parent 020e8f9 commit 882db59
Show file tree
Hide file tree
Showing 12 changed files with 744 additions and 2 deletions.
24 changes: 23 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enablePlugins(GitVersioning, GitBranchPrompt)
lazy val supportedVersions = List(scala211, scala212, scala213)

lazy val root = (project in file("."))
.aggregate(api, aggregator, online, spark_uber, spark_embedded)
.aggregate(api, aggregator, online, spark_uber, spark_embedded, flink)
.settings(
publish / skip := true,
crossScalaVersions := Nil,
Expand Down Expand Up @@ -161,6 +161,17 @@ val VersionMatrix: Map[String, VersionDependency] = Map(
Some("1.8.2"),
Some("1.10.2")
),
"flink" -> VersionDependency(
Seq(
"org.apache.flink" %% "flink-streaming-scala",
"org.apache.flink" % "flink-metrics-dropwizard",
"org.apache.flink" % "flink-clients",
"org.apache.flink" % "flink-test-utils"
),
None,
Some("1.16.1"),
None
),
"netty-buffer" -> VersionDependency(
Seq(
"io.netty" % "netty-buffer"
Expand Down Expand Up @@ -345,6 +356,17 @@ lazy val spark_embedded = (project in file("spark"))
Test / test := {}
)

lazy val flink = (project in file("flink"))
.dependsOn(aggregator.%("compile->compile;test->test"), online)
.settings(
crossScalaVersions := List(scala212),
libraryDependencies ++= fromMatrix(scalaVersion.value,
"avro",
"spark-all/provided",
"scala-parallel-collections",
"flink")
)

// Build Sphinx documentation
lazy val sphinx = taskKey[Unit]("Build Sphinx Documentation")
sphinx := {
Expand Down
117 changes: 117 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package ai.chronon.flink

import ai.chronon.online.{Api, KVStore}
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.configuration.Configuration
import org.apache.flink.metrics.Counter
import org.apache.flink.streaming.api.functions.async.{ResultFuture, RichAsyncFunction}
import org.apache.flink.streaming.api.datastream.AsyncDataStream
import org.apache.flink.streaming.api.scala.DataStream

import java.util
import java.util.concurrent.TimeUnit
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}

case class WriteResponse(putRequest: PutRequest, status: Boolean)

object AsyncKVStoreWriter {
private val kvStoreConcurrency = 10
private val defaultTimeoutMillis = 1000L

def withUnorderedWaits(inputDS: DataStream[PutRequest],
kvStoreWriterFn: RichAsyncFunction[PutRequest, WriteResponse],
featureGroupName: String,
timeoutMillis: Long = defaultTimeoutMillis,
capacity: Int = kvStoreConcurrency): DataStream[WriteResponse] = {
// We use the Java API here as we have encountered issues in integration tests in the
// past using the Scala async datastream API.
new DataStream(
AsyncDataStream
.unorderedWait(
inputDS.javaStream,
kvStoreWriterFn,
timeoutMillis,
TimeUnit.MILLISECONDS,
capacity
)
.uid(s"kvstore-writer-async-$featureGroupName")
.name(s"async kvstore writes for $featureGroupName")
.setParallelism(inputDS.parallelism)
)
}

/**
* This was moved to flink-rpc-akka in Flink 1.16 and made private, so we reproduce the direct execution context here
*/
private class DirectExecutionContext extends ExecutionContext {
override def execute(runnable: Runnable): Unit =
runnable.run()

override def reportFailure(cause: Throwable): Unit =
throw new IllegalStateException("Error in direct execution context.", cause)

override def prepare: ExecutionContext = this
}

private val ExecutionContextInstance: ExecutionContext = new DirectExecutionContext
}

/**
* Async Flink writer function to help us write to the KV store.
* @param onlineImpl - Instantiation of the Chronon API to help create KV store objects
* @param featureGroupName Name of the FG we're writing to
*/
class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String)
extends RichAsyncFunction[PutRequest, WriteResponse] {

@transient private var kvStore: KVStore = _

@transient private var errorCounter: Counter = _
@transient private var successCounter: Counter = _

// The context used for the future callbacks
implicit lazy val executor: ExecutionContext = AsyncKVStoreWriter.ExecutionContextInstance

protected def getKVStore: KVStore = {
onlineImpl.genKvStore
}

override def open(configuration: Configuration): Unit = {
val group = getRuntimeContext.getMetricGroup
.addGroup("chronon")
.addGroup("feature_group", featureGroupName)
errorCounter = group.counter("kvstore_writer.errors")
successCounter = group.counter("kvstore_writer.successes")

kvStore = getKVStore
}

override def timeout(input: PutRequest, resultFuture: ResultFuture[WriteResponse]): Unit = {
println(s"Timed out writing to KV Store for object: $input")
errorCounter.inc()
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = false)))
}

override def asyncInvoke(input: PutRequest, resultFuture: ResultFuture[WriteResponse]): Unit = {
val resultFutureRequested: Future[Seq[Boolean]] = kvStore.multiPut(Seq(input))
resultFutureRequested.onComplete {
case Success(l) =>
val succeeded = l.forall(identity)
if (succeeded) {
successCounter.inc()
} else {
errorCounter.inc()
println(s"Failed to write to KVStore for object: $input")
}
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = succeeded)))
case Failure(exception) =>
// this should be rare and indicates we have an uncaught exception
// in the KVStore - we log the exception and skip the object to
// not fail the app
errorCounter.inc()
println(s"Caught exception writing to KVStore for object: $input - $exception")
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = false)))
}
}
}
100 changes: 100 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package ai.chronon.flink

import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.{Constants, DataModel, Query, StructType => ChrononStructType}
import ai.chronon.online.{AvroConversions, GroupByServingInfoParsed}
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.configuration.Configuration
import org.apache.flink.metrics.Counter
import org.apache.flink.util.Collector

import scala.jdk.CollectionConverters._

/**
* A Flink function that is responsible for converting the Spark expr eval output and converting that to a form
* that can be written out to the KV store (PutRequest object)
* @param groupByServingInfoParsed The GroupBy we are working with
* @tparam T The input data type
*/
case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
extends RichFlatMapFunction[Map[String, Any], PutRequest] {

@transient protected var avroConversionErrorCounter: Counter = _

protected val query: Query = groupByServingInfoParsed.groupBy.streamingSource.get.getEvents.query
protected val streamingDataset: String = groupByServingInfoParsed.groupBy.streamingDataset

// TODO: update to use constant names that are company specific
protected val timeColumnAlias: String = Constants.TimeColumn
protected val timeColumn: String = Option(query.timeColumn).getOrElse(timeColumnAlias)

protected val (keyToBytes, valueToBytes): (Any => Array[Byte], Any => Array[Byte]) =
getKVSerializers(groupByServingInfoParsed)
protected val (keyColumns, valueColumns): (Array[String], Array[String]) = getKVColumns
protected val extraneousRecord: Any => Array[Any] = {
case x: Map[_, _] if x.keys.forall(_.isInstanceOf[String]) =>
x.flatMap { case (key, value) => Array(key, value) }.toArray
}

private lazy val getKVSerializers = (
groupByServingInfoParsed: GroupByServingInfoParsed
) => {
val keyZSchema: ChrononStructType = groupByServingInfoParsed.keyChrononSchema
val valueZSchema: ChrononStructType = groupByServingInfoParsed.groupBy.dataModel match {
case DataModel.Events => groupByServingInfoParsed.valueChrononSchema
case _ =>
throw new IllegalArgumentException(
s"Only the events based data model is supported at the moment - ${groupByServingInfoParsed.groupBy}"
)
}

(
AvroConversions.encodeBytes(keyZSchema, extraneousRecord),
AvroConversions.encodeBytes(valueZSchema, extraneousRecord)
)
}

private lazy val getKVColumns: (Array[String], Array[String]) = {
val keyColumns = groupByServingInfoParsed.groupBy.keyColumns.asScala.toArray
val (additionalColumns, _) = groupByServingInfoParsed.groupBy.dataModel match {
case DataModel.Events =>
Seq.empty[String] -> timeColumn
case _ =>
throw new IllegalArgumentException(
s"Only the events based data model is supported at the moment - ${groupByServingInfoParsed.groupBy}"
)
}
val valueColumns = groupByServingInfoParsed.groupBy.aggregationInputs ++ additionalColumns
(keyColumns, valueColumns)
}

override def open(configuration: Configuration): Unit = {
super.open(configuration)
val metricsGroup = getRuntimeContext.getMetricGroup
.addGroup("chronon")
.addGroup("feature_group", groupByServingInfoParsed.groupBy.getMetaData.getName)
avroConversionErrorCounter = metricsGroup.counter("avro_conversion_errors")
}

override def close(): Unit = super.close()

override def flatMap(value: Map[String, Any], out: Collector[PutRequest]): Unit =
try {
out.collect(avroConvertMapToPutRequest(value))
} catch {
case e: Exception =>
// To improve availability, we don't rethrow the exception. We just drop the event
// and track the errors in a metric. If there are too many errors we'll get alerted/paged.
println(s"Error converting to Avro bytes - $e")
avroConversionErrorCounter.inc()
}

def avroConvertMapToPutRequest(in: Map[String, Any]): PutRequest = {
val tsMills = in(timeColumnAlias).asInstanceOf[Long]
val keyBytes = keyToBytes(keyColumns.map(in.get(_).get))
val valueBytes = valueToBytes(valueColumns.map(in.get(_).get))
PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills))
}

}
73 changes: 73 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package ai.chronon.flink

import ai.chronon.api.Extensions.{GroupByOps, SourceOps}
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.spark.sql.Encoder
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction

/**
* Flink job that processes a single streaming GroupBy and writes out the results
* (raw events in untiled, pre-aggregates in case of tiled) to the KV store.
* At a high level, the operators are structured as follows:
* Kafka source -> Spark expression eval -> Avro conversion -> KV store writer
* Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic
* Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data
* Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store (PutRequest object)
* KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API
*
* In the untiled version there are no-shuffles and thus this ends up being a single node in the Flink DAG
* (with the above 4 operators and parallelism as injected by the user)
*
* @param eventSrc - Provider of a Flink Datastream[T] for the given topic and feature group
* @param sinkFn - Async Flink writer function to help us write to the KV store
* @param groupByServingInfoParsed - The GroupBy we are working with
* @param encoder - Spark Encoder for the input data type
* @param parallelism - Parallelism to use for the Flink job
* @tparam T - The input data type
*/
class FlinkJob[T](eventSrc: FlinkSource[T],
sinkFn: RichAsyncFunction[PutRequest, WriteResponse],
groupByServingInfoParsed: GroupByServingInfoParsed,
encoder: Encoder[T],
parallelism: Int) {

protected val exprEval: SparkExpressionEvalFn[T] =
new SparkExpressionEvalFn[T](encoder, groupByServingInfoParsed.groupBy)
val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName

if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) {
throw new IllegalArgumentException(
s"Invalid feature group: $featureGroupName. No streaming source"
)
}

// The source of our Flink application is a Kafka topic
val kafkaTopic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic

def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = {
val sourceStream: DataStream[T] =
eventSrc
.getDataStream(kafkaTopic, featureGroupName)(env, parallelism)

val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream
.flatMap(exprEval)
.uid(s"spark-expr-eval-flatmap-$featureGroupName")
.name(s"Spark expression eval for $featureGroupName")
.setParallelism(sourceStream.parallelism) // Use same parallelism as previous operator

val putRecordDS: DataStream[PutRequest] = sparkExprEvalDS
.flatMap(AvroCodecFn[T](groupByServingInfoParsed))
.uid(s"avro-conversion-$featureGroupName")
.name(s"Avro conversion for $featureGroupName")
.setParallelism(sourceStream.parallelism)

AsyncKVStoreWriter.withUnorderedWaits(
putRecordDS,
sinkFn,
featureGroupName
)
}
}
14 changes: 14 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/FlinkSource.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package ai.chronon.flink

import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}

abstract class FlinkSource[T] extends Serializable {

/**
* Return a Flink DataStream for the given topic and feature group.
*/
def getDataStream(topic: String, groupName: String)(
env: StreamExecutionEnvironment,
parallelism: Int
): DataStream[T]
}
Loading

0 comments on commit 882db59

Please sign in to comment.