diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java
new file mode 100644
index 000000000000..9b5701e3ac17
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java
@@ -0,0 +1,361 @@
+package org.apache.beam.sdk.transforms;
+
+import com.google.auto.value.AutoValue;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.util.MemoizingPerInstantiationSerializableSupplier;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+/**
+ * An abstract {@link DoFn} that allows processing elements asynchronously.
+ *
+ *
This {@link DoFn} provides a framework for managing asynchronous operations, ensuring that
+ * elements are processed in a non-blocking manner while maintaining proper synchronization and
+ * handling of results.
+ *
+ *
To use this, extend {@link org.apache.beam.sdk.transforms.AsyncDoFn} and implement the {@link #asyncProcessElement} method to
+ * initiate the asynchronous operation.
+ *
+ * @param The key of KV of input elements
+ * @param The value of the KV of the input elements.
+ * @param The type of the output elements.
+ */
+ public abstract class AsyncDoFn extends DoFn, OutputT> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(org.apache.beam.sdk.transforms.AsyncDoFn.class);
+
+ /**
+ * A {@link TimerSpec} object used for managing asynchronous callbacks or recovering from
+ * lost in-memory state.
+ */
+ @TimerId("timer")
+ private final TimerSpec pollTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
+
+ /**
+ * A {@link StateSpec} object used for storing elements that are waiting to be processed.
+ */
+ @StateId("buffer")
+ private final StateSpec> buffer = StateSpecs.bag();
+
+ static class LocalState {
+ LocalState(ExecutorService executorService, int maxItemsToBuffer) {
+ this.executorService = executorService;
+ this.bufferedElementsSemaphore = new Semaphore(maxItemsToBuffer);
+ }
+ ExecutorService executorService;
+ Semaphore bufferedElementsSemaphore;
+ static class ProcessingElements {
+ ProcessingElements(InputT element, CompletableFuture future) {
+ this.element = element;
+ this.future = future;
+ }
+ final InputT element;
+ final CompletableFuture future;
+ }
+ ConcurrentHashMap> processingElements;
+ }
+ private final MemoizingPerInstantiationSerializableSupplier>
+ localState;
+
+ /**
+ * Configurable parameters for the AsyncDoFn.
+ */
+ private final int callbackFrequencyMillis;
+ private final int maxItemsToBuffer;
+ private final int timeoutMillis;
+ private final int maxWaitTimeMillis;
+ private final SerializableFunction idFn;
+
+ protected AsyncDoFn(AsyncDoFnOptions options) {
+ this.callbackFrequencyMillis = options.getCallbackFrequencyMillis();
+ this.maxItemsToBuffer = options.getMaxItemsToBuffer();
+ this.timeoutMillis = options.getTimeoutMillis();
+ this.maxWaitTimeMillis = options.getMaxWaitTimeMillis();
+ this.idFn = options.getIdFn();
+ final int parallelism = options.getParallelism();
+ this.localState = new MemoizingPerInstantiationSerializableSupplier<>(
+ () -> {
+ ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1);
+ executor.setMaximumPoolSize(parallelism);
+ return new LocalState<>(executor, maxItemsToBuffer);
+ });
+ }
+
+ /**
+ * Processes an input element asynchronously.
+ *
+ * This method should initiate the asynchronous operation and return immediately. The result
+ * of the operation should be completed by calling the {@link #onResult} method.
+ *
+ * @param c The process context.
+ */
+ @ProcessElement
+ public void processElement(
+ ProcessContext c,
+ @TimerId("timer") Timer timer,
+ @StateId("buffer") BagState> toProcess) {
+
+ scheduleItem(c.element(), c.timestamp(), timer, toProcess, c);
+
+ // Don't output any elements. This will be done in commitFinishedItems.
+ }
+
+ private void scheduleItem(
+ KV element,
+ Instant timestamp,
+ Timer timer,
+ BagState> toProcess,
+ ProcessContext c) {
+ LocalState local = localState.get();
+
+ Object elementId = idFn.apply(element.getValue());
+ try {
+ if (local.bufferedElementsSemaphore.tryAcquire(maxWaitTimeMillis, TimeUnit.MILLISECONDS)) {
+ CompletableFuture future = new CompletableFuture<>();
+ local.processingElements.put(elementId,
+ new LocalState.ProcessingElements<>(element.getValue(), future));
+ local.executorService.execute(() -> {
+ try {
+ asyncProcessElement(element.getValue(), future);
+ } catch (Exception e) {
+ future.completeExceptionally(e);
+ }
+ });
+ } else {
+ LOG.warn("Unable to schedule due max buffered limit {}, buffering on disk and will schedule with timers.",
+ maxItemsToBuffer);
+ }
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(ie);
+ }
+
+ toProcess.add(element);
+ // XXX what about the watermark?
+ timer.set(nextTimeToFire(element.getKey()));
+ }
+
+ /**
+ * Override this method to implement the asynchronous processing logic. Initiate the
+ * asynchronous operation, and when the result is available, complete the future with the
+ * result.
+ *
+ * Note that this may run
+ *
+ * @param element The input element.
+ * @param future The future to complete with the result of the asynchronous operation.
+ */
+ protected abstract void asyncProcessElement(InputT element, CompletableFuture future) throws Exception;
+
+ private Instant nextTimeToFire(K key) {
+ long base = Math.floorDiv(System.currentTimeMillis() + callbackFrequencyMillis, callbackFrequencyMillis);
+ long offset = Math.floorMod(Math.abs(key.hashCode()), callbackFrequencyMillis);
+ return Instant.ofEpochMilli(base + offset);
+ }
+
+ /**
+ * Commits finished items and synchronizes local state with runner state.
+ *
+ * Note timer firings are per key while local state contains messages for all keys. Only
+ * messages for the given key will be output/cleaned up.
+ *
+ * @param toProcess State that keeps track of queued messages for this key.
+ * @param timer Timer that initiated this commit and can be reset if not all items have finished..
+ * @return A list of elements that have finished processing for this key.
+ */
+ @OnTimer("timer")
+ public void timerCallback(
+ OnTimerContext c,
+ @AlwaysFetched @StateId("buffer") BagState> toProcess,
+ @TimerId("timer") Timer timer) throws Exception {
+ commitFinishedItems(c, toProcess, timer);
+ }
+
+ private void commitFinishedItems(
+ OnTimerContext c,
+ BagState> toProcess,
+ Timer timer) throws Exception {
+ // For all elements that are in processing state:
+ // If the element is done processing, delete it from all state and yield the output.
+ // If the element is not yet done, print it. If the element is not in local state, schedule it for processing.
+
+ HashMap bufferMap = new HashMap<>();
+ @Nullable K key = null;
+ long toProcessCount = 0;
+ for (KV elem : toProcess.read()) {
+ ++toProcessCount;
+ if (toProcessCount == 1) {
+ key = elem.getKey();
+ }
+ Object id = idFn.apply(elem.getValue());
+ InputT evicted = bufferMap.put(id, elem.getValue());
+ if (evicted != null) {
+ if (evicted.equals(elem)) {
+ LOG.error("Unexpected duplicate elements in buffer, only a single processing and output will be made.");
+ } else {
+ LOG.error("Unexpected id equality with differing elements! This is resulting in lost data. id={}, elem1={}, elem2={}",
+ id, elem, evicted);
+ }
+ }
+ }
+ if (toProcessCount == 0) {
+ // No buffered elements.
+ return;
+ }
+
+ LOG.debug("processing timer for key: {}", key);
+
+ LocalState state = localState.get();
+
+ final List outputs = new ArrayList<>();
+ final List finishedIds = new ArrayList<>();
+ final List itemsToReschedule = new ArrayList<>();
+
+
+ for (final Iterator> i = bufferMap.entrySet().iterator();
+ i.hasNext();) {
+ final Map.Entry elem = i.next();
+ final Object id = elem.getKey();
+ final InputT element = elem.getValue();
+
+ state.processingElements.compute(
+ id,
+ (@Nullable Object ignored,
+ @Nullable LocalState.ProcessingElements existing) -> {
+ if (existing == null) {
+ LOG.info(
+ "item {} found in processing state but not local state, scheduling now",
+ element);
+ itemsToReschedule.add(element);
+ return null;
+ }
+
+ if (!existing.future.isDone()) {
+ return existing;
+ }
+ try {
+ outputs.add(existing.future.get());
+ i.remove();
+ } catch (ExecutionException | InterruptedException e) {
+ LOG.error("Error processing asynchronously, retrying", e);
+ itemsToReschedule.add(element);
+ }
+ state.bufferedElementsSemaphore.release();
+ return null;
+ });
+ }
+ for (OutputT o : outputs) {
+ c.output(o);
+ }
+
+ // Update processing state to remove elements we've finished
+ if (bufferMap.size() != toProcessCount) {
+ toProcess.clear();
+ final K finalKey = key;
+ bufferMap.forEach((k, v) -> toProcess.add(KV.of(finalKey, v)));
+ }
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("items finished {}, items rescheduling {}, items in buffer {}",
+ outputs.size(), itemsToReschedule.size(),
+ maxItemsToBuffer - state.bufferedElementsSemaphore.availablePermits());
+ }
+
+ if (!bufferMap.isEmpty()) {
+ timer.set(nextTimeToFire(key));
+ }
+ }
+ }
+
+ /**
+ * Configurable parameters for the AsyncDoFn.
+ */
+ @AutoValue
+ public abstract static class AsyncDoFnOptions {
+
+ /**
+ * The maximum number of elements to process in parallel per worker for this dofn.
+ */
+ public abstract int getParallelism();
+
+ /**
+ * The frequency with which the runner will check for elements to commit.
+ */
+ public abstract int getCallbackFrequencyMillis();
+
+ /**
+ * We should ideally buffer enough to always be busy but not so much that the worker ooms.
+ */
+ public abstract int getMaxItemsToBuffer();
+
+ /**
+ * The maximum amount of time an item should try to be scheduled locally before it goes in the queue of waiting work.
+ */
+ public abstract int getTimeoutMillis();
+
+ /**
+ * The maximum amount of sleep time while attempting to schedule an item.
+ */
+ public abstract int getMaxWaitTimeMillis();
+
+ /**
+ * A function that extracts an id from an element to be used as keys in maps.
+ * The id should have a 1:1 relationship with the entire element, that is
+ * getIdFn()(input1).equals(getIdFn()(input2)) iff input1.equals(input2).
+ * By default, the id is the entire element.
+ */
+ public abstract SerializableFunction getIdFn();
+
+ public static Builder builder() {
+ return new AutoValue_AsyncDoFn_AsyncDoFnOptions.Builder()
+ .setParallelism(1)
+ .setCallbackFrequencyMillis(5)
+ .setMaxItemsToBuffer(20)
+ .setTimeoutMillis(1)
+ .setMaxWaitTimeMillis(1)
+ .setIdFn(i -> i);
+ }
+
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setParallelism(int parallelism);
+
+ public abstract Builder setCallbackFrequencyMillis(int callbackFrequencyMillis);
+
+ public abstract Builder setMaxItemsToBuffer(int maxItemsToBuffer);
+
+ public abstract Builder setTimeoutMillis(int timeoutMillis);
+
+ public abstract Builder setMaxWaitTimeMillis(double maxWaitTimeMillis);
+
+ public abstract Builder setIdFn(Function idFn);
+
+ public abstract AsyncDoFnOptions build();
+ }
+ }
+ }