Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 36 additions & 41 deletions src/main/java/io/kestra/plugin/aws/sns/Publish.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.metrics.Counter;
import io.kestra.core.models.property.Data;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.runners.RunContext;
import io.kestra.core.serializers.FileSerde;
Expand All @@ -24,6 +26,7 @@
import java.net.URI;
import java.util.List;
import jakarta.validation.constraints.NotNull;
import java.util.Map;

import static io.kestra.core.utils.Rethrow.throwFunction;

Expand Down Expand Up @@ -95,8 +98,7 @@
)
}
)
public class Publish extends AbstractSns implements RunnableTask<Publish.Output> {
@PluginProperty(dynamic = true)
public class Publish extends AbstractSns implements RunnableTask<Publish.Output>,Data.From {
@NotNull
@Schema(
title = "The source of the published data.",
Expand All @@ -109,37 +111,38 @@ public class Publish extends AbstractSns implements RunnableTask<Publish.Output>
public Publish.Output run(RunContext runContext) throws Exception {
var topicArn = runContext.render(getTopicArn()).as(String.class).orElseThrow();
try (var snsClient = this.client(runContext)) {
Integer count;
Flux<Message> flowable;
Flux<Integer> resultFlowable;

if (this.from instanceof String) {
URI from = new URI(runContext.render((String) this.from));
if (!from.getScheme().equals("kestra")) {
throw new Exception("Invalid 'from' parameter, must be a Kestra internal storage URI");
}

try (BufferedReader inputStream = new BufferedReader(new InputStreamReader(runContext.storage().getFile(from)))) {
flowable = FileSerde.readAll(inputStream, Message.class);
resultFlowable = this.buildFlowable(flowable, snsClient, topicArn, runContext);

count = resultFlowable.reduce(Integer::sum).blockOptional().orElse(0);
}

} else if (this.from instanceof List) {
flowable = Flux
.fromIterable((List<?>) this.from)
.map(map -> JacksonMapper.toMap(map, Message.class));

resultFlowable = this.buildFlowable(flowable, snsClient, topicArn, runContext);

count = resultFlowable.reduce(Integer::sum).blockOptional().orElse(0);
} else {
var msg = JacksonMapper.toMap(this.from, Message.class);
snsClient.publish(msg.to(PublishRequest.builder().topicArn(topicArn), runContext));

count = 1;
}
Integer count = Data.from(from).read(runContext)
.map(throwFunction(raw -> {
Message message;

if (raw instanceof Message) {
message = (Message) raw;
} else if (raw instanceof Map) {
message = JacksonMapper.ofJson().convertValue(raw, Message.class);
} else if (raw instanceof String || raw instanceof Map) {
String str = raw.toString();
try {
message = JacksonMapper.ofJson().readValue(str, Message.class);
} catch (Exception e) {
message = Message.builder()
.data(str)
.build();
}
} else {
throw new IllegalArgumentException("Unsupported message type: " + raw.getClass());
}
snsClient.publish(PublishRequest.builder()
.topicArn(topicArn)
.message(message.getData())
.subject(message.getSubject())
.build()
);

return 1;
}))
.reduce(Integer::sum)
.blockOptional()
.orElse(0);

// metrics
runContext.metric(Counter.of("sns.publish.messages", count, "topic", topicArn));
Expand All @@ -150,14 +153,6 @@ public Publish.Output run(RunContext runContext) throws Exception {
}
}

private Flux<Integer> buildFlowable(Flux<Message> flowable, SnsClient snsClient, String topicArn, RunContext runContext) throws IllegalVariableEvaluationException {
return flowable
.map(throwFunction(message -> {
snsClient.publish(message.to(PublishRequest.builder().topicArn(topicArn), runContext));
return 1;
}));
}

@Builder
@Getter
public static class Output implements io.kestra.core.models.tasks.Output {
Expand Down
81 changes: 39 additions & 42 deletions src/main/java/io/kestra/plugin/aws/sqs/Publish.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.metrics.Counter;
import io.kestra.core.models.property.Data;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.runners.RunContext;
import io.kestra.core.serializers.FileSerde;
Expand Down Expand Up @@ -92,8 +94,7 @@
)
}
)
public class Publish extends AbstractSqs implements RunnableTask<Publish.Output> {
@PluginProperty(dynamic = true)
public class Publish extends AbstractSqs implements RunnableTask<Publish.Output>,Data.From {
@NotNull
@Schema(
title = "The source of the published data.",
Expand All @@ -107,39 +108,42 @@ public class Publish extends AbstractSqs implements RunnableTask<Publish.Output>
public Output run(RunContext runContext) throws Exception {
var queueUrl = runContext.render(getQueueUrl()).as(String.class).orElseThrow();
try (var sqsClient = this.client(runContext)) {
Integer count;
Flux<Message> flowable;
Flux<Integer> resultFlowable;

if (this.from instanceof String) {
URI from = new URI(runContext.render((String) this.from));
if (!from.getScheme().equals("kestra")) {
throw new Exception("Invalid from parameter, must be a Kestra internal storage URI");
}


try (BufferedReader inputStream = new BufferedReader(new InputStreamReader(runContext.storage().getFile(from)))) {
flowable = FileSerde.readAll(inputStream, Message.class);
resultFlowable = this.buildFlowable(flowable, sqsClient, queueUrl, runContext);

count = resultFlowable.reduce(Integer::sum).blockOptional().orElse(0);
}

} else if (this.from instanceof List) {
flowable = Flux
.fromIterable((List<?>) this.from)
.map(map -> JacksonMapper.toMap(map, Message.class));

resultFlowable = this.buildFlowable(flowable, sqsClient, queueUrl, runContext);

count = resultFlowable.reduce(Integer::sum).blockOptional().orElse(0);
} else {
var msg = JacksonMapper.toMap(this.from, Message.class);
sqsClient.sendMessage(msg.to(SendMessageRequest.builder().queueUrl(queueUrl), runContext));

count = 1;
}

Integer count = Data.from(from).read(runContext)
.map(throwFunction(raw -> {
Message message;

if (raw instanceof Message) {
message = (Message) raw;
} else if (raw instanceof Map) {
message = JacksonMapper.ofJson().convertValue(raw, Message.class);
} else if (raw instanceof String || raw instanceof Map) {
String str = raw.toString();
try {
message = JacksonMapper.ofJson().readValue(str, Message.class);
} catch (Exception e) {
message = Message.builder()
.data(str)
.build();
}
} else {
throw new IllegalArgumentException("Unsupported message type: " + raw.getClass());
}

var builder = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody(message.getData());

if (message.getDelaySeconds() != null) {
builder.delaySeconds(message.getDelaySeconds());
}

sqsClient.sendMessage(builder.build());
return 1;
}))
.reduce(Integer::sum)
.blockOptional()
.orElse(0);

// metrics
runContext.metric(Counter.of("sqs.publish.messages", count, "queue", queueUrl));

Expand All @@ -149,13 +153,6 @@ public Output run(RunContext runContext) throws Exception {
}
}

private Flux<Integer> buildFlowable(Flux<Message> flowable, SqsClient sqsClient, String queueUrl, RunContext runContext) throws IllegalVariableEvaluationException {
return flowable
.map(throwFunction(message -> {
sqsClient.sendMessage(message.to(SendMessageRequest.builder().queueUrl(queueUrl), runContext));
return 1;
}));
}

@Builder
@Getter
Expand Down
Loading