Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ public long toMillis() {
return duration().toMillis();
}

public TtlDuration minus(final long otherMs) {
return TtlDuration.of(duration().minusMillis(otherMs));
}

@Override
public boolean equals(final Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ public BoundStatement insert(
final long epochMillis
) {
if (ttlResolver.isPresent()) {
final Optional<TtlDuration> rowTtl = ttlResolver.get().computeTtl(key, value);
final Optional<TtlDuration> rowTtl =
ttlResolver.get().computeInsertTtl(key, value, epochMillis);

// If user happens to return same ttl value as the default, skip applying it at
// the row level since this is less efficient in Scylla
Expand Down Expand Up @@ -342,7 +343,7 @@ public byte[] get(final int kafkaPartition, final Bytes key, long streamTimeMs)
} else if (ttlResolver.get().needsValueToComputeTtl()) {
return postFilterGet(key, streamTimeMs);
} else {
final TtlDuration ttl = ttlResolver.get().resolveTtl(key, null);
final TtlDuration ttl = ttlResolver.get().resolveRowTtl(key, null);
if (ttl.isFinite()) {
final long minValidTimeMs = streamTimeMs - ttl.toMillis();
return preFilterGet(key, minValidTimeMs);
Expand Down Expand Up @@ -398,7 +399,7 @@ private byte[] postFilterGet(final Bytes key, long streamTimeMs) {

final Row rowResult = result.get(0);
final byte[] value = getValueFromRow(rowResult);
final TtlDuration ttl = ttlResolver.get().resolveTtl(key, value);
final TtlDuration ttl = ttlResolver.get().resolveRowTtl(key, value);

if (ttl.isFinite()) {
final long minValidTsFromValue = streamTimeMs - ttl.toMillis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public byte[] get(final int kafkaPartition, final Bytes key, final long streamTi
}

if (ttlResolver.isPresent()) {
final TtlDuration rowTtl = ttlResolver.get().resolveTtl(key, value.value());
final TtlDuration rowTtl = ttlResolver.get().resolveRowTtl(key, value.value());
if (rowTtl.isFinite()) {
final long minValidTs = streamTimeMs - rowTtl.toMillis();
if (value.epochMillis < minValidTs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public void init(final ProcessorContext context, final StateStore root) {
@Override
public void init(final StateStoreContext storeContext, final StateStore root) {
try {
final TaskType taskType = asInternalProcessorContext(storeContext).taskType();
final var internalProcessorContext = asInternalProcessorContext(storeContext);
final TaskType taskType = internalProcessorContext.taskType();
log = new LogContext(
String.format(
"%sstore [%s] ",
Expand All @@ -117,7 +118,8 @@ public void init(final StateStoreContext storeContext, final StateStore root) {
final StateSerdes<?, ?> stateSerdes = StoreAccessorUtil.extractKeyValueStoreSerdes(root);
final Optional<TtlResolver<?, ?>> ttlResolver = TtlResolver.fromTtlProviderAndStateSerdes(
stateSerdes,
params.ttlProvider()
params.ttlProvider(),
internalProcessorContext
);

operations = opsProvider.provide(params, ttlResolver, storeContext, taskType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dev.responsive.kafka.internal.utils.StateDeserializer;
import java.util.Optional;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.streams.processor.ProcessorContext;
import org.apache.kafka.streams.state.StateSerdes;

public class TtlResolver<K, V> {
Expand All @@ -29,11 +30,13 @@ public class TtlResolver<K, V> {

private final StateDeserializer<K, V> stateDeserializer;
private final TtlProvider<K, V> ttlProvider;
private final ProcessorContext processorContext;

@SuppressWarnings("unchecked")
public static <K, V> Optional<TtlResolver<?, ?>> fromTtlProviderAndStateSerdes(
final StateSerdes<?, ?> stateSerdes,
final Optional<TtlProvider<?, ?>> ttlProvider
final Optional<TtlProvider<?, ?>> ttlProvider,
final ProcessorContext processorContext
) {
return ttlProvider.isPresent()
? Optional.of(
Expand All @@ -42,17 +45,20 @@ public class TtlResolver<K, V> {
stateSerdes.topic(),
stateSerdes.keyDeserializer(),
stateSerdes.valueDeserializer()),
(TtlProvider<K, V>) ttlProvider.get()
(TtlProvider<K, V>) ttlProvider.get(),
processorContext
))
: Optional.empty();
}

public TtlResolver(
final StateDeserializer<K, V> stateDeserializer,
final TtlProvider<K, V> ttlProvider
final TtlProvider<K, V> ttlProvider,
final ProcessorContext processorContext
) {
this.stateDeserializer = stateDeserializer;
this.ttlProvider = ttlProvider;
this.processorContext = processorContext;
}

public TtlDuration defaultTtl() {
Expand All @@ -68,18 +74,33 @@ public boolean needsValueToComputeTtl() {
}

/**
* @return the raw result from the user's ttl computation function for this row
* @return the raw result from the user's ttl computation function for this row,
* adjusted by the difference between current time and the record timestamp.
* Used for writes.
*/
public Optional<TtlDuration> computeTtl(final Bytes keyBytes, final byte[] valueBytes) {
return ttlProvider.computeTtl(keyBytes.get(), valueBytes, stateDeserializer);
public Optional<TtlDuration> computeInsertTtl(
final Bytes keyBytes,
final byte[] valueBytes,
final long timestampMs
) {
return ttlProvider.computeTtl(keyBytes.get(), valueBytes, stateDeserializer)
.map(ttl -> {
if (ttl.isFinite()) {
return ttl.minus(processorContext.currentSystemTimeMs() - timestampMs);
} else {
return ttl;
}
});
}

/**
* @return the actual ttl for this row after resolving the raw result returned by the user
* (eg applying the default value)
* (eg applying the default value). Used for reads
*/
public TtlDuration resolveTtl(final Bytes keyBytes, final byte[] valueBytes) {
final Optional<TtlDuration> ttl = computeTtl(keyBytes, valueBytes);
public TtlDuration resolveRowTtl(final Bytes keyBytes, final byte[] valueBytes) {
final Optional<TtlDuration> ttl =
ttlProvider.computeTtl(keyBytes.get(), valueBytes, stateDeserializer);

return ttl.orElse(defaultTtl());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.function.Function;
import org.apache.kafka.common.serialization.StringDeserializer;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.streams.processor.MockProcessorContext;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -56,6 +57,8 @@
@ExtendWith(ResponsiveExtension.class)
class CassandraFactTableIntegrationTest {

private final MockProcessorContext mockContext = new MockProcessorContext();

private String storeName; // ie the "kafkaName", NOT the "cassandraName"
private ResponsiveKeyValueParams params;
private CassandraClient client;
Expand Down Expand Up @@ -214,7 +217,8 @@ public void shouldRespectSemanticKeyBasedTtl() throws Exception {
defaultPartitioner(),
Optional.of(new TtlResolver<>(
new StateDeserializer<>("ignored", new StringDeserializer(), new StringDeserializer()),
ttlProvider))
ttlProvider,
mockContext))
));

table.init(1);
Expand All @@ -228,6 +232,7 @@ public void shouldRespectSemanticKeyBasedTtl() throws Exception {

// When:
final long insertTimeMs = 0L;
mockContext.setCurrentSystemTimeMs(insertTimeMs);
client.execute(table.insert(1, noTtlKey, val, insertTimeMs));
client.execute(table.insert(1, defaultTtlKey, val, insertTimeMs));
client.execute(table.insert(1, tenMinTtlKey, val, insertTimeMs));
Expand Down Expand Up @@ -281,7 +286,8 @@ public void shouldRespectSemanticKeyValueBasedTtl() throws Exception {
defaultPartitioner(),
Optional.of(new TtlResolver<>(
new StateDeserializer<>("ignored", new StringDeserializer(), new StringDeserializer()),
ttlProvider))
ttlProvider,
mockContext))
));

table.init(1);
Expand All @@ -301,6 +307,7 @@ public void shouldRespectSemanticKeyValueBasedTtl() throws Exception {

// When
long insertTimeMs = 0L;
mockContext.setCurrentSystemTimeMs(insertTimeMs);
client.execute(table.insert(1, tenMinTtlKey, val, insertTimeMs));
client.execute(table.insert(1, defaultTtlKey, defaultTtlValue, insertTimeMs));
client.execute(table.insert(1, noTtlKey, noTtlValue, insertTimeMs));
Expand Down Expand Up @@ -361,7 +368,8 @@ public void shouldRespectOverridesWithValueBasedTtl() throws Exception {
defaultPartitioner(),
Optional.of(new TtlResolver<>(
new StateDeserializer<>("ignored", new StringDeserializer(), new StringDeserializer()),
ttlProvider)
ttlProvider,
mockContext)
)));

table.init(1);
Expand All @@ -375,6 +383,7 @@ public void shouldRespectOverridesWithValueBasedTtl() throws Exception {

// When
long currentTimeMs = 0L;
mockContext.setCurrentSystemTimeMs(currentTimeMs);
// first record set to expire at 3ms
client.execute(table.insert(1, key, threeMinTtlValue, currentTimeMs));

Expand All @@ -384,10 +393,12 @@ public void shouldRespectOverridesWithValueBasedTtl() throws Exception {

// insert new record with 3ms ttl -- now set to expire at 10ms
currentTimeMs = Duration.ofMinutes(7).toMillis();
mockContext.setCurrentSystemTimeMs(currentTimeMs);
client.execute(table.insert(1, key, threeMinTtlValue, currentTimeMs));

// override with 10ms ttl -- now set to expire at 18ms
currentTimeMs = Duration.ofMinutes(8).toMillis();
mockContext.setCurrentSystemTimeMs(currentTimeMs);
client.execute(table.insert(1, key, tenMinTtlValue, currentTimeMs));

// record should still exist after 10ms
Expand All @@ -396,6 +407,7 @@ public void shouldRespectOverridesWithValueBasedTtl() throws Exception {

// override with default ttl (30ms) -- now set to expire at 45ms
currentTimeMs = Duration.ofMinutes(15).toMillis();
mockContext.setCurrentSystemTimeMs(currentTimeMs);
client.execute(table.insert(1, key, defaultTtlValue, currentTimeMs));

// record should still exist after 18ms
Expand All @@ -404,11 +416,13 @@ public void shouldRespectOverridesWithValueBasedTtl() throws Exception {

// override with no ttl -- now set to never expire
currentTimeMs = Duration.ofMinutes(30).toMillis();
mockContext.setCurrentSystemTimeMs(currentTimeMs);
client.execute(table.insert(1, key, noTtlValue, currentTimeMs));

// record should still exist after 45ms
currentTimeMs = Duration.ofMinutes(50).toMillis();
assertThat(table.get(1, key, currentTimeMs), is(noTtlValue));
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public MockResponsiveKafkaStreams(
public static Optional<TtlResolver<?, ?>> defaultOnlyTtl(final Duration ttl) {
return Optional.of(new TtlResolver<>(
new StateDeserializer<>("ignored", null, null),
TtlProvider.withDefault(ttl))
TtlProvider.withDefault(ttl), null)
);
}

Expand All @@ -99,7 +99,7 @@ public MockResponsiveKafkaStreams(
) {
return Optional.of(new TtlResolver<>(
new StateDeserializer<>("ignored", null, null),
ttlProvider)
ttlProvider, null)
);
}

Expand All @@ -108,7 +108,8 @@ public MockResponsiveKafkaStreams(
) {
return ttlProvider.isPresent()
? Optional.of(
new TtlResolver<>(new StateDeserializer<>("ignored", null, null), ttlProvider.get()))
new TtlResolver<>(new StateDeserializer<>("ignored", null, null), ttlProvider.get(),
null))
: Optional.empty();
}

Expand Down
Loading