diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/OrchestrationMiddleware.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/OrchestrationMiddleware.java index 448d5936..95a2cd27 100644 --- a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/OrchestrationMiddleware.java +++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/OrchestrationMiddleware.java @@ -9,9 +9,14 @@ import com.microsoft.azure.functions.internal.spi.middleware.Middleware; import com.microsoft.azure.functions.internal.spi.middleware.MiddlewareChain; import com.microsoft.azure.functions.internal.spi.middleware.MiddlewareContext; +import com.microsoft.durabletask.DataConverter; import com.microsoft.durabletask.OrchestrationRunner; import com.microsoft.durabletask.OrchestratorBlockedException; +import java.util.Iterator; +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicBoolean; + /** * Durable Function Orchestration Middleware * @@ -21,14 +26,19 @@ public class OrchestrationMiddleware implements Middleware { private static final String ORCHESTRATION_TRIGGER = "DurableOrchestrationTrigger"; + private final Object dataConverterLock = new Object(); + private volatile DataConverter dataConverter; + private final AtomicBoolean oneTimeLogicExecuted = new AtomicBoolean(false); @Override public void invoke(MiddlewareContext context, MiddlewareChain chain) throws Exception { String parameterName = context.getParameterName(ORCHESTRATION_TRIGGER); - if (parameterName == null){ + if (parameterName == null) { chain.doNext(context); return; } + //invoked only for orchestrator function. + loadCustomizedDataConverterOnce(); String orchestratorRequestEncodedProtoBytes = (String) context.getParameterValue(parameterName); String orchestratorOutputEncodedProtoBytes = OrchestrationRunner.loadAndRun(orchestratorRequestEncodedProtoBytes, taskOrchestrationContext -> { try { @@ -39,12 +49,29 @@ public void invoke(MiddlewareContext context, MiddlewareChain chain) throws Exce // The OrchestratorBlockedEvent will be wrapped into InvocationTargetException by using reflection to // invoke method. Thus get the cause to check if it's OrchestratorBlockedEvent. Throwable cause = e.getCause(); - if (cause instanceof OrchestratorBlockedException){ + if (cause instanceof OrchestratorBlockedException) { throw (OrchestratorBlockedException) cause; } throw new RuntimeException("Unexpected failure in the task execution", e); } - }); + }, this.dataConverter); context.updateReturnValue(orchestratorOutputEncodedProtoBytes); } + + private void loadCustomizedDataConverterOnce() { + if (!oneTimeLogicExecuted.get()) { + synchronized (dataConverterLock) { + if (!oneTimeLogicExecuted.get()) { + Iterator iterator = ServiceLoader.load(DataConverter.class).iterator(); + if (iterator.hasNext()) { + this.dataConverter = iterator.next(); + if (iterator.hasNext()) { + throw new IllegalStateException("Multiple implementations of DataConverter found on the classpath."); + } + } + oneTimeLogicExecuted.compareAndSet(false,true); + } + } + } + } } diff --git a/client/src/main/java/com/microsoft/durabletask/OrchestrationRunner.java b/client/src/main/java/com/microsoft/durabletask/OrchestrationRunner.java index 13904901..e0ddd193 100644 --- a/client/src/main/java/com/microsoft/durabletask/OrchestrationRunner.java +++ b/client/src/main/java/com/microsoft/durabletask/OrchestrationRunner.java @@ -36,10 +36,11 @@ private OrchestrationRunner() { */ public static String loadAndRun( String base64EncodedOrchestratorRequest, - OrchestratorFunction orchestratorFunc) { + OrchestratorFunction orchestratorFunc, + DataConverter dataConverter) { // Example string: CiBhOTMyYjdiYWM5MmI0MDM5YjRkMTYxMDIwNzlmYTM1YSIaCP///////////wESCwi254qRBhDk+rgocgAicgj///////////8BEgwIs+eKkQYQzMXjnQMaVwoLSGVsbG9DaXRpZXMSACJGCiBhOTMyYjdiYWM5MmI0MDM5YjRkMTYxMDIwNzlmYTM1YRIiCiA3ODEwOTA2N2Q4Y2Q0ODg1YWU4NjQ0OTNlMmRlMGQ3OA== byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedOrchestratorRequest); - byte[] resultBytes = loadAndRun(decodedBytes, orchestratorFunc); + byte[] resultBytes = loadAndRun(decodedBytes, orchestratorFunc, dataConverter); return Base64.getEncoder().encodeToString(resultBytes); } @@ -55,7 +56,8 @@ public static String loadAndRun( */ public static byte[] loadAndRun( byte[] orchestratorRequestBytes, - OrchestratorFunction orchestratorFunc) { + OrchestratorFunction orchestratorFunc, + DataConverter dataConverter) { if (orchestratorFunc == null) { throw new IllegalArgumentException("orchestratorFunc must not be null"); } @@ -66,7 +68,7 @@ public static byte[] loadAndRun( ctx.complete(output); }; - return loadAndRun(orchestratorRequestBytes, orchestration); + return loadAndRun(orchestratorRequestBytes, orchestration, dataConverter); } /** @@ -82,7 +84,7 @@ public static String loadAndRun( String base64EncodedOrchestratorRequest, TaskOrchestration orchestration) { byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedOrchestratorRequest); - byte[] resultBytes = loadAndRun(decodedBytes, orchestration); + byte[] resultBytes = loadAndRun(decodedBytes, orchestration, null); return Base64.getEncoder().encodeToString(resultBytes); } @@ -95,7 +97,7 @@ public static String loadAndRun( * @return a protobuf-encoded payload of orchestrator actions to be interpreted by the external orchestration engine * @throws IllegalArgumentException if either parameter is {@code null} or if {@code orchestratorRequestBytes} is not valid protobuf */ - public static byte[] loadAndRun(byte[] orchestratorRequestBytes, TaskOrchestration orchestration) { + public static byte[] loadAndRun(byte[] orchestratorRequestBytes, TaskOrchestration orchestration, DataConverter dataConverter) { if (orchestratorRequestBytes == null || orchestratorRequestBytes.length == 0) { throw new IllegalArgumentException("triggerStateProtoBytes must not be null or empty"); } @@ -127,7 +129,7 @@ public TaskOrchestration create() { TaskOrchestrationExecutor taskOrchestrationExecutor = new TaskOrchestrationExecutor( orchestrationFactories, - new JacksonDataConverter(), + dataConverter != null ? dataConverter : new JacksonDataConverter(), DEFAULT_MAXIMUM_TIMER_INTERVAL, logger); diff --git a/samples-azure-functions/build.gradle b/samples-azure-functions/build.gradle index e262a338..8c23f384 100644 --- a/samples-azure-functions/build.gradle +++ b/samples-azure-functions/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation project(':azurefunctions') implementation 'com.microsoft.azure.functions:azure-functions-java-library:3.0.0' + implementation 'com.google.code.gson:gson:2.9.0' testImplementation 'org.junit.jupiter:junit-jupiter:5.6.2' testImplementation 'io.rest-assured:rest-assured:5.3.0' testImplementation 'io.rest-assured:json-path:5.3.0' diff --git a/samples-azure-functions/src/main/java/com/functions/AzureFunctions.java b/samples-azure-functions/src/main/java/com/functions/AzureFunctions.java index ad804370..8a52e4c3 100644 --- a/samples-azure-functions/src/main/java/com/functions/AzureFunctions.java +++ b/samples-azure-functions/src/main/java/com/functions/AzureFunctions.java @@ -36,12 +36,14 @@ public HttpResponseMessage startOrchestration( */ @FunctionName("Cities") public String citiesOrchestrator( - @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx, + final ExecutionContext context) { String result = ""; result += ctx.callActivity("Capitalize", "Tokyo", String.class).await() + ", "; result += ctx.callActivity("Capitalize", "London", String.class).await() + ", "; result += ctx.callActivity("Capitalize", "Seattle", String.class).await() + ", "; result += ctx.callActivity("Capitalize", "Austin", String.class).await(); + context.getLogger().info("Orchestrator function completed!"); return result; } diff --git a/samples-azure-functions/src/main/java/com/functions/CustomizeDataConverter.java b/samples-azure-functions/src/main/java/com/functions/CustomizeDataConverter.java new file mode 100644 index 00000000..88a71ed4 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/CustomizeDataConverter.java @@ -0,0 +1,59 @@ +package com.functions; + +import com.microsoft.azure.functions.ExecutionContext; +import com.microsoft.azure.functions.HttpMethod; +import com.microsoft.azure.functions.HttpRequestMessage; +import com.microsoft.azure.functions.HttpResponseMessage; +import com.microsoft.azure.functions.annotation.AuthorizationLevel; +import com.microsoft.azure.functions.annotation.FunctionName; +import com.microsoft.azure.functions.annotation.HttpTrigger; +import com.microsoft.durabletask.DurableTaskClient; +import com.microsoft.durabletask.TaskOrchestrationContext; +import com.microsoft.durabletask.azurefunctions.DurableActivityTrigger; +import com.microsoft.durabletask.azurefunctions.DurableClientContext; +import com.microsoft.durabletask.azurefunctions.DurableClientInput; +import com.microsoft.durabletask.azurefunctions.DurableOrchestrationTrigger; + +import java.time.LocalDate; +import java.util.Optional; + +public class CustomizeDataConverter { + + @FunctionName("StartCustomize") + public HttpResponseMessage startExampleProcess( + @HttpTrigger(name = "req", + methods = {HttpMethod.GET, HttpMethod.POST}, + authLevel = AuthorizationLevel.ANONYMOUS) final HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") final DurableClientContext durableContext, + final ExecutionContext context) { + context.getLogger().info("Java HTTP trigger processed a request"); + + final DurableTaskClient client = durableContext.getClient(); + final String instanceId = client.scheduleNewOrchestrationInstance("Customize"); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + @FunctionName("Customize") + public ExampleResponse exampleOrchestrator( + @DurableOrchestrationTrigger(name = "taskOrchestrationContext") final TaskOrchestrationContext context, + final ExecutionContext functionContext) { + return context.callActivity("ToLower", "Foo", ExampleResponse.class).await(); + } + + @FunctionName("ToLower") + public ExampleResponse toLower( + @DurableActivityTrigger(name = "value") final String value, + final ExecutionContext context) { + return new ExampleResponse(LocalDate.now(), value.toLowerCase()); + } + + static class ExampleResponse { + private final LocalDate date; + private final String value; + + public ExampleResponse(LocalDate date, String value) { + this.date = date; + this.value = value; + } + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/converter/MyConverter.java b/samples-azure-functions/src/main/java/com/functions/converter/MyConverter.java new file mode 100644 index 00000000..d1e8cb10 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/converter/MyConverter.java @@ -0,0 +1,18 @@ +package com.functions.converter; + +import com.google.gson.Gson; +import com.microsoft.durabletask.DataConverter; + +public class MyConverter implements DataConverter { + + private static final Gson gson = new Gson(); + @Override + public String serialize(Object value) { + return gson.toJson(value); + } + + @Override + public T deserialize(String data, Class target) { + return gson.fromJson(data, target); + } +} diff --git a/samples-azure-functions/src/main/resources/META-INF/services/com.microsoft.durabletask.DataConverter b/samples-azure-functions/src/main/resources/META-INF/services/com.microsoft.durabletask.DataConverter new file mode 100644 index 00000000..4fac2b8c --- /dev/null +++ b/samples-azure-functions/src/main/resources/META-INF/services/com.microsoft.durabletask.DataConverter @@ -0,0 +1 @@ +com.functions.converter.MyConverter \ No newline at end of file diff --git a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java index fa5f64df..887b1ae4 100644 --- a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java +++ b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java @@ -15,6 +15,18 @@ @Tag("e2e") public class EndToEndTests { + private String waitForCompletion(String statusQueryGetUri) throws InterruptedException { + String runTimeStatus = null; + for (int i = 0; i < 15; i++) { + Response statusResponse = get(statusQueryGetUri); + runTimeStatus = statusResponse.jsonPath().get("runtimeStatus"); + if (!"Completed".equals(runTimeStatus)) { + Thread.sleep(1000); + } else break; + } + return runTimeStatus; + } + @Order(1) @Test public void setupHost() { @@ -82,16 +94,13 @@ public void restart(boolean restartWithNewInstanceId) throws InterruptedExceptio } } - private String waitForCompletion(String statusQueryGetUri) throws InterruptedException { - String runTimeStatus = null; - for (int i = 0; i < 15; i++) { - Response statusResponse = get(statusQueryGetUri); - runTimeStatus = statusResponse.jsonPath().get("runtimeStatus"); - if (!"Completed".equals(runTimeStatus)) { - Thread.sleep(1000); - } else break; - } - return runTimeStatus; + @Test + public void customizeDataConverter() throws InterruptedException { + String startOrchestrationPath = "/api/StartCustomize"; + Response response = post(startOrchestrationPath); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + String runTimeStatus = waitForCompletion(statusQueryGetUri); + assertEquals("Completed", runTimeStatus); } - }