diff --git a/docs/layouts/shortcodes/generated/execution_config_configuration.html b/docs/layouts/shortcodes/generated/execution_config_configuration.html index 87b89981bc51a..6d28c767f6e14 100644 --- a/docs/layouts/shortcodes/generated/execution_config_configuration.html +++ b/docs/layouts/shortcodes/generated/execution_config_configuration.html @@ -129,6 +129,24 @@ Boolean Set whether to compact the changes sent downstream in row-time mini-batch. If true, Flink will compact changes and send only the latest change downstream. Note that if the downstream needs the details of versioned data, this optimization cannot be applied. If false, Flink will send all changes to downstream just like when the mini-batch is not enabled. + +
table.exec.delta-join.cache-enabled

Streaming + true + Boolean + Whether to enable the cache of delta join. If enabled, the delta join caches the records from remote dim table. Default is true. + + +
table.exec.delta-join.left.cache-size

Streaming + 10000 + Long + The cache size used to cache the lookup results of the left table in delta join. This value must be positive when enabling cache. Default is 10000. + + +
table.exec.delta-join.right.cache-size

Streaming + 10000 + Long + The cache size used to cache the lookup results of the right table in delta join. This value must be positive when enabling cache. Default is 10000. +
table.exec.disabled-operators

Batch (none) diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java index 1e0fb43ebfec0..aca9853a38eeb 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java @@ -725,6 +725,33 @@ public class ExecutionConfigOptions { "Set whether to use the SQL/Table operators based on the asynchronous state api. " + "Default value is false."); + @Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING) + public static final ConfigOption TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED = + key("table.exec.delta-join.cache-enabled") + .booleanType() + .defaultValue(true) + .withDescription( + "Whether to enable the cache of delta join. If enabled, the delta join caches the " + + "records from remote dim table. Default is true."); + + @Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING) + public static final ConfigOption TABLE_EXEC_DELTA_JOIN_LEFT_CACHE_SIZE = + key("table.exec.delta-join.left.cache-size") + .longType() + .defaultValue(10000L) + .withDescription( + "The cache size used to cache the lookup results of the left table in delta join. " + + "This value must be positive when enabling cache. Default is 10000."); + + @Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING) + public static final ConfigOption TABLE_EXEC_DELTA_JOIN_RIGHT_CACHE_SIZE = + key("table.exec.delta-join.right.cache-size") + .longType() + .defaultValue(10000L) + .withDescription( + "The cache size used to cache the lookup results of the right table in delta join. " + + "This value must be positive when enabling cache. Default is 10000."); + // ------------------------------------------------------------------------------------------ // Enum option types // ------------------------------------------------------------------------------------------ diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index 6d368618c9130..962cc5505414d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -20,10 +20,12 @@ import org.apache.flink.FlinkVersion; import org.apache.flink.api.dag.Transformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.ReadableConfig; import org.apache.flink.streaming.api.functions.async.AsyncFunction; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.transformations.TwoInputTransformation; +import org.apache.flink.table.api.config.ExecutionConfigOptions; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.conversion.DataStructureConverter; @@ -83,6 +85,7 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecDeltaJoin.DELTA_JOIN_TRANSFORMATION; import static org.apache.flink.table.planner.plan.utils.DeltaJoinUtil.getUnwrappedAsyncLookupFunction; @@ -234,11 +237,17 @@ protected Transformation translateToPlanInternal( RowDataKeySelector leftJoinKeySelector = KeySelectorUtil.getRowDataSelector( classLoader, leftJoinKeys, InternalTypeInfo.of(leftStreamType)); + // currently, delta join only supports consuming INSERT-ONLY stream + RowDataKeySelector leftUpsertKeySelector = + getUpsertKeySelector(new int[0], leftStreamType, classLoader); // right side selector RowDataKeySelector rightJoinKeySelector = KeySelectorUtil.getRowDataSelector( classLoader, rightJoinKeys, InternalTypeInfo.of(rightStreamType)); + // currently, delta join only supports consuming INSERT-ONLY stream + RowDataKeySelector rightUpsertKeySelector = + getUpsertKeySelector(new int[0], rightStreamType, classLoader); StreamOperatorFactory operatorFactory = createAsyncLookupDeltaJoin( @@ -252,7 +261,9 @@ protected Transformation translateToPlanInternal( leftStreamType, rightStreamType, leftJoinKeySelector, + leftUpsertKeySelector, rightJoinKeySelector, + rightUpsertKeySelector, classLoader); final TwoInputTransformation transform = @@ -282,7 +293,9 @@ private StreamOperatorFactory createAsyncLookupDeltaJoin( RowType leftStreamType, RowType rightStreamType, RowDataKeySelector leftJoinKeySelector, + RowDataKeySelector leftUpsertKeySelector, RowDataKeySelector rightJoinKeySelector, + RowDataKeySelector rightUpsertKeySelector, ClassLoader classLoader) { DataTypeFactory dataTypeFactory = @@ -299,6 +312,10 @@ private StreamOperatorFactory createAsyncLookupDeltaJoin( leftStreamType, rightStreamType, leftLookupKeys, + leftJoinKeySelector, + leftUpsertKeySelector, + rightJoinKeySelector, + rightUpsertKeySelector, false); AsyncDeltaJoinRunner rightLookupTableAsyncFunction = @@ -312,8 +329,14 @@ private StreamOperatorFactory createAsyncLookupDeltaJoin( leftStreamType, rightStreamType, rightLookupKeys, + leftJoinKeySelector, + leftUpsertKeySelector, + rightJoinKeySelector, + rightUpsertKeySelector, true); + Tuple2 leftRightCacheSize = getCacheSize(config); + return new StreamingDeltaJoinOperatorFactory( rightLookupTableAsyncFunction, leftLookupTableAsyncFunction, @@ -321,6 +344,8 @@ private StreamOperatorFactory createAsyncLookupDeltaJoin( rightJoinKeySelector, asyncLookupOptions.asyncTimeout, asyncLookupOptions.asyncBufferCapacity, + leftRightCacheSize.f0, + leftRightCacheSize.f1, leftStreamType, rightStreamType); } @@ -336,6 +361,10 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner( RowType leftStreamSideType, RowType rightStreamSideType, Map lookupKeys, + RowDataKeySelector leftJoinKeySelector, + RowDataKeySelector leftUpsertKeySelector, + RowDataKeySelector rightJoinKeySelector, + RowDataKeySelector rightUpsertKeySelector, boolean treatRightAsLookupTable) { RelOptTable lookupTable = treatRightAsLookupTable ? rightTempTable : leftTempTable; RowType streamSideType = treatRightAsLookupTable ? leftStreamSideType : rightStreamSideType; @@ -409,8 +438,13 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner( (DataStructureConverter) lookupSideFetcherConverter, lookupSideGeneratedResultFuture, InternalSerializers.create(lookupTableSourceRowType), + leftJoinKeySelector, + leftUpsertKeySelector, + rightJoinKeySelector, + rightUpsertKeySelector, asyncLookupOptions.asyncBufferCapacity, - treatRightAsLookupTable); + treatRightAsLookupTable, + enableCache(config)); } /** @@ -449,4 +483,33 @@ public RexNode visitInputRef(RexInputRef inputRef) { return condition.accept(converter); } + + private RowDataKeySelector getUpsertKeySelector( + int[] upsertKey, RowType rowType, ClassLoader classLoader) { + final int[] rightUpsertKeys; + if (upsertKey.length > 0) { + rightUpsertKeys = upsertKey; + } else { + rightUpsertKeys = IntStream.range(0, rowType.getFields().size()).toArray(); + } + return KeySelectorUtil.getRowDataSelector( + classLoader, rightUpsertKeys, InternalTypeInfo.of(rowType)); + } + + private boolean enableCache(ReadableConfig config) { + return config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED); + } + + /** Get the left cache size and right size. */ + private Tuple2 getCacheSize(ReadableConfig config) { + long leftCacheSize = + config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_LEFT_CACHE_SIZE); + long rightCacheSize = + config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_RIGHT_CACHE_SIZE); + if ((leftCacheSize <= 0 || rightCacheSize <= 0) && enableCache(config)) { + throw new IllegalArgumentException( + "Cache size in delta join must be positive when enabling cache."); + } + return Tuple2.of(leftCacheSize, rightCacheSize); + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala index e4b15abb463cf..063fe973c67e9 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala @@ -20,17 +20,19 @@ package org.apache.flink.table.planner.runtime.stream.sql import org.apache.flink.core.execution.CheckpointingMode import org.apache.flink.table.api.Schema import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl -import org.apache.flink.table.api.config.OptimizerConfigOptions +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} import org.apache.flink.table.api.config.OptimizerConfigOptions.DeltaJoinStrategy import org.apache.flink.table.catalog.{CatalogTable, ObjectPath, ResolvedCatalogTable} import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AsyncTestValueLookupFunction import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow import org.apache.flink.table.planner.runtime.utils.{FailingCollectionSource, StreamingTestBase} +import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension, Parameters} import org.apache.flink.types.Row import org.assertj.core.api.Assertions.{assertThat, assertThatThrownBy} -import org.junit.jupiter.api.{BeforeEach, Test} +import org.junit.jupiter.api.{BeforeEach, TestTemplate} +import org.junit.jupiter.api.extension.ExtendWith import javax.annotation.Nullable @@ -39,7 +41,8 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ -class DeltaJoinITCase extends StreamingTestBase { +@ExtendWith(Array(classOf[ParameterizedTestExtension])) +class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { @BeforeEach override def before(): Unit = { @@ -49,10 +52,14 @@ class DeltaJoinITCase extends StreamingTestBase { OptimizerConfigOptions.TABLE_OPTIMIZER_DELTA_JOIN_STRATEGY, DeltaJoinStrategy.FORCE) + tEnv.getConfig.set( + ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED, + Boolean.box(enableCache)) + AsyncTestValueLookupFunction.invokeCount.set(0) } - @Test + @TestTemplate def testJoinKeyEqualsIndex(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -77,7 +84,7 @@ class DeltaJoinITCase extends StreamingTestBase { testUpsertResult(List("a1"), List("b1"), data1, data2, "a1 = b1", expected, 6) } - @Test + @TestTemplate def testJoinKeyContainsIndex(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -102,7 +109,7 @@ class DeltaJoinITCase extends StreamingTestBase { testUpsertResult(List("a1"), List("b1"), data1, data2, "a1 = b1 and a2 = b2", expected, 6) } - @Test + @TestTemplate def testJoinKeyNotContainsIndex(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -129,7 +136,72 @@ class DeltaJoinITCase extends StreamingTestBase { .hasMessageContaining("The current sql doesn't support to do delta join optimization.") } - @Test + @TestTemplate + def testSameJoinKeyColValuesWhileJoinKeyEqualsIndex(): Unit = { + val data1 = List( + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + // mismatch + changelogRow("+I", Double.box(3.0), Int.box(3), LocalDateTime.of(2023, 3, 3, 3, 3, 3)) + ) + + val data2 = List( + changelogRow("+I", Int.box(1), Double.box(1.0), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Int.box(1), Double.box(1.0), LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(99.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected = List( + "+I[1.0, 1, 2022-02-02T02:02:02, 1, 1.0, 2022-02-02T02:02:22]" + ) + testUpsertResult( + List("a1"), + List("b1"), + data1, + data2, + "a1 = b1", + expected, + if (enableCache) 4 else 6) + } + + @TestTemplate + def testSameJoinKeyColValuesWhileJoinKeyContainsIndex(): Unit = { + val data1 = List( + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Double.box(1.0), Int.box(2), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + // mismatch + changelogRow("+I", Double.box(3.0), Int.box(3), LocalDateTime.of(2023, 3, 3, 3, 3, 3)) + ) + + val data2 = List( + changelogRow("+I", Int.box(1), Double.box(1.0), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Int.box(2), Double.box(1.0), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(99.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected = List( + "+I[1.0, 1, 2021-01-01T01:01:01, 1, 1.0, 2021-01-01T01:01:01]", + "+I[1.0, 1, 2021-01-01T01:01:01, 2, 1.0, 2021-01-01T01:01:01]", + "+I[1.0, 2, 2021-01-01T01:01:01, 1, 1.0, 2021-01-01T01:01:01]", + "+I[1.0, 2, 2021-01-01T01:01:01, 2, 1.0, 2021-01-01T01:01:01]" + ) + testUpsertResult( + List("a1"), + List("b1"), + data1, + data2, + "a1 = b1 and a2 = b2", + expected, + if (enableCache) 4 else 6) + } + + @TestTemplate def testWithNonEquiCondition1(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -159,7 +231,7 @@ class DeltaJoinITCase extends StreamingTestBase { 6) } - @Test + @TestTemplate def testWithNonEquiCondition2(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -206,7 +278,7 @@ class DeltaJoinITCase extends StreamingTestBase { .hasMessageContaining("The current sql doesn't support to do delta join optimization.") } - @Test + @TestTemplate def testFilterFieldsBeforeJoin(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -245,7 +317,7 @@ class DeltaJoinITCase extends StreamingTestBase { .hasMessageContaining("The current sql doesn't support to do delta join optimization.") } - @Test + @TestTemplate def testProjectFieldsAfterJoin(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -271,8 +343,19 @@ class DeltaJoinITCase extends StreamingTestBase { ) tEnv - .executeSql( - s"insert into testSnk select a1, a0 + 1, a2, b0 + 2, b1, b2 from testLeft join testRight on a0 = b0") + .executeSql(""" + |insert into testSnk + | select + | a1, + | a0 + 1, + | a2, + | b0 + 2, + | b1, + | b2 + | from testLeft + | join testRight + | on a0 = b0 + |""".stripMargin) .await() val result = TestValuesTableFactory.getResultsAsStrings("testSnk") @@ -280,7 +363,7 @@ class DeltaJoinITCase extends StreamingTestBase { assertThat(AsyncTestValueLookupFunction.invokeCount.get()).isEqualTo(6) } - @Test + @TestTemplate def testProjectFieldsBeforeJoin(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -311,18 +394,23 @@ class DeltaJoinITCase extends StreamingTestBase { |) |""".stripMargin) - // could not optimize into delta join because there is ProjectPushDownSpec between join and source - assertThatThrownBy( - () => - tEnv - .executeSql( - s"insert into projectedSink select testLeft.a0, testRight.b0, testLeft.a1, testLeft.a2 " + - s"from testLeft join testRight on a0 = b0") - .await()) + // could not optimize into delta join + // because there is ProjectPushDownSpec between join and source + assertThatThrownBy(() => tEnv.executeSql(""" + |insert into projectedSink + | select + | testLeft.a0, + | testRight.b0, + | testLeft.a1, + | testLeft.a2 + | from testLeft + | join testRight + | on a0 = b0 + |""".stripMargin)) .hasMessageContaining("The current sql doesn't support to do delta join optimization.") } - @Test + @TestTemplate def testProjectFieldsBeforeJoin2(): Unit = { val data1 = List( changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), @@ -354,7 +442,7 @@ class DeltaJoinITCase extends StreamingTestBase { .hasMessageContaining("The current sql doesn't support to do delta join optimization.") } - @Test + @TestTemplate def testFailOverAndRestore(): Unit = { // enable checkpoint, we are using failing source to force have a complete checkpoint // and cover restore path @@ -504,3 +592,10 @@ class DeltaJoinITCase extends StreamingTestBase { } } + +object DeltaJoinITCase { + @Parameters(name = "EnableCache={0}") + def parameters(): java.util.Collection[Boolean] = { + Seq[Boolean](true, false) + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/StreamingDeltaJoinOperatorFactory.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/StreamingDeltaJoinOperatorFactory.java index c3d2c97daf5a7..9dabd3176780f 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/StreamingDeltaJoinOperatorFactory.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/StreamingDeltaJoinOperatorFactory.java @@ -36,13 +36,17 @@ public class StreamingDeltaJoinOperatorFactory extends AbstractStreamOperatorFac YieldingOperatorFactory { private final AsyncDeltaJoinRunner rightLookupTableAsyncFunction; - private final RowDataKeySelector rightJoinKeySelector; private final AsyncDeltaJoinRunner leftLookupTableAsyncFunction; + private final RowDataKeySelector leftJoinKeySelector; + private final RowDataKeySelector rightJoinKeySelector; private final long timeout; private final int capacity; + private final long leftSideCacheSize; + private final long rightSideCacheSize; + private final RowType leftStreamType; private final RowType rightStreamType; @@ -53,6 +57,8 @@ public StreamingDeltaJoinOperatorFactory( RowDataKeySelector rightJoinKeySelector, long timeout, int capacity, + long leftSideCacheSize, + long rightSideCacheSize, RowType leftStreamType, RowType rightStreamType) { this.rightLookupTableAsyncFunction = rightLookupTableAsyncFunction; @@ -61,6 +67,8 @@ public StreamingDeltaJoinOperatorFactory( this.rightJoinKeySelector = rightJoinKeySelector; this.timeout = timeout; this.capacity = capacity; + this.leftSideCacheSize = leftSideCacheSize; + this.rightSideCacheSize = rightSideCacheSize; this.leftStreamType = leftStreamType; this.rightStreamType = rightStreamType; } @@ -79,6 +87,8 @@ public > T createStreamOperator( capacity, processingTimeService, mailboxExecutor, + leftSideCacheSize, + rightSideCacheSize, leftStreamType, rightStreamType); deltaJoinOperator.setup( diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java index 6882314bc75ac..d73bfd315db35 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java @@ -31,15 +31,21 @@ import org.apache.flink.table.runtime.collector.TableFunctionResultFuture; import org.apache.flink.table.runtime.generated.GeneratedFunction; import org.apache.flink.table.runtime.generated.GeneratedResultFuture; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -64,6 +70,22 @@ public class AsyncDeltaJoinRunner extends RichAsyncFunction { private final boolean treatRightAsLookupTable; + private final boolean enableCache; + + /** Selector to get join key from left input. */ + private final RowDataKeySelector leftJoinKeySelector; + + /** Selector to get upsert key from left input. */ + private final RowDataKeySelector leftUpsertKeySelector; + + /** Selector to get join key from right input. */ + private final RowDataKeySelector rightJoinKeySelector; + + /** Selector to get upsert key from right input. */ + private final RowDataKeySelector rightUpsertKeySelector; + + private transient DeltaJoinCache cache; + /** * Buffers {@link ResultFuture} to avoid newInstance cost when processing elements every time. * We use {@link BlockingQueue} to make sure the head {@link ResultFuture}s are available. @@ -85,14 +107,28 @@ public AsyncDeltaJoinRunner( DataStructureConverter fetcherConverter, GeneratedResultFuture> generatedResultFuture, RowDataSerializer lookupSideRowSerializer, + RowDataKeySelector leftJoinKeySelector, + RowDataKeySelector leftUpsertKeySelector, + RowDataKeySelector rightJoinKeySelector, + RowDataKeySelector rightUpsertKeySelector, int asyncBufferCapacity, - boolean treatRightAsLookupTable) { + boolean treatRightAsLookupTable, + boolean enableCache) { this.generatedFetcher = generatedFetcher; this.fetcherConverter = fetcherConverter; this.generatedResultFuture = generatedResultFuture; this.lookupSideRowSerializer = lookupSideRowSerializer; + this.leftJoinKeySelector = leftJoinKeySelector; + this.leftUpsertKeySelector = leftUpsertKeySelector; + this.rightJoinKeySelector = rightJoinKeySelector; + this.rightUpsertKeySelector = rightUpsertKeySelector; this.asyncBufferCapacity = asyncBufferCapacity; this.treatRightAsLookupTable = treatRightAsLookupTable; + this.enableCache = enableCache; + } + + public void setCache(DeltaJoinCache cache) { + this.cache = cache; } @Override @@ -121,7 +157,11 @@ public void open(OpenContext openContext) throws Exception { resultFutureBuffer, createFetcherResultFuture(openContext), fetcherConverter, - treatRightAsLookupTable); + treatRightAsLookupTable, + leftUpsertKeySelector, + rightUpsertKeySelector, + enableCache, + cache); // add will throw exception immediately if the queue is full which should never happen resultFutureBuffer.add(rf); allResultFutures.add(rf); @@ -141,8 +181,27 @@ public void open(OpenContext openContext) throws Exception { public void asyncInvoke(RowData input, ResultFuture resultFuture) throws Exception { JoinedRowResultFuture outResultFuture = resultFutureBuffer.take(); + RowData streamJoinKey = null; + if (enableCache) { + if (treatRightAsLookupTable) { + streamJoinKey = leftJoinKeySelector.getKey(input); + cache.requestRightCache(); + } else { + streamJoinKey = rightJoinKeySelector.getKey(input); + cache.requestLeftCache(); + } + } + // the input row is copied when object reuse in StreamDeltaJoinOperator - outResultFuture.reset(input, resultFuture); + outResultFuture.reset(streamJoinKey, input, resultFuture); + + if (enableCache) { + Optional> dataFromCache = tryGetDataFromCache(streamJoinKey); + if (dataFromCache.isPresent()) { + outResultFuture.complete(dataFromCache.get()); + return; + } + } long startTime = System.currentTimeMillis(); // fetcher has copied the input field when object reuse is enabled @@ -177,6 +236,30 @@ public AsyncFunction getFetcher() { return fetcher; } + @VisibleForTesting + public DeltaJoinCache getCache() { + return cache; + } + + private Optional> tryGetDataFromCache(RowData joinKey) { + Preconditions.checkState(enableCache); + + if (treatRightAsLookupTable) { + LinkedHashMap rightRows = cache.getData(joinKey, true); + if (rightRows != null) { + cache.hitRightCache(); + return Optional.of(rightRows.values()); + } + } else { + LinkedHashMap leftRows = cache.getData(joinKey, false); + if (leftRows != null) { + cache.hitLeftCache(); + return Optional.of(leftRows.values()); + } + } + return Optional.empty(); + } + /** * The {@link JoinedRowResultFuture} is used to combine left {@link RowData} and right {@link * RowData} into {@link JoinedRowData}. @@ -191,9 +274,16 @@ public static final class JoinedRowResultFuture implements ResultFuture private final TableFunctionResultFuture joinConditionResultFuture; private final DataStructureConverter resultConverter; + private final boolean enableCache; + private final DeltaJoinCache cache; + private final DelegateResultFuture delegate; private final boolean treatRightAsLookupTable; + private final RowDataKeySelector leftUpsertKeySelector; + private final RowDataKeySelector rightUpsertKeySelector; + + private @Nullable RowData streamJoinKey; private RowData streamRow; private ResultFuture realOutput; @@ -201,16 +291,28 @@ private JoinedRowResultFuture( BlockingQueue resultFutureBuffer, TableFunctionResultFuture joinConditionResultFuture, DataStructureConverter resultConverter, - boolean treatRightAsLookupTable) { + boolean treatRightAsLookupTable, + RowDataKeySelector leftUpsertKeySelector, + RowDataKeySelector rightUpsertKeySelector, + boolean enableCache, + DeltaJoinCache cache) { this.resultFutureBuffer = resultFutureBuffer; this.joinConditionResultFuture = joinConditionResultFuture; this.resultConverter = resultConverter; + this.enableCache = enableCache; + this.cache = cache; this.delegate = new DelegateResultFuture(); this.treatRightAsLookupTable = treatRightAsLookupTable; + this.leftUpsertKeySelector = leftUpsertKeySelector; + this.rightUpsertKeySelector = rightUpsertKeySelector; } - public void reset(RowData row, ResultFuture realOutput) { + public void reset( + @Nullable RowData joinKey, RowData row, ResultFuture realOutput) { + Preconditions.checkState( + (enableCache && joinKey != null) || (!enableCache && joinKey == null)); this.realOutput = realOutput; + this.streamJoinKey = joinKey; this.streamRow = row; joinConditionResultFuture.setInput(row); joinConditionResultFuture.setResultFuture(delegate); @@ -219,6 +321,19 @@ public void reset(RowData row, ResultFuture realOutput) { @Override public void complete(Collection result) { + if (result == null) { + result = Collections.emptyList(); + } + + // now we have received the rows from the lookup table, try to set them to the cache + try { + updateCacheIfNecessary(result); + } catch (Throwable t) { + LOG.info("Failed to update the cache", t); + completeExceptionally(t); + return; + } + Collection rowDataCollection = convertToInternalData(result); // call condition collector first, // the filtered result will be routed to the delegateCollector @@ -273,6 +388,61 @@ public void close() throws Exception { joinConditionResultFuture.close(); } + private void updateCacheIfNecessary(Collection lookupRows) throws Exception { + if (!enableCache) { + return; + } + + // 1. build the cache in lookup side if not exists + // 2. update the cache in stream side if exists + if (treatRightAsLookupTable) { + if (cache.getData(streamJoinKey, true) == null) { + cache.buildCache(streamJoinKey, buildMapWithUkAsKeys(lookupRows, true), true); + } + + LinkedHashMap leftCacheData = cache.getData(streamJoinKey, false); + if (leftCacheData != null) { + RowData uk = leftUpsertKeySelector.getKey(streamRow); + cache.upsertCache(streamJoinKey, uk, streamRow, false); + } + } else { + if (cache.getData(streamJoinKey, false) == null) { + cache.buildCache(streamJoinKey, buildMapWithUkAsKeys(lookupRows, false), false); + } + + LinkedHashMap rightCacheData = cache.getData(streamJoinKey, true); + if (rightCacheData != null) { + RowData uk = rightUpsertKeySelector.getKey(streamRow); + cache.upsertCache(streamJoinKey, uk, streamRow, true); + } + } + } + + private LinkedHashMap buildMapWithUkAsKeys( + Collection lookupRows, boolean treatRightAsLookupTable) throws Exception { + LinkedHashMap map = new LinkedHashMap<>(); + for (Object lookupRow : lookupRows) { + RowData rowData = convertToInternalData(lookupRow); + RowData uk; + if (treatRightAsLookupTable) { + uk = rightUpsertKeySelector.getKey(rowData); + map.put(uk, lookupRow); + } else { + uk = leftUpsertKeySelector.getKey(rowData); + map.put(uk, lookupRow); + } + } + return map; + } + + private RowData convertToInternalData(Object data) { + if (resultConverter.isIdentityConversion()) { + return (RowData) data; + } else { + return resultConverter.toInternal(data); + } + } + @SuppressWarnings({"unchecked", "rawtypes"}) private Collection convertToInternalData(Collection data) { if (resultConverter.isIdentityConversion()) { diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/DeltaJoinCache.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/DeltaJoinCache.java new file mode 100644 index 0000000000000..ffaca5736bd70 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/DeltaJoinCache.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.deltajoin; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.table.data.RowData; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.guava33.com.google.common.cache.Cache; +import org.apache.flink.shaded.guava33.com.google.common.cache.CacheBuilder; +import org.apache.flink.shaded.guava33.com.google.common.cache.RemovalListener; +import org.apache.flink.shaded.guava33.com.google.common.cache.RemovalNotification; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.LinkedHashMap; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Cache for both sides in delta join. + * + *

Note: This cache is not thread-safe although its inner {@link Cache} is thread-safe. + */ +@NotThreadSafe +public class DeltaJoinCache { + + private static final String LEFT_CACHE_METRIC_PREFIX = "deltaJoin.leftCache."; + private static final String RIGHT_CACHE_METRIC_PREFIX = "deltaJoin.rightCache."; + + private static final String METRIC_HIT_RATE = "hitRate"; + private static final String METRIC_REQUEST_COUNT = "requestCount"; + private static final String METRIC_HIT_COUNT = "hitCount"; + private static final String METRIC_KEY_SIZE = "keySize"; + private static final String METRIC_TOTAL_NON_EMPTY_VALUE_SIZE = "totalNonEmptyValues"; + + // use LinkedHashMap to keep order + private final Cache> leftCache; + private final Cache> rightCache; + + // metrics + private final AtomicLong leftTotalSize = new AtomicLong(0L); + private final AtomicLong rightTotalSize = new AtomicLong(0L); + private final AtomicLong leftHitCount = new AtomicLong(0L); + private final AtomicLong leftRequestCount = new AtomicLong(0L); + private final AtomicLong rightHitCount = new AtomicLong(0L); + private final AtomicLong rightRequestCount = new AtomicLong(0L); + + public DeltaJoinCache(long leftCacheMaxSize, long rightCacheMaxSize) { + this.leftCache = + CacheBuilder.newBuilder() + .maximumSize(leftCacheMaxSize) + .removalListener(new DeltaJoinCacheRemovalListener(true)) + .build(); + this.rightCache = + CacheBuilder.newBuilder() + .maximumSize(rightCacheMaxSize) + .removalListener(new DeltaJoinCacheRemovalListener(false)) + .build(); + } + + public void registerMetrics(MetricGroup metricGroup) { + // left cache metric + metricGroup.>gauge( + LEFT_CACHE_METRIC_PREFIX + METRIC_HIT_RATE, + () -> + leftRequestCount.get() == 0 + ? 0.0 + : Long.valueOf(leftHitCount.get()).doubleValue() + / leftRequestCount.get()); + metricGroup.>gauge( + LEFT_CACHE_METRIC_PREFIX + METRIC_REQUEST_COUNT, rightRequestCount::get); + metricGroup.>gauge( + LEFT_CACHE_METRIC_PREFIX + METRIC_HIT_COUNT, leftHitCount::get); + metricGroup.>gauge( + LEFT_CACHE_METRIC_PREFIX + METRIC_KEY_SIZE, leftCache::size); + + metricGroup.>gauge( + LEFT_CACHE_METRIC_PREFIX + METRIC_TOTAL_NON_EMPTY_VALUE_SIZE, leftTotalSize::get); + + // right cache metric + metricGroup.>gauge( + RIGHT_CACHE_METRIC_PREFIX + METRIC_HIT_RATE, + () -> + rightRequestCount.get() == 0 + ? 0.0 + : Long.valueOf(rightHitCount.get()).doubleValue() + / rightRequestCount.get()); + metricGroup.>gauge( + RIGHT_CACHE_METRIC_PREFIX + METRIC_REQUEST_COUNT, rightRequestCount::get); + metricGroup.>gauge( + RIGHT_CACHE_METRIC_PREFIX + METRIC_HIT_COUNT, rightHitCount::get); + metricGroup.>gauge( + RIGHT_CACHE_METRIC_PREFIX + METRIC_KEY_SIZE, rightCache::size); + metricGroup.>gauge( + RIGHT_CACHE_METRIC_PREFIX + METRIC_TOTAL_NON_EMPTY_VALUE_SIZE, rightTotalSize::get); + } + + @Nullable + public LinkedHashMap getData(RowData key, boolean requestRightCache) { + return requestRightCache ? rightCache.getIfPresent(key) : leftCache.getIfPresent(key); + } + + public void buildCache( + RowData key, LinkedHashMap ukDataMap, boolean buildRightCache) { + Preconditions.checkState(getData(key, buildRightCache) == null); + if (buildRightCache) { + rightCache.put(key, ukDataMap); + rightTotalSize.addAndGet(ukDataMap.size()); + } else { + leftCache.put(key, ukDataMap); + leftTotalSize.addAndGet(ukDataMap.size()); + } + } + + public void upsertCache(RowData key, RowData uk, Object data, boolean upsertRightCache) { + if (upsertRightCache) { + upsert(rightCache, key, uk, data, rightTotalSize); + } else { + upsert(leftCache, key, uk, data, leftTotalSize); + } + } + + private void upsert( + Cache> cache, + RowData key, + RowData uk, + Object data, + AtomicLong cacheTotalSize) { + cache.asMap() + .computeIfPresent( + key, + (k, v) -> { + Object oldData = v.put(uk, data); + if (oldData == null) { + cacheTotalSize.incrementAndGet(); + } + return v; + }); + } + + public void requestLeftCache() { + leftRequestCount.incrementAndGet(); + } + + public void requestRightCache() { + rightRequestCount.incrementAndGet(); + } + + public void hitLeftCache() { + leftHitCount.incrementAndGet(); + } + + public void hitRightCache() { + rightHitCount.incrementAndGet(); + } + + private class DeltaJoinCacheRemovalListener + implements RemovalListener> { + + private final boolean isLeftCache; + + public DeltaJoinCacheRemovalListener(boolean isLeftCache) { + this.isLeftCache = isLeftCache; + } + + @Override + public void onRemoval( + RemovalNotification> removalNotification) { + if (removalNotification.getValue() == null) { + return; + } + + if (isLeftCache) { + leftTotalSize.addAndGet(-removalNotification.getValue().size()); + } else { + rightTotalSize.addAndGet(-removalNotification.getValue().size()); + } + } + } + + // ===== visible for test ===== + + @VisibleForTesting + public Cache> getLeftCache() { + return leftCache; + } + + @VisibleForTesting + public Cache> getRightCache() { + return rightCache; + } + + @VisibleForTesting + public AtomicLong getLeftTotalSize() { + return leftTotalSize; + } + + @VisibleForTesting + public AtomicLong getRightTotalSize() { + return rightTotalSize; + } + + @VisibleForTesting + public AtomicLong getLeftHitCount() { + return leftHitCount; + } + + @VisibleForTesting + public AtomicLong getLeftRequestCount() { + return leftRequestCount; + } + + @VisibleForTesting + public AtomicLong getRightHitCount() { + return rightHitCount; + } + + @VisibleForTesting + public AtomicLong getRightRequestCount() { + return rightRequestCount; + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperator.java index 61edefbc6182b..45fc04b239c72 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperator.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperator.java @@ -99,7 +99,7 @@ public class StreamingDeltaJoinOperator private static final String METRIC_DELTA_JOIN_OP_TOTAL_IN_FLIGHT_NUM = "deltaJoinOpTotalInFlightNum"; - private static final String METRIC_DELTA_JOIN_ASYNC_IO_TIME = "deltaJoinAsyncIoTime"; + private static final String METRIC_DELTA_JOIN_ASYNC_IO_TIME = "deltaJoinAsyncIOTime"; private final StreamRecord leftEmptyStreamRecord; private final StreamRecord rightEmptyStreamRecord; @@ -116,6 +116,10 @@ public class StreamingDeltaJoinOperator /** Max number of inflight invocation. */ private final int capacity; + private final long leftSideCacheSize; + + private final long rightSideCacheSize; + private transient boolean needDeepCopy; /** {@link TypeSerializer} for left side inputs while making snapshots. */ @@ -138,6 +142,8 @@ public class StreamingDeltaJoinOperator private transient TableAsyncExecutionController asyncExecutionController; + private transient DeltaJoinCache cache; + /** Mailbox executor used to yield while waiting for buffers to empty. */ private final transient MailboxExecutor mailboxExecutor; @@ -172,6 +178,8 @@ public StreamingDeltaJoinOperator( int capacity, ProcessingTimeService processingTimeService, MailboxExecutor mailboxExecutor, + long leftSideCacheSize, + long rightSideCacheSize, RowType leftStreamType, RowType rightStreamType) { // rightLookupTableAsyncFunction is an udx used for left records @@ -184,6 +192,8 @@ public StreamingDeltaJoinOperator( this.processingTimeService = checkNotNull(processingTimeService); this.mailboxExecutor = mailboxExecutor; this.isInputEnded = new boolean[2]; + this.leftSideCacheSize = leftSideCacheSize; + this.rightSideCacheSize = rightSideCacheSize; this.leftEmptyStreamRecord = new StreamRecord<>(new GenericRowData(leftStreamType.getFieldCount())); this.rightEmptyStreamRecord = @@ -224,6 +234,11 @@ public void setup( isLeft(inputIndex) ? leftJoinKeySelector : rightJoinKeySelector; return keySelector.getKey(record.getValue()); }); + + this.cache = new DeltaJoinCache(leftSideCacheSize, rightSideCacheSize); + + leftTriggeredUserFunction.setCache(cache); + rightTriggeredUserFunction.setCache(cache); } @Override @@ -275,6 +290,9 @@ public void open() throws Exception { getRuntimeContext() .getMetricGroup() .gauge(METRIC_DELTA_JOIN_ASYNC_IO_TIME, asyncIOTime::get); + // 3. cache metric + cache.registerMetrics(getRuntimeContext().getMetricGroup()); + // asyncBufferCapacity + 1 as the queue size in order to avoid // blocking on the queue when taking a collector. this.resultHandlerBuffer = new ArrayBlockingQueue<>(capacity + 1); diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java index e09450fd9f27f..17b286e9819aa 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java @@ -49,17 +49,27 @@ import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.VarCharType; import org.apache.flink.table.utils.HandwrittenSelectorUtil; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameter; +import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; +import org.apache.flink.util.Preconditions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentLinkedQueue; @@ -68,18 +78,20 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; import java.util.stream.Stream; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.binaryrow; import static org.apache.flink.table.runtime.util.StreamRecordUtils.insertRecord; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Test class for {@link StreamingDeltaJoinOperator}. */ +@ExtendWith(ParameterizedTestExtension.class) public class StreamingDeltaJoinOperatorTest { - private KeyedTwoInputStreamOperatorTestHarness testHarness; - private static final int AEC_CAPACITY = 100; + private static final int CACHE_SIZE = 10; // the data snapshot of the left/right table when joining private static final LinkedList leftTableCurrentData = new LinkedList<>(); @@ -92,7 +104,7 @@ public class StreamingDeltaJoinOperatorTest { * CREATE TABLE leftSrc( * left_value INT, * left_jk1 BOOLEAN, - * left_jk2_lk VARCHAR, + * left_jk2_lk STRING, * INDEX(left_jk2_lk) * ) * @@ -101,7 +113,7 @@ public class StreamingDeltaJoinOperatorTest { * CREATE TABLE rightSrc( * right_jk2 STRING, * right_value INT, - * right_jk1_lk VARCHAR, + * right_jk1_lk BOOLEAN, * INDEX(right_jk1_lk) * ) * @@ -123,37 +135,45 @@ public class StreamingDeltaJoinOperatorTest { // left join key: // left lookup key: - private static final InternalTypeInfo leftTypeInfo = - InternalTypeInfo.of( - RowType.of( - new LogicalType[] { - new IntType(), new BooleanType(), VarCharType.STRING_TYPE - }, - new String[] {"left_value", "left_jk1", "left_jk2_lk"})); + private static final RowType leftRowType = + RowType.of( + new LogicalType[] {new IntType(), new BooleanType(), VarCharType.STRING_TYPE}, + new String[] {"left_value", "left_jk1", "left_jk2_lk"}); + + private static final InternalTypeInfo leftTypeInfo = InternalTypeInfo.of(leftRowType); private static final int[] leftJoinKeyIndices = new int[] {1, 2}; // right join key: // right lookup key: + private static final RowType rightRowType = + RowType.of( + new LogicalType[] {VarCharType.STRING_TYPE, new IntType(), new BooleanType()}, + new String[] {"right_jk2", "right_value", "right_jk1_lk"}); + private static final InternalTypeInfo rightTypeInfo = - InternalTypeInfo.of( - RowType.of( - new LogicalType[] { - VarCharType.STRING_TYPE, new IntType(), new BooleanType() - }, - new String[] {"right_jk2", "right_value", "right_jk1_lk"})); + InternalTypeInfo.of(rightRowType); + private static final int[] rightJoinKeyIndices = new int[] {2, 0}; private static final RowDataKeySelector leftJoinKeySelector = HandwrittenSelectorUtil.getRowDataSelector( - leftJoinKeyIndices, - leftTypeInfo.toRowType().getChildren().toArray(new LogicalType[0])); + leftJoinKeyIndices, leftRowType.getChildren().toArray(new LogicalType[0])); private static final RowDataKeySelector rightJoinKeySelector = HandwrittenSelectorUtil.getRowDataSelector( - rightJoinKeyIndices, - rightTypeInfo.toRowType().getChildren().toArray(new LogicalType[0])); + rightJoinKeyIndices, rightRowType.getChildren().toArray(new LogicalType[0])); + + private static final int[] outputFieldIndices = + IntStream.range(0, leftTypeInfo.getArity() + rightTypeInfo.getArity()).toArray(); - private static final int[] outputUpsertKeyIndices = leftJoinKeyIndices; + @Parameters(name = "EnableCache = {0}") + public static List parameters() { + return Arrays.asList(false, true); + } + + @Parameter public boolean enableCache; + + private KeyedTwoInputStreamOperatorTestHarness testHarness; private RowDataHarnessAssertor assertor; @@ -182,7 +202,7 @@ public void beforeEach() throws Exception { getOutputType().getChildren().toArray(new LogicalType[0]), // sort the result by the output upsert key (o1, o2) -> { - for (int keyIndex : outputUpsertKeyIndices) { + for (int keyIndex : outputFieldIndices) { LogicalType type = getOutputType().getChildren().get(keyIndex); RowData.FieldGetter getter = RowData.createFieldGetter(type, keyIndex); @@ -215,7 +235,7 @@ public void afterEach() throws Exception { MyAsyncFunction.clearExpectedThrownException(); } - @Test + @TestTemplate void testJoinBothAppendOnlyTables() throws Exception { StreamRecord leftRecord1 = insertRecord(100, true, "jklk1"); StreamRecord leftRecord2 = insertRecord(100, false, "jklk2"); @@ -270,9 +290,56 @@ void testJoinBothAppendOnlyTables() throws Exception { assertThat(aec.getBlockingSize()).isEqualTo(0); assertThat(aec.getInFlightSize()).isEqualTo(0); assertThat(aec.getFinishSize()).isEqualTo(0); + + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + Map> expectedLeftCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(leftRecord1.getValue(), leftRowType), + leftRecord1.getValue(), + toBinary(leftRecord3.getValue(), leftRowType), + leftRecord3.getValue(), + toBinary(leftRecord5.getValue(), leftRowType), + leftRecord5.getValue()), + binaryrow(false, "jklk2"), + newHashMap( + toBinary(leftRecord2.getValue(), leftRowType), + leftRecord2.getValue(), + toBinary(leftRecord4.getValue(), leftRowType), + leftRecord4.getValue(), + toBinary(leftRecord6.getValue(), leftRowType), + leftRecord6.getValue()), + binaryrow(false, "unknown"), + Collections.emptyMap()); + + Map> expectedRightCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(rightRecord1.getValue(), rightRowType), + rightRecord1.getValue(), + toBinary(rightRecord4.getValue(), rightRowType), + rightRecord4.getValue()), + binaryrow(false, "jklk2"), + newHashMap( + toBinary(rightRecord2.getValue(), rightRowType), + rightRecord2.getValue(), + toBinary(rightRecord5.getValue(), rightRowType), + rightRecord5.getValue())); + + verifyCacheData(cache, expectedLeftCacheData, expectedRightCacheData, 5, 2, 6, 4); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(2); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(3); + } else { + verifyCacheData(cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(6); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(5); + } } - @Test + @TestTemplate void testBlockingWithSameJoinKey() throws Exception { // block the async function MyAsyncFunction.block(); @@ -341,6 +408,47 @@ void testBlockingWithSameJoinKey() throws Exception { assertThat(recordsBuffer.getActiveBuffer()).isEmpty(); assertThat(recordsBuffer.getBlockingBuffer()).isEmpty(); assertThat(recordsBuffer.getFinishedBuffer()).isEmpty(); + + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + Map> expectedLeftCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(leftRecord1.getValue(), leftRowType), + leftRecord1.getValue(), + toBinary(leftRecord3.getValue(), leftRowType), + leftRecord3.getValue()), + binaryrow(false, "jklk2"), + newHashMap( + toBinary(leftRecord2.getValue(), leftRowType), + leftRecord2.getValue(), + toBinary(leftRecord4.getValue(), leftRowType), + leftRecord4.getValue(), + toBinary(leftRecord5.getValue(), leftRowType), + leftRecord5.getValue()), + binaryrow(false, "unknown"), + Collections.emptyMap()); + + Map> expectedRightCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(rightRecord1.getValue(), rightRowType), + rightRecord1.getValue()), + binaryrow(false, "jklk2"), + newHashMap( + toBinary(rightRecord2.getValue(), rightRowType), + rightRecord2.getValue())); + + verifyCacheData(cache, expectedLeftCacheData, expectedRightCacheData, 3, 0, 5, 3); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(2); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(3); + } else { + verifyCacheData(cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(5); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(3); + } } /** @@ -348,7 +456,7 @@ void testBlockingWithSameJoinKey() throws Exception { * the left table that has not been sent to the delta-join operator (maybe is in flight between * source and delta-join). */ - @Test + @TestTemplate void testTableDataVisibleBeforeJoin() throws Exception { MyAsyncExecutionControllerDelegate.insertTableDataAfterEmit = false; @@ -414,9 +522,42 @@ void testTableDataVisibleBeforeJoin() throws Exception { assertThat(aec.getBlockingSize()).isEqualTo(0); assertThat(aec.getInFlightSize()).isEqualTo(0); assertThat(aec.getFinishSize()).isEqualTo(0); + + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + Map> expectedLeftCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(leftRecord1.getValue(), leftRowType), + leftRecord1.getValue(), + toBinary(leftRecord2.getValue(), leftRowType), + leftRecord2.getValue(), + toBinary(leftRecord3.getValue(), leftRowType), + leftRecord3.getValue()), + binaryrow(false, "jklk2"), + Collections.emptyMap()); + + Map> expectedRightCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(rightRecord1.getValue(), rightRowType), + rightRecord1.getValue(), + toBinary(rightRecord3.getValue(), rightRowType), + rightRecord3.getValue())); + + verifyCacheData(cache, expectedLeftCacheData, expectedRightCacheData, 3, 1, 3, 2); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(1); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(2); + } else { + verifyCacheData(cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(3); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(3); + } } - @Test + @TestTemplate void testCheckpointAndRestore() throws Exception { // block the async function MyAsyncFunction.block(); @@ -453,6 +594,9 @@ void testCheckpointAndRestore() throws Exception { MyAsyncFunction.release(); testHarness.close(); + MyAsyncFunction.leftInvokeCount.set(0); + MyAsyncFunction.rightInvokeCount.set(0); + MyAsyncFunction.block(); // restoring testHarness = createDeltaJoinOperatorTestHarness(); @@ -494,9 +638,38 @@ void testCheckpointAndRestore() throws Exception { assertThat(recordsBuffer.getActiveBuffer()).isEmpty(); assertThat(recordsBuffer.getBlockingBuffer()).isEmpty(); assertThat(recordsBuffer.getFinishedBuffer()).isEmpty(); + + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + Map> expectedLeftCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(leftRecord1.getValue(), leftRowType), + toBinary(leftRecord1.getValue(), leftRowType), + toBinary(leftRecord2.getValue(), leftRowType), + toBinary(leftRecord2.getValue(), leftRowType)), + binaryrow(false, "unknown"), + Collections.emptyMap()); + + Map> expectedRightCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(rightRecord1.getValue(), rightRowType), + toBinary(rightRecord1.getValue(), rightRowType))); + + verifyCacheData(cache, expectedLeftCacheData, expectedRightCacheData, 2, 0, 2, 1); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(1); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(2); + } else { + verifyCacheData(cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(2); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(2); + } } - @Test + @TestTemplate void testClearLegacyStateWhenCheckpointing() throws Exception { // block the async function MyAsyncFunction.block(); @@ -547,7 +720,7 @@ void testClearLegacyStateWhenCheckpointing() throws Exception { "result mismatch", expectedOutput, testHarness.getOutput()); } - @Test + @TestTemplate void testMeetExceptionWhenLookup() throws Exception { Throwable expectedException = new IllegalStateException("Mock to fail"); MyAsyncFunction.setExpectedThrownException(expectedException); @@ -566,6 +739,78 @@ void testMeetExceptionWhenLookup() throws Exception { .isEqualTo(expectedException); } + private void verifyCacheData( + DeltaJoinCache actualCache, + Map> expectedLeftCacheData, + Map> expectedRightCacheData, + long expectedLeftCacheRequestCount, + long expectedLeftCacheHitCount, + long expectedRightCacheRequestCount, + long expectedRightCacheHitCount) { + // assert left cache + verifyCacheData( + actualCache, + expectedLeftCacheData, + expectedLeftCacheRequestCount, + expectedLeftCacheHitCount, + true); + + // assert right cache + verifyCacheData( + actualCache, + expectedRightCacheData, + expectedRightCacheRequestCount, + expectedRightCacheHitCount, + false); + } + + private void verifyCacheData( + DeltaJoinCache actualCache, + Map> expectedCacheData, + long expectedCacheRequestCount, + long expectedCacheHitCount, + boolean testLeftCache) { + String errorPrefix = testLeftCache ? "left cache " : "right cache "; + + Map> actualCacheData = + testLeftCache + ? actualCache.getLeftCache().asMap() + : actualCache.getRightCache().asMap(); + assertThat(actualCacheData).as(errorPrefix + "data mismatch").isEqualTo(expectedCacheData); + + long actualCacheSize = + testLeftCache + ? actualCache.getLeftCache().size() + : actualCache.getRightCache().size(); + assertThat(actualCacheSize) + .as(errorPrefix + "size mismatch") + .isEqualTo(expectedCacheData.size()); + + long actualTotalSize = + testLeftCache + ? actualCache.getLeftTotalSize().get() + : actualCache.getRightTotalSize().get(); + assertThat(actualTotalSize) + .as(errorPrefix + "total size mismatch") + .isEqualTo(expectedCacheData.values().stream().mapToInt(Map::size).sum()); + + long actualRequestCount = + testLeftCache + ? actualCache.getLeftRequestCount().get() + : actualCache.getRightRequestCount().get(); + assertThat(actualRequestCount) + .as(errorPrefix + "request count mismatch") + .isEqualTo(expectedCacheRequestCount); + + long actualHitCount = + testLeftCache + ? actualCache.getLeftHitCount().get() + : actualCache.getRightHitCount().get(); + assertThat(actualHitCount) + .as(errorPrefix + "hit count mismatch") + .isEqualTo(expectedCacheHitCount); + } + private void waitAllDataProcessed() throws Exception { testHarness.endAllInputs(); if (latestException.isPresent()) { @@ -584,14 +829,22 @@ private void waitAllDataProcessed() throws Exception { (DataStructureConverter) DataStructureConverters.getConverter(leftTypeInfo.getDataType()); + RowDataKeySelector leftUpsertKeySelector = getUpsertKeySelector(leftRowType, null); + RowDataKeySelector rightUpsertKeySelector = getUpsertKeySelector(rightRowType, null); + AsyncDeltaJoinRunner leftAsyncFunction = new AsyncDeltaJoinRunner( new GeneratedFunctionWrapper<>(new MyAsyncFunction()), leftFetcherConverter, new GeneratedResultFutureWrapper<>(new TestingFetcherResultFuture()), leftTypeInfo.toRowSerializer(), + leftJoinKeySelector, + leftUpsertKeySelector, + rightJoinKeySelector, + rightUpsertKeySelector, AEC_CAPACITY, - false); + false, + enableCache); DataStructureConverter rightFetcherConverter = (DataStructureConverter) @@ -603,8 +856,13 @@ private void waitAllDataProcessed() throws Exception { rightFetcherConverter, new GeneratedResultFutureWrapper<>(new TestingFetcherResultFuture()), rightTypeInfo.toRowSerializer(), + leftJoinKeySelector, + leftUpsertKeySelector, + rightJoinKeySelector, + rightUpsertKeySelector, AEC_CAPACITY, - true); + true, + enableCache); InternalTypeInfo joinKeyTypeInfo = leftJoinKeySelector.getProducedType(); @@ -619,8 +877,10 @@ private void waitAllDataProcessed() throws Exception { new TestProcessingTimeService(), new MailboxExecutorImpl( mailbox, 0, StreamTaskActionExecutor.IMMEDIATE, mailboxProcessor), - (RowType) leftTypeInfo.toLogicalType(), - (RowType) rightTypeInfo.toLogicalType()); + CACHE_SIZE, + CACHE_SIZE, + leftRowType, + rightRowType); return new KeyedTwoInputStreamOperatorTestHarness<>( operator, @@ -634,6 +894,14 @@ private void waitAllDataProcessed() throws Exception { rightTypeInfo.toSerializer()); } + private RowDataKeySelector getUpsertKeySelector(RowType rowType, @Nullable int[] upsertKey) { + if (upsertKey == null) { + upsertKey = IntStream.range(0, rowType.getFieldCount()).toArray(); + } + return HandwrittenSelectorUtil.getRowDataSelector( + upsertKey, rowType.getChildren().toArray(new LogicalType[0])); + } + private void prepareOperatorRuntimeInfo(StreamingDeltaJoinOperator operator) { unwrapAsyncFunction(operator, true).tagInvokingSideDuringRuntime(true); unwrapAsyncFunction(operator, false).tagInvokingSideDuringRuntime(false); @@ -660,15 +928,28 @@ private StreamingDeltaJoinOperator unwrapOperator( return (StreamingDeltaJoinOperator) testHarness.getOperator(); } + private DeltaJoinCache unwrapCache( + KeyedTwoInputStreamOperatorTestHarness + testHarness) { + DeltaJoinCache cacheInLeftRunner = + unwrapOperator(testHarness).getLeftTriggeredUserFunction().getCache(); + DeltaJoinCache cacheInRightRunner = + unwrapOperator(testHarness).getRightTriggeredUserFunction().getCache(); + + // the object ref must be the same + assertThat(cacheInLeftRunner == cacheInRightRunner).isTrue(); + return cacheInLeftRunner; + } + private RowType getOutputType() { return RowType.of( Stream.concat( - leftTypeInfo.toRowType().getChildren().stream(), - rightTypeInfo.toRowType().getChildren().stream()) + leftRowType.getChildren().stream(), + rightRowType.getChildren().stream()) .toArray(LogicalType[]::new), Stream.concat( - leftTypeInfo.toRowType().getFieldNames().stream(), - rightTypeInfo.toRowType().getFieldNames().stream()) + leftRowType.getFieldNames().stream(), + rightRowType.getFieldNames().stream()) .toArray(String[]::new)); } @@ -697,6 +978,28 @@ private static void insertTableData(StreamRecord record, boolean insert } } + private Map newHashMap(Object... data) { + Preconditions.checkArgument(data.length % 2 == 0); + Map map = new HashMap<>(); + for (int i = 0; i < data.length; i = i + 2) { + Preconditions.checkArgument( + data[i] instanceof RowData, "The key of the map must be RowData"); + RowData key = (RowData) data[i]; + Preconditions.checkArgument(!map.containsKey(key), "Duplicate key"); + map.put(key, (T) data[i + 1]); + } + return map; + } + + private RowData toBinary(RowData row, RowType rowType) { + int size = row.getArity(); + Object[] fields = new Object[size]; + for (int i = 0; i < size; i++) { + fields[i] = RowData.createFieldGetter(rowType.getTypeAt(i), i).getFieldOrNull(row); + } + return binaryrow(fields); + } + /** An async function used for test. */ public static class MyAsyncFunction extends RichAsyncFunction { diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java index 717c04fb9d61e..40929ed6ac64b 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/StreamRecordUtils.java @@ -144,6 +144,8 @@ public static BinaryRowData binaryrow(Object... fields) { writer.writeInt(j, (Integer) value); } else if (value instanceof String) { writer.writeString(j, StringData.fromString((String) value)); + } else if (value instanceof StringData) { + writer.writeString(j, (StringData) value); } else if (value instanceof Double) { writer.writeDouble(j, (Double) value); } else if (value instanceof Float) {