diff --git a/acl-groovy-dsl/src/main/resources/acl.gdsl b/acl-groovy-dsl/src/main/resources/acl.gdsl index 9da036ed..1ec2f915 100644 --- a/acl-groovy-dsl/src/main/resources/acl.gdsl +++ b/acl-groovy-dsl/src/main/resources/acl.gdsl @@ -33,7 +33,7 @@ contributor(context(scope: scriptScope())) { && !enclosingCall("allOf") && !enclosingCall("anyOf")) { method name: 'topicFilter', type: 'javasabr.mqtt.acl.groovy.dsl.builder.SubscribeRuleBuilder', - params: [string: 'javasabr.mqtt.model.acl.matcher.ValueMatcher...'], + params: [string: 'javasabr.mqtt.model.acl.matcher.TopicFilterMatcher...'], doc: 'Set of topic filters matching by rule' method name: 'match', type: 'javasabr.mqtt.model.acl.matcher.TopicFilterMatcher', diff --git a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java index f330b24e..0714f562 100644 --- a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java +++ b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java @@ -21,6 +21,7 @@ import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.PublishReceivingService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; import javasabr.mqtt.service.handler.client.ExternalNetworkMqttUserReleaseHandler; @@ -29,6 +30,7 @@ import javasabr.mqtt.service.impl.DefaultMqttConnectionFactory; import javasabr.mqtt.service.impl.DefaultPublishDeliveringService; import javasabr.mqtt.service.impl.DefaultPublishReceivingService; +import javasabr.mqtt.service.impl.InMemoryRetainMessageService; import javasabr.mqtt.service.impl.DefaultTopicService; import javasabr.mqtt.service.impl.DisabledAuthorizationService; import javasabr.mqtt.service.impl.ExternalNetworkMqttUserFactory; @@ -119,6 +121,11 @@ SubscriptionService subscriptionService() { return new InMemorySubscriptionService(); } + @Bean + RetainMessageService retainMessageService() { + return new InMemoryRetainMessageService(); + } + @Bean MqttMessageOutFactory mqtt311MessageOutFactory() { return new Mqtt311MessageOutFactory(); @@ -211,7 +218,8 @@ MqttInMessageHandler publishMqttInMessageHandler( return new PublishMqttInMessageHandler( publishReceivingService, messageOutFactoryService, - topicService, authorizationService, + topicService, + authorizationService, fieldValidators); } @@ -234,8 +242,15 @@ MqttInMessageHandler disconnectMqttInMessageHandler(MessageOutFactoryService mes MqttInMessageHandler subscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, - TopicService topicService) { - return new SubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService); + TopicService topicService, + RetainMessageService retainMessageService, + PublishDeliveringService publishDeliveringService) { + return new SubscribeMqttInMessageHandler( + subscriptionService, + messageOutFactoryService, + topicService, + retainMessageService, + publishDeliveringService); } @Bean @@ -260,14 +275,12 @@ MqttPublishOutMessageHandler qos0MqttPublishOutMessageHandler(MessageOutFactoryS } @Bean - MqttPublishOutMessageHandler qos1MqttPublishOutMessageHandler( - MessageOutFactoryService messageOutFactoryService) { + MqttPublishOutMessageHandler qos1MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { return new Qos1MqttPublishOutMessageHandler(messageOutFactoryService); } @Bean - MqttPublishOutMessageHandler qos2MqttPublishOutMessageHandler( - MessageOutFactoryService messageOutFactoryService) { + MqttPublishOutMessageHandler qos2MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { return new Qos2MqttPublishOutMessageHandler(messageOutFactoryService); } @@ -281,33 +294,39 @@ PublishDeliveringService publishDeliveringService( MqttPublishInMessageHandler qos0MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { return new Qos0MqttPublishInMessageHandler( subscriptionService, publishDeliveringService, - messageOutFactoryService); + messageOutFactoryService, + retainMessageService); } @Bean MqttPublishInMessageHandler qos1MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { return new Qos1MqttPublishInMessageHandler( subscriptionService, publishDeliveringService, - messageOutFactoryService); + messageOutFactoryService, + retainMessageService); } @Bean MqttPublishInMessageHandler qos2MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { return new Qos2MqttPublishInMessageHandler( subscriptionService, publishDeliveringService, - messageOutFactoryService); + messageOutFactoryService, + retainMessageService); } @Bean diff --git a/core-service/build.gradle b/core-service/build.gradle index 04b7b9e8..155f6ee6 100644 --- a/core-service/build.gradle +++ b/core-service/build.gradle @@ -12,4 +12,5 @@ dependencies { testImplementation projects.testSupport testImplementation testFixtures(projects.network) -} \ No newline at end of file + testImplementation testFixtures(projects.model) +} diff --git a/core-service/src/main/java/javasabr/mqtt/service/RetainMessageService.java b/core-service/src/main/java/javasabr/mqtt/service/RetainMessageService.java new file mode 100644 index 00000000..d7b58d3e --- /dev/null +++ b/core-service/src/main/java/javasabr/mqtt/service/RetainMessageService.java @@ -0,0 +1,13 @@ +package javasabr.mqtt.service; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.subscription.Subscription; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.rlib.collections.array.Array; + +public interface RetainMessageService { + + void retainMessage(Publish publish); + + Array getRetainedMessages(TopicFilter topicFilter); +} diff --git a/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java b/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java index b8de7ded..6da47a5b 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java @@ -1,11 +1,11 @@ package javasabr.mqtt.service; import javasabr.mqtt.model.MqttUser; -import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode; import javasabr.mqtt.model.session.MqttSession; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscription.Subscription; +import javasabr.mqtt.model.subscription.SubscriptionResult; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; import javasabr.rlib.collections.array.Array; @@ -15,7 +15,7 @@ * Subscription service */ public interface SubscriptionService { - + default Array findSubscribers(TopicName topicName) { return findSubscribersTo(MutableArray.ofType(SingleSubscriber.class), topicName); } @@ -29,7 +29,7 @@ default Array findSubscribers(TopicName topicName) { * @param subscriptions the list of request to subscribe topics * @return array of subscribe ack reason codes */ - Array subscribe(MqttUser user, MqttSession session, Array subscriptions); + Array subscribe(MqttUser user, MqttSession session, Array subscriptions); /** * Removes MQTT client from listening to the topics. diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemoryRetainMessageService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemoryRetainMessageService.java new file mode 100644 index 00000000..c56a4169 --- /dev/null +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemoryRetainMessageService.java @@ -0,0 +1,33 @@ +package javasabr.mqtt.service.impl; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.tree.ConcurrentRetainedMessageTree; +import javasabr.mqtt.service.RetainMessageService; +import javasabr.rlib.collections.array.Array; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; + +@FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) +public class InMemoryRetainMessageService implements RetainMessageService { + + ConcurrentRetainedMessageTree retainedMessageTree; + + public InMemoryRetainMessageService() { + this.retainedMessageTree = new ConcurrentRetainedMessageTree(); + } + + @Override + public void retainMessage(Publish publish) { + if (publish.payload().length == 0) { + retainedMessageTree.removeRetainedMessage(publish.topicName()); + } else { + retainedMessageTree.addRetainedMessage(publish); + } + } + + @Override + public Array getRetainedMessages(TopicFilter topicFilter) { + return retainedMessageTree.getRetainedMessages(topicFilter); + } +} diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java index b9bd0e86..9bc639b4 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java @@ -12,6 +12,7 @@ import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.tree.ConcurrentSubscriberTree; import javasabr.mqtt.model.subscription.Subscription; +import javasabr.mqtt.model.subscription.SubscriptionResult; import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; @@ -30,6 +31,13 @@ @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class InMemorySubscriptionService implements SubscriptionService { + private static final SubscriptionResult INVALID_TOPIC_FILTER_RESULT = + new SubscriptionResult(SubscribeAckReasonCode.TOPIC_FILTER_INVALID); + private static final SubscriptionResult SHARED_SUBSCRIPTION_NOT_SUPPORTED_RESULT = + new SubscriptionResult(SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED); + private static final SubscriptionResult WILDCARD_SUBSCRIPTION_NOT_SUPPORTED_RESULT = + new SubscriptionResult(SubscribeAckReasonCode.WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED); + ConcurrentSubscriberTree subscriberTree; public InMemorySubscriptionService() { @@ -44,13 +52,13 @@ public Array findSubscribersTo(MutableArray } @Override - public Array subscribe( + public Array subscribe( MqttUser user, MqttSession session, Array subscriptions) { - MutableArray subscribeResults = ArrayFactory.mutableArray( - SubscribeAckReasonCode.class, + MutableArray subscribeResults = ArrayFactory.mutableArray( + SubscriptionResult.class, subscriptions.size()); for (Subscription subscription : subscriptions) { @@ -60,23 +68,25 @@ public Array subscribe( return subscribeResults; } - private SubscribeAckReasonCode addSubscription(MqttUser user, MqttSession session, Subscription subscription) { + private SubscriptionResult addSubscription(MqttUser user, MqttSession session, Subscription subscription) { MqttClientConnectionConfig connectionConfig = user.connectionConfig(); TopicFilter topicFilter = subscription.topicFilter(); if (topicFilter.isInvalid()) { - return SubscribeAckReasonCode.TOPIC_FILTER_INVALID; + return INVALID_TOPIC_FILTER_RESULT; } else if (!connectionConfig.sharedSubscriptionAvailable() && topicFilter instanceof SharedTopicFilter) { - return SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED; + return SHARED_SUBSCRIPTION_NOT_SUPPORTED_RESULT; } else if (!connectionConfig.wildcardSubscriptionAvailable() && topicFilter.wildcard()) { - return SubscribeAckReasonCode.WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED; + return WILDCARD_SUBSCRIPTION_NOT_SUPPORTED_RESULT; } ActiveSubscriptions activeSubscriptions = session.activeSubscriptions(); - SingleSubscriber previous = subscriberTree.subscribe(user, subscription); - if (previous != null) { - activeSubscriptions.remove(previous.subscription()); + SingleSubscriber newSubscriber = new SingleSubscriber(user, subscription); + SingleSubscriber previousSubscriber = subscriberTree.subscribe(newSubscriber); + boolean isSubscriptionAlreadyExisted = previousSubscriber != null; + if (isSubscriptionAlreadyExisted) { + activeSubscriptions.remove(previousSubscriber.subscription()); } activeSubscriptions.add(subscription); - return subscription.qos().subscribeAckReasonCode(); + return new SubscriptionResult(newSubscriber, isSubscriptionAlreadyExisted); } @Override @@ -125,7 +135,7 @@ public void restoreSubscriptions(MqttUser user, MqttSession session) { .activeSubscriptions() .subscriptions(); for (Subscription subscription : subscriptions) { - subscriberTree.subscribe(user, subscription); + subscriberTree.subscribe(new SingleSubscriber(user, subscription)); } } } diff --git a/core-service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java index a80ab1cd..e61fee0c 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java @@ -1,5 +1,7 @@ package javasabr.mqtt.service.message.handler.impl; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST; import static javasabr.mqtt.model.reason.code.SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED; import static javasabr.mqtt.model.reason.code.SubscribeAckReasonCode.WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED; @@ -7,12 +9,16 @@ import javasabr.mqtt.model.MqttClientConnectionConfig; import javasabr.mqtt.model.MqttProperties; import javasabr.mqtt.model.QoS; +import javasabr.mqtt.model.SubscribeRetainHandling; import javasabr.mqtt.model.message.MqttMessageType; +import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.reason.code.DisconnectReasonCode; import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; import javasabr.mqtt.model.session.MessageTacker; +import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscription.RequestedSubscription; import javasabr.mqtt.model.subscription.Subscription; +import javasabr.mqtt.model.subscription.SubscriptionResult; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.network.MqttConnection; import javasabr.mqtt.network.impl.ExternalNetworkMqttUser; @@ -20,9 +26,12 @@ import javasabr.mqtt.network.message.out.MqttOutMessage; import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; +import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.ArrayCollectors; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.MutableArray; import lombok.AccessLevel; @@ -40,14 +49,20 @@ public class SubscribeMqttInMessageHandler extends SubscriptionService subscriptionService; TopicService topicService; + RetainMessageService retainMessageService; + PublishDeliveringService publishDeliveringService; public SubscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, - TopicService topicService) { + TopicService topicService, + RetainMessageService retainMessageService, + PublishDeliveringService publishDeliveringService) { super(ExternalNetworkMqttUser.class, SubscribeMqttInMessage.class, messageOutFactoryService); this.subscriptionService = subscriptionService; this.topicService = topicService; + this.retainMessageService = retainMessageService; + this.publishDeliveringService = publishDeliveringService; } @Override @@ -90,25 +105,30 @@ protected void processValidMessage( subscribeMessage.subscriptions(), subscriptionId); - Array subscribeResults = subscriptionService + Array subscribeResults = subscriptionService .subscribe(user, session, subscriptions); - sendSubscribeResults(user, session, subscribeMessage, subscribeResults); + sendRetainedMessages(subscribeResults); - SubscribeAckReasonCode anyReasonToDisconnect = subscribeResults + SubscriptionResult anyDisconnectResult = subscribeResults .iterations() .reversedArgs() - .findAny(DISCONNECT_CASES, Set::contains); - - if (anyReasonToDisconnect != null) { - log.info(user.clientId(), anyReasonToDisconnect, "[%s] Will be forced closing by reason:[%s]"::formatted); - DisconnectReasonCode reasonCode = DisconnectReasonCode.ofCode(anyReasonToDisconnect.code()); - user.closeWithReason(messageOutFactoryService - .resolveFactory(user) - .newDisconnect(user, reasonCode)); + .findAny(DISCONNECT_CASES, SubscribeMqttInMessageHandler::containsSubscribeAckReasonCode); + + if (anyDisconnectResult != null) { + SubscribeAckReasonCode subackReasonCode = anyDisconnectResult.subscribeAckReasonCode(); + log.info(user.clientId(), subackReasonCode, "[%s] Will be forced closing by reason:[%s]"::formatted); + DisconnectReasonCode reasonCode = DisconnectReasonCode.ofCode(subackReasonCode.code()); + user.closeWithReason(messageOutFactoryService.resolveFactory(user).newDisconnect(user, reasonCode)); } } + private static boolean containsSubscribeAckReasonCode( + Set reasonCodes, + SubscriptionResult subscriptionResult) { + return reasonCodes.contains(subscriptionResult.subscribeAckReasonCode()); + } + private Array transformSubscriptions( MqttClientConnectionConfig connectionConfig, ExternalNetworkMqttUser user, @@ -166,14 +186,46 @@ private void sendSubscribeResults( ExternalNetworkMqttUser user, NetworkMqttSession session, SubscribeMqttInMessage subscribeMessage, - Array subscribeResults) { + Array subscribeResults) { int messageId = subscribeMessage.messageId(); + Array ackReasonCodes = subscribeResults.stream() + .map(SubscriptionResult::subscribeAckReasonCode) + .collect(ArrayCollectors.toArray(SubscribeAckReasonCode.class)); MqttOutMessage response = messageOutFactoryService .resolveFactory(user) - .newSubscribeAck(messageId, subscribeResults); + .newSubscribeAck(messageId, ackReasonCodes); user.sendAsync(response) .thenAccept(_ -> session .inMessageTracker() .remove(messageId)); } + + private void sendRetainedMessages(Array subscribeResults) { + for (SubscriptionResult subscriptionResult : subscribeResults) { + SingleSubscriber subscriber = subscriptionResult.subscriber(); + if (subscriber == null || !isRetainHandlingRequired(subscriber, subscriptionResult)) { + continue; + } + Subscription subscription = subscriber.subscription(); + boolean retainAsPublished = subscription.retainAsPublished(); + var retainedMessages = retainMessageService.getRetainedMessages(subscription.topicFilter()); + for (Publish retainedMessage : retainedMessages) { + if (!retainAsPublished) { + retainedMessage = retainedMessage.withoutRetain(); + } + publishDeliveringService.startDelivering(retainedMessage, subscriber); + } + } + } + + private static boolean isRetainHandlingRequired(SingleSubscriber subscriber, SubscriptionResult subscriptionResult) { + Subscription subscription = subscriber.subscription(); + if (subscription.topicFilter().isShared()) { + return false; + } else { + SubscribeRetainHandling retainHandling = subscription.retainHandling(); + return retainHandling == SEND || (retainHandling == SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST + && subscriptionResult.isNotExistedPreviously()); + } + } } diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java index 5a0c093d..f2dcf07f 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java @@ -10,6 +10,7 @@ import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.MqttPublishInMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; @@ -29,6 +30,7 @@ public abstract class AbstractMqttPublishInMessageHandler subscribers = subscriptionService.findSubscribers(topicName); if (subscribers.isEmpty()) { diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java index 5a297f59..1a697ca9 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java @@ -10,6 +10,7 @@ import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; public class Qos0MqttPublishInMessageHandler extends AbstractMqttPublishInMessageHandler { @@ -17,8 +18,14 @@ public class Qos0MqttPublishInMessageHandler extends AbstractMqttPublishInMessag public Qos0MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + ExternalNetworkMqttUser.class, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java index b7825bc9..0e61d3f3 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java @@ -8,8 +8,7 @@ import javasabr.mqtt.service.MessageOutFactoryService; import org.jspecify.annotations.Nullable; -public class Qos0MqttPublishOutMessageHandler - extends AbstractMqttPublishOutMessageHandler { +public class Qos0MqttPublishOutMessageHandler extends AbstractMqttPublishOutMessageHandler { public Qos0MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { super(ExternalNetworkMqttUser.class, messageOutFactoryService); diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java index 76c1446d..b5b76333 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java @@ -11,6 +11,7 @@ import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -24,8 +25,14 @@ public class Qos1MqttPublishInMessageHandler extends TrackableMqttPublishInMessa public Qos1MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + ExternalNetworkMqttUser.class, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); } @Override @@ -53,10 +60,7 @@ protected boolean validateImpl(ExternalNetworkMqttUser user, NetworkMqttSession } @Override - protected void handleNoMatchedSubscribers( - ExternalNetworkMqttUser user, - NetworkMqttSession session, - Publish publish) { + protected void handleNoMatchedSubscribers(ExternalNetworkMqttUser user, NetworkMqttSession session, Publish publish) { super.handleNoMatchedSubscribers(user, session, publish); int messageId = publish.messageId(); MqttOutMessage response = messageOutFactoryService diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java index 62632168..5a55a86a 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java @@ -19,6 +19,7 @@ import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -34,8 +35,14 @@ public class Qos2MqttPublishInMessageHandler extends TrackableMqttPublishInMessa public Qos2MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + ExternalNetworkMqttUser.class, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); this.trackableMessageCallback = this::handleReceivedTrackableMessage; } @@ -67,17 +74,15 @@ protected boolean validateImpl(ExternalNetworkMqttUser user, NetworkMqttSession } @Override - protected void handleNoMatchedSubscribers( - ExternalNetworkMqttUser user, - NetworkMqttSession session, - Publish publish) { + protected void handleNoMatchedSubscribers(ExternalNetworkMqttUser user, NetworkMqttSession session, Publish publish) { super.handleNoMatchedSubscribers(user, session, publish); var reasonCode = PublishReceivedReasonCode.NO_MATCHING_SUBSCRIBERS; updateSessionState(session, publish, reasonCode); sendFeedback( - user, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(publish.messageId(), reasonCode)); + user, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(publish.messageId(), reasonCode)); } @Override @@ -90,9 +95,10 @@ protected void handleSuccess( var reasonCode = PublishReceivedReasonCode.SUCCESS; updateSessionState(session, publish, reasonCode); sendFeedback( - user, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(publish.messageId(), PublishReceivedReasonCode.SUCCESS)); + user, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(publish.messageId(), PublishReceivedReasonCode.SUCCESS)); } private void updateSessionState(NetworkMqttSession session, Publish publish, PublishReceivedReasonCode reasonCode) { @@ -118,22 +124,25 @@ protected void handleError( MessageTacker messageTacker = session.inMessageTracker(); messageTacker.update(messageId, MqttMessageType.PUBLISH, reasonCode); - sendFeedback(user, session, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(messageId, reasonCode), messageId); + sendFeedback( + user, + session, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(messageId, reasonCode), + messageId); } - private void handleDuplicated( - ExternalNetworkMqttUser user, - int messageId, - TrackedMessageMeta alreadyInProcess) { + private void handleDuplicated(ExternalNetworkMqttUser user, int messageId, TrackedMessageMeta alreadyInProcess) { PublishReceivedReasonCode reasonCode = PublishReceivedReasonCode.SUCCESS; if (alreadyInProcess.reasonCode() instanceof PublishReceivedReasonCode receivedReasonCode) { reasonCode = receivedReasonCode; } - sendFeedback(user, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(messageId, reasonCode)); + sendFeedback( + user, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(messageId, reasonCode)); } private void handleMessageIdIsInUse(ExternalNetworkMqttUser user, int messageId) { @@ -163,10 +172,7 @@ private boolean handleReceivedTrackableMessage(MqttUser user, MqttSession sessio return true; } - messageTacker.update( - messageId, - MqttMessageType.PUBLISH_COMPLETE, - PublishCompletedReasonCode.SUCCESS); + messageTacker.update(messageId, MqttMessageType.PUBLISH_COMPLETE, PublishCompletedReasonCode.SUCCESS); MqttOutMessage response = messageOutFactoryService .resolveFactory(networkMqttUser) diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java index 9b768a65..57e0b1a1 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java @@ -11,17 +11,24 @@ import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; -public abstract class TrackableMqttPublishInMessageHandler - extends AbstractMqttPublishInMessageHandler { +public abstract class TrackableMqttPublishInMessageHandler extends + AbstractMqttPublishInMessageHandler { public TrackableMqttPublishInMessageHandler( Class expectedClientType, SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(expectedClientType, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + expectedClientType, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); } @Override diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy index ca195f54..8b64ea7e 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy @@ -12,6 +12,7 @@ import javasabr.mqtt.network.user.NetworkMqttUser import javasabr.mqtt.service.impl.DefaultMessageOutFactoryService import javasabr.mqtt.service.impl.DefaultPublishDeliveringService import javasabr.mqtt.service.impl.DefaultPublishReceivingService +import javasabr.mqtt.service.impl.InMemoryRetainMessageService import javasabr.mqtt.service.impl.DefaultTopicService import javasabr.mqtt.service.impl.DisabledAuthorizationService import javasabr.mqtt.service.impl.InMemorySubscriptionService @@ -48,14 +49,11 @@ abstract class IntegrationServiceSpecification extends Specification { def testPayload = "testpayload".getBytes(StandardCharsets.UTF_8) @Shared - def clientIdGenerator = new AtomicInteger(); + def clientIdGenerator = new AtomicInteger() @Shared def defaultTopicService = new DefaultTopicService() - @Shared - def defaultSubscriptionService = new InMemorySubscriptionService() - @Shared def defaultMessageOutFactoryService = new DefaultMessageOutFactoryService([ new Mqtt311MessageOutFactory(), @@ -69,11 +67,18 @@ abstract class IntegrationServiceSpecification extends Specification { new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) ]) + @Shared + def inMemoryRetainMessageService = new InMemoryRetainMessageService() + + @Shared + def defaultSubscriptionService = new InMemorySubscriptionService() + @Shared def qos0MqttPublishInMessageHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService); + defaultMessageOutFactoryService, + inMemoryRetainMessageService) @Shared def publishReceivingService = new DefaultPublishReceivingService([ @@ -81,25 +86,27 @@ abstract class IntegrationServiceSpecification extends Specification { new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService), + defaultMessageOutFactoryService, + inMemoryRetainMessageService), new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) ]) @Shared - def defaultPublishReleaseMqttInMessageHandler = new PublishReleaseMqttInMessageHandler(defaultMessageOutFactoryService); + def defaultPublishReleaseMqttInMessageHandler = new PublishReleaseMqttInMessageHandler(defaultMessageOutFactoryService) @Shared def defaultBufferAllocator = new DefaultBufferAllocator(SimpleServerNetworkConfig.builder().build()) @Shared - def defaultMqttSessionService = new InMemoryMqttSessionService(60_000); - + def defaultMqttSessionService = new InMemoryMqttSessionService(60_000) + @Shared def disabledAclService = new DisabledAuthorizationService() - + @Shared List> publishInFieldValidators = [ new PublishRetainMqttInMessageFieldValidator(defaultMessageOutFactoryService), diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy index 6e97e1de..404cbde4 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy @@ -6,13 +6,16 @@ import javasabr.mqtt.model.SubscribeRetainHandling import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode import javasabr.mqtt.model.subscription.Subscription +import javasabr.mqtt.model.subscription.SubscriptionResult +import javasabr.mqtt.model.topic.TopicFilter +import javasabr.mqtt.model.topic.TopicName import javasabr.mqtt.service.IntegrationServiceSpecification -import javasabr.mqtt.service.SubscriptionService +import javasabr.mqtt.service.TestExternalNetworkMqttUser import javasabr.rlib.collections.array.Array class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { - SubscriptionService subscriptionService = new InMemorySubscriptionService() + def subscriptionService = new InMemorySubscriptionService() def "should subscribe with expected results in default settings"() { given: @@ -53,11 +56,11 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 4 - result == Array.of( + result.collect(SubscriptionResult::subscribeAckReasonCode) == [ SubscribeAckReasonCode.GRANTED_QOS_0, SubscribeAckReasonCode.GRANTED_QOS_1, SubscribeAckReasonCode.GRANTED_QOS_2, - SubscribeAckReasonCode.TOPIC_FILTER_INVALID) + SubscribeAckReasonCode.TOPIC_FILTER_INVALID] } def "should not subscribe with for not supported topic filter"() { @@ -108,12 +111,12 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 5 - result == Array.of( + result.collect(SubscriptionResult::subscribeAckReasonCode) == [ SubscribeAckReasonCode.WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED, SubscribeAckReasonCode.WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED, SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED, SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED, - SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED) + SubscribeAckReasonCode.SHARED_SUBSCRIPTIONS_NOT_SUPPORTED] } def "should store subscription with correct subscription id"() { @@ -156,11 +159,11 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 4 - result == Array.of( + result.collect(SubscriptionResult::subscribeAckReasonCode) == [ SubscribeAckReasonCode.GRANTED_QOS_0, SubscribeAckReasonCode.GRANTED_QOS_1, SubscribeAckReasonCode.GRANTED_QOS_2, - SubscribeAckReasonCode.GRANTED_QOS_2) + SubscribeAckReasonCode.GRANTED_QOS_2] when: def mqttSession = mqttUser.session() def activeSubscriptions = mqttSession.activeSubscriptions() @@ -325,4 +328,42 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { storedSubscriptions.size() == 3 storedSubscriptions ==~ resultSubscriptions } + + def "should clean and restore subscriptions"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def expectedUser = mqttConnection.user() as TestExternalNetworkMqttUser + def expectedSubscription = new Subscription( + TopicFilter.valueOf("topic"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true) + when: + subscriptionService.subscribe(expectedUser, expectedUser.session(), Array.of(expectedSubscription)) + def subscribers = subscriptionService.findSubscribers(TopicName.valueOf("topic")) + then: + !subscribers.isEmpty() + with(subscribers[0]) { + user() == expectedUser + subscription() == expectedSubscription + } + when: + subscriptionService.cleanSubscriptions(expectedUser, expectedUser.session()) + subscribers = subscriptionService.findSubscribers(TopicName.valueOf("topic")) + then: + subscribers.isEmpty() + + when: + subscriptionService.restoreSubscriptions(expectedUser, expectedUser.session()) + subscribers = subscriptionService.findSubscribers(TopicName.valueOf("topic")) + then: + !subscribers.isEmpty() + with(subscribers[0]) { + user() == expectedUser + subscription() == expectedSubscription + } + } } diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy index df3539ba..5f20d18c 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy @@ -2,12 +2,15 @@ package javasabr.mqtt.service.message.handler.impl import javasabr.mqtt.model.MqttVersion import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.SubscribeRetainHandling import javasabr.mqtt.model.message.MqttMessageType import javasabr.mqtt.model.reason.code.DisconnectReasonCode import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode import javasabr.mqtt.model.subscription.RequestedSubscription +import javasabr.mqtt.model.subscription.TestPublishFactory import javasabr.mqtt.network.message.in.SubscribeMqttInMessage import javasabr.mqtt.network.message.out.DisconnectMqtt5OutMessage +import javasabr.mqtt.network.message.out.PublishMqtt5OutMessage import javasabr.mqtt.network.message.out.SubscribeAckMqtt5OutMessage import javasabr.mqtt.network.util.ExtraErrorReasons import javasabr.mqtt.service.IntegrationServiceSpecification @@ -25,18 +28,25 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification LoggerManager.enable(SubscribeMqttInMessageHandler, LoggerLevel.INFO) } + SubscribeMqttInMessageHandler subscribeMessageHandler + + def setup() { + subscribeMessageHandler = new SubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService, + inMemoryRetainMessageService, + defaultPublishDeliveringService) + } + def "should close connection by reason that session is already closed"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser mqttUser.session(null) when: def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def disconnectReason = mqttUser.nextSentMessage(DisconnectMqtt5OutMessage) disconnectReason.reasonCode() == DisconnectReasonCode.UNSPECIFIED_ERROR @@ -47,10 +57,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def "should response that message id is in use"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser def session = mqttUser.session() @@ -64,7 +70,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -79,10 +85,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def serverConfig = defaultExternalServerConnectionConfig .withSubscriptionIdAvailable(false) def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -94,7 +96,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -109,10 +111,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def serverConfig = defaultExternalServerConnectionConfig .withMaxQos(QoS.AT_MOST_ONCE) def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -124,7 +122,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -139,10 +137,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def serverConfig = defaultExternalServerConnectionConfig .withWildcardSubscriptionAvailable(false) def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -153,7 +147,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification RequestedSubscription.minimal("topic1/#", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2/+", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -172,10 +166,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def serverConfig = defaultExternalServerConnectionConfig .withSharedSubscriptionAvailable(false) def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -186,7 +176,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification RequestedSubscription.minimal("\$share/group1/topic1/#", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("\$share/group1/topic2/+", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -203,14 +193,10 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def "should close connection by reason MQTT protocol error"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: def subscribeMessage = new SubscribeMqttInMessage(0 as byte) - messageHandler.processInvalidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processInvalidMessage(mqttConnection, subscribeMessage) then: def disconnectReason = mqttUser.nextSentMessage(DisconnectMqtt5OutMessage) disconnectReason.reasonCode() == DisconnectReasonCode.MALFORMED_PACKET @@ -221,10 +207,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def "should reuse the same message if from previous request"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -233,7 +215,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -247,7 +229,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage2) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage2) then: def subscribeAck2 = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes2 = subscribeAck2.reasonCodes() @@ -259,10 +241,6 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def "should response that message id is in use because previous is still in progress"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser mqttUser.returnCompletedFeatures(false) @@ -272,7 +250,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -286,7 +264,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage2) + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage2) then: def subscribeAck2 = mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes2 = subscribeAck2.reasonCodes() @@ -294,4 +272,174 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification reasonCodes2.get(0) == SubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE subscribeAck2.messageId() == expectedMessageId } + + def "should only deliver 'send-if-subscription-does-not-exist' Subscribe Retain Handling once"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + mqttUser.returnCompletedFeatures(false) + and: + def expectedMessageId = 15 + def requestedSubscriptions = Array.of(new RequestedSubscription( + "topic/filter/1", + QoS.EXACTLY_ONCE, + SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST, + true, + true)) + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(requestedSubscriptions) + }} + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + inMemoryRetainMessageService.retainMessage(publishWithRetain) + when: + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + mqttUser.nextSentMessage(PublishMqtt5OutMessage) + mqttUser.isEmpty() + when: + subscribeMessage.messageId = ++expectedMessageId + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + mqttUser.isEmpty() + } + + def "should always deliver 'send' Subscribe Retain Handling"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + mqttUser.returnCompletedFeatures(false) + and: + def expectedMessageId = 15 + def requestedSubscriptions = Array.of(new RequestedSubscription( + "topic/filter/1", + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true)) + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(requestedSubscriptions) + }} + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + inMemoryRetainMessageService.retainMessage(publishWithRetain) + when: + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + mqttUser.nextSentMessage(PublishMqtt5OutMessage) + mqttUser.isEmpty() + when: + subscribeMessage.messageId = ++expectedMessageId + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + mqttUser.nextSentMessage(PublishMqtt5OutMessage) + mqttUser.isEmpty() + } + + def "should not deliver 'do-not-send' Subscribe Retain Handling"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + mqttUser.returnCompletedFeatures(false) + and: + def expectedMessageId = 15 + def requestedSubscriptions = Array.of(new RequestedSubscription( + "topic/filter/1", + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.DO_NOT_SEND, + true, + true)) + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(requestedSubscriptions) + }} + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + inMemoryRetainMessageService.retainMessage(publishWithRetain) + and: + def publishWithoutRetain = TestPublishFactory.makePublishWithoutRetain("topic/filter/1", "payload2") + inMemoryRetainMessageService.retainMessage(publishWithoutRetain) + when: + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + mqttUser.isEmpty() + when: + subscribeMessage.messageId = ++expectedMessageId + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + mqttUser.isEmpty() + } + + def "should reset retain flag if 'retain as published' is false"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + mqttUser.returnCompletedFeatures(false) + and: + def expectedMessageId = 15 + def requestedSubscriptions = Array.of(new RequestedSubscription( + "topic/filter/1", + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + false)) + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(requestedSubscriptions) + }} + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + inMemoryRetainMessageService.retainMessage(publishWithRetain) + when: + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + def publishMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + mqttUser.isEmpty() + and: + !publishMessage.retain() + } + + def "should keep retain flag if 'retain as published' is true"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + mqttUser.returnCompletedFeatures(false) + and: + def expectedMessageId = 15 + def requestedSubscriptions = Array.of(new RequestedSubscription( + "topic/filter/1", + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true)) + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(requestedSubscriptions) + }} + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + inMemoryRetainMessageService.retainMessage(publishWithRetain) + when: + subscribeMessageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttUser.nextSentMessage(SubscribeAckMqtt5OutMessage) + def publishMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + mqttUser.isEmpty() + and: + publishMessage.retain() + } } diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy index fd2d8600..aacc05f3 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy @@ -15,7 +15,8 @@ class Qos0MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def subscriber1 = mockedExternalConnection(MqttVersion.MQTT_5) def subscriber2 = mockedExternalConnection(MqttVersion.MQTT_5) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) @@ -50,7 +51,8 @@ class Qos0MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos0MqttPublishInMessageHandlerTest/2") diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy index d86426e6..6b21ab90 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy @@ -23,7 +23,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def subscriber1 = mockedExternalConnection(MqttVersion.MQTT_5) def subscriber2 = mockedExternalConnection(MqttVersion.MQTT_5) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) @@ -68,7 +69,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/2") @@ -92,7 +94,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/3") @@ -115,7 +118,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/4") @@ -141,7 +145,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/5") diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy index fdae20e6..d2c879a1 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy @@ -27,7 +27,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def subscriber1 = mockedExternalConnection(MqttVersion.MQTT_5) def subscriber2 = mockedExternalConnection(MqttVersion.MQTT_5) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) @@ -87,7 +88,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/2") @@ -126,7 +128,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/3") @@ -149,7 +152,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/4") @@ -175,7 +179,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/5") @@ -203,7 +208,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + inMemoryRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/5") diff --git a/model/src/main/java/javasabr/mqtt/model/AbstractTrieNode.java b/model/src/main/java/javasabr/mqtt/model/AbstractTrieNode.java new file mode 100644 index 00000000..75f835a2 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/AbstractTrieNode.java @@ -0,0 +1,93 @@ +package javasabr.mqtt.model; + +import java.util.Collection; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.rlib.collections.dictionary.DictionaryFactory; +import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; +import org.jspecify.annotations.Nullable; + +public abstract class AbstractTrieNode { + + @Nullable + volatile LockableRefToRefDictionary childNodes; + + protected abstract Supplier getNodeFactory(); + + private LockableRefToRefDictionary getOrCreateChildNodes() { + var current = childNodes; + if (current != null) { + return current; + } + synchronized (this) { + current = childNodes; + if (current == null) { + current = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + childNodes = current; + } + return current; + } + } + + protected T getOrCreateChildNode(String segment) { + var childNodes = getOrCreateChildNodes(); + long stamp = childNodes.readLock(); + try { + T topicFilterNode = childNodes.get(segment); + if (topicFilterNode != null) { + return topicFilterNode; + } + } finally { + childNodes.readUnlock(stamp); + } + stamp = childNodes.writeLock(); + try { + return childNodes.getOrCompute(segment, getNodeFactory()); + } finally { + childNodes.writeUnlock(stamp); + } + } + + protected void collectChildNodes(Collection resultCollection) { + var localChildNodes = childNodes; + if (localChildNodes == null) { + return; + } + long stamp = localChildNodes.readLock(); + try { + localChildNodes.values(resultCollection); + } finally { + localChildNodes.readUnlock(stamp); + } + } + + @Nullable + protected Collection getChildNodes(Supplier> resultCollectionFactory) { + var localChildNodes = childNodes; + if (localChildNodes == null) { + return null; + } + Collection resultCollection = resultCollectionFactory.get(); + collectChildNodes(resultCollection); + return resultCollection; + } + + @Nullable + protected T getChildNode(String segment) { + var localChildNodes = childNodes; + if (localChildNodes == null) { + return null; + } + long stamp = localChildNodes.readLock(); + try { + return localChildNodes.get(segment); + } finally { + localChildNodes.readUnlock(stamp); + } + } + + @Override + public String toString() { + return DebugUtils.toJsonString(this); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/QoS.java b/model/src/main/java/javasabr/mqtt/model/QoS.java index 8da7dd94..d8159f69 100644 --- a/model/src/main/java/javasabr/mqtt/model/QoS.java +++ b/model/src/main/java/javasabr/mqtt/model/QoS.java @@ -19,8 +19,7 @@ public enum QoS implements NumberedEnum { EXACTLY_ONCE(2, SubscribeAckReasonCode.GRANTED_QOS_2), INVALID(3, SubscribeAckReasonCode.IMPLEMENTATION_SPECIFIC_ERROR); - private static final NumberedEnumMap NUMBERED_MAP = - new NumberedEnumMap<>(QoS.class); + private static final NumberedEnumMap NUMBERED_MAP = new NumberedEnumMap<>(QoS.class); public static QoS ofCode(int level) { return NUMBERED_MAP.resolve(level, QoS.INVALID); @@ -45,4 +44,8 @@ public boolean isLowerThan(QoS another) { public boolean isHigherThan(QoS another) { return level > another.level; } + + public boolean isValid() { + return this != INVALID; + } } diff --git a/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java b/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java index 2e10c730..db8d646d 100644 --- a/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java +++ b/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java @@ -92,6 +92,24 @@ public Publish with(int messageId, QoS qos, boolean duplicated, int topicAlias) userProperties); } + public Publish withoutRetain() { + return new Publish( + messageId, + qos, + topicName, + responseTopicName, + payload, + duplicated, + false, + contentType, + subscriptionIds, + correlationData, + messageExpiryInterval, + topicAlias, + payloadFormat, + userProperties); + } + @Override public String toString() { return DebugUtils.toJsonString(this); diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java index 307db58c..1f6d4c74 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java @@ -2,7 +2,6 @@ import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.subscriber.SingleSubscriber; -import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; import javasabr.rlib.collections.array.Array; @@ -22,8 +21,8 @@ public ConcurrentSubscriberTree() { } @Nullable - public SingleSubscriber subscribe(MqttUser user, Subscription subscription) { - return rootNode.subscribe(0, user, subscription, subscription.topicFilter()); + public SingleSubscriber subscribe(SingleSubscriber subscriber) { + return rootNode.subscribe(0, subscriber, subscriber.subscription().topicFilter()); } public boolean unsubscribe(MqttUser user, TopicFilter topicFilter) { diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index 4a6579c2..5d3c74bb 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -5,14 +5,11 @@ import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.Subscriber; -import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.LockableArray; import javasabr.rlib.collections.array.MutableArray; -import javasabr.rlib.collections.dictionary.DictionaryFactory; -import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; import lombok.AccessLevel; import lombok.Getter; import lombok.experimental.Accessors; @@ -24,30 +21,33 @@ @FieldDefaults(level = AccessLevel.PRIVATE) class SubscriberNode extends SubscriberTreeBase { - private final static Supplier SUBSCRIBER_NODE_FACTORY = SubscriberNode::new; + private final static Supplier NODE_FACTORY = SubscriberNode::new; static { DebugUtils.registerIncludedFields("childNodes", "subscribers"); } - @Nullable - volatile LockableRefToRefDictionary childNodes; @Nullable volatile LockableArray subscribers; + @Override + protected Supplier getNodeFactory() { + return NODE_FACTORY; + } + /** * @return the previous subscription from the same owner */ @Nullable - public SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscription, TopicFilter topicFilter) { + protected SingleSubscriber subscribe(int level, SingleSubscriber subscriber, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { - return addSubscriber(getOrCreateSubscribers(), owner, subscription, topicFilter); + return addSubscriber(getOrCreateSubscribers(), subscriber, topicFilter); } SubscriberNode childNode = getOrCreateChildNode(topicFilter.segment(level)); - return childNode.subscribe(level + 1, owner, subscription, topicFilter); + return childNode.subscribe(level + 1, subscriber, topicFilter); } - public boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { + protected boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { return removeSubscriber(subscribers(), owner, topicFilter); } @@ -67,7 +67,7 @@ private void exactlyTopicMatch( int lastLevel, MutableArray result) { String segment = topicName.segment(level); - SubscriberNode subscriberNode = childNode(segment); + SubscriberNode subscriberNode = getChildNode(segment); if (subscriberNode == null) { return; } @@ -83,7 +83,7 @@ private void singleWildcardTopicMatch( TopicName topicName, int lastLevel, MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); + SubscriberNode subscriberNode = getChildNode(TopicFilter.SINGLE_LEVEL_WILDCARD); if (subscriberNode == null) { return; } @@ -95,71 +95,24 @@ private void singleWildcardTopicMatch( } private void multiWildcardTopicMatch(MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); + SubscriberNode subscriberNode = getChildNode(TopicFilter.MULTI_LEVEL_WILDCARD); if (subscriberNode != null) { appendSubscribersTo(result, subscriberNode); } } - private SubscriberNode getOrCreateChildNode(String segment) { - LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); - long stamp = childNodes.readLock(); - try { - SubscriberNode subscriberNode = childNodes.get(segment); - if (subscriberNode != null) { - return subscriberNode; - } - } finally { - childNodes.readUnlock(stamp); - } - stamp = childNodes.writeLock(); - try { - return childNodes.getOrCompute(segment, SUBSCRIBER_NODE_FACTORY); - } finally { - childNodes.writeUnlock(stamp); - } - } - - @Nullable - private SubscriberNode childNode(String segment) { - LockableRefToRefDictionary childNodes = childNodes(); - if (childNodes == null) { - return null; - } - long stamp = childNodes.readLock(); - try { - return childNodes.get(segment); - } finally { - childNodes.readUnlock(stamp); - } - } - - private LockableRefToRefDictionary getOrCreateChildNodes() { - if (childNodes == null) { - synchronized (this) { - if (childNodes == null) { - childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); - } - } - } - //noinspection ConstantConditions - return childNodes; - } - private LockableArray getOrCreateSubscribers() { - if (subscribers == null) { - synchronized (this) { - if (subscribers == null) { - subscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); - } + LockableArray localSubscribers = subscribers; + if (localSubscribers != null) { + return localSubscribers; + } + synchronized (this) { + localSubscribers = subscribers; + if (localSubscribers == null) { + localSubscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); + subscribers = localSubscribers; } + return localSubscribers; } - //noinspection ConstantConditions - return subscribers; - } - - @Override - public String toString() { - return DebugUtils.toJsonString(this); } } diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java index 972b696b..15b4aa8d 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java @@ -1,12 +1,12 @@ package javasabr.mqtt.model.subscriber.tree; import java.util.Objects; +import javasabr.mqtt.model.AbstractTrieNode; import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.subscriber.SharedSubscriber; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.Subscriber; -import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.rlib.collections.array.LockableArray; @@ -18,7 +18,7 @@ @RequiredArgsConstructor @FieldDefaults(level = AccessLevel.PROTECTED, makeFinal = true) -abstract class SubscriberTreeBase { +abstract class SubscriberTreeBase extends AbstractTrieNode { /** * @return previous subscriber with the same user @@ -26,17 +26,16 @@ abstract class SubscriberTreeBase { @Nullable protected static SingleSubscriber addSubscriber( LockableArray subscribers, - MqttUser user, - Subscription subscription, + SingleSubscriber subscriber, TopicFilter topicFilter) { long stamp = subscribers.writeLock(); try { if (topicFilter instanceof SharedTopicFilter stf) { - addSharedSubscriber(subscribers, user, subscription, stf); + addSharedSubscriber(subscribers, subscriber, stf); return null; } else { - SingleSubscriber previous = removePreviousIfExist(subscribers, user); - subscribers.add(new SingleSubscriber(user, subscription)); + SingleSubscriber previous = removePreviousIfExist(subscribers, subscriber.user()); + subscribers.add(subscriber); return previous; } } finally { @@ -45,9 +44,7 @@ protected static SingleSubscriber addSubscriber( } @Nullable - private static SingleSubscriber removePreviousIfExist( - LockableArray subscribers, - MqttUser user) { + private static SingleSubscriber removePreviousIfExist(LockableArray subscribers, MqttUser user) { int index = subscribers.indexOf(Subscriber::resolveUser, user); if (index < 0) { return null; @@ -59,8 +56,7 @@ private static SingleSubscriber removePreviousIfExist( private static void addSharedSubscriber( LockableArray subscribers, - MqttUser user, - Subscription subscription, + SingleSubscriber subscriber, SharedTopicFilter sharedTopicFilter) { String group = sharedTopicFilter.shareName(); @@ -73,7 +69,7 @@ private static void addSharedSubscriber( subscribers.add(sharedSubscriber); } - sharedSubscriber.addSubscriber(new SingleSubscriber(user, subscription)); + sharedSubscriber.addSubscriber(subscriber); } protected static void appendSubscribersTo(MutableArray result, SubscriberNode subscriberNode) { @@ -84,10 +80,7 @@ protected static void appendSubscribersTo(MutableArray result, long stamp = subscribers.readLock(); try { for (Subscriber subscriber : subscribers) { - SingleSubscriber singleSubscriber = subscriber.resolveSingle(); - if (removeDuplicateWithLowerQoS(result, singleSubscriber)) { - result.add(singleSubscriber); - } + addOrReplaceIfLowerQos(result, subscriber); } } finally { subscribers.readUnlock(stamp); @@ -141,23 +134,18 @@ private static boolean isSharedSubscriberWithGroup(Subscriber subscriber, String return subscriber instanceof SharedSubscriber shared && Objects.equals(group, shared.group()); } - private static boolean removeDuplicateWithLowerQoS( - MutableArray result, SingleSubscriber candidate) { - + private static void addOrReplaceIfLowerQos(MutableArray result, Subscriber subscriber) { + SingleSubscriber candidate = subscriber.resolveSingle(); int found = result.indexOf(SingleSubscriber::user, candidate.user()); if (found == -1) { - return true; + result.add(candidate); + return; } - QoS candidateQos = candidate.qos(); - SingleSubscriber exist = result.get(found); - QoS existeQos = exist.qos(); - - if (existeQos.ordinal() < candidateQos.ordinal()) { + QoS existedQos = result.get(found).qos(); + if (existedQos.isLowerThan(candidateQos)) { result.remove(found); - return true; + result.add(candidate); } - - return false; } } diff --git a/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java b/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java index 69439978..c66a6548 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java +++ b/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java @@ -33,8 +33,12 @@ public record Subscription( boolean noLocal, /* If true, Application Messages forwarded using this subscription keep the RETAIN flag they were published with. If - false, Application Messages forwarded using this subscription have the RETAIN flag set to 0. Retained messages sent - when the subscription is established have the RETAIN flag set to 1. + false, Application Messages forwarded using this subscription have the RETAIN flag set to 0. + + Bit 3 of the Subscription Options represents the Retain As Published option. + If 1, Application Messages forwarded using this subscription keep the RETAIN flag they were published with. + If 0, Application Messages forwarded using this subscription have the RETAIN flag set to 0. + Retained messages sent when the subscription is established have the RETAIN flag set to 1. */ boolean retainAsPublished) { diff --git a/model/src/main/java/javasabr/mqtt/model/subscription/SubscriptionResult.java b/model/src/main/java/javasabr/mqtt/model/subscription/SubscriptionResult.java new file mode 100644 index 00000000..7c10963b --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/subscription/SubscriptionResult.java @@ -0,0 +1,23 @@ +package javasabr.mqtt.model.subscription; + +import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; +import javasabr.mqtt.model.subscriber.SingleSubscriber; +import org.jspecify.annotations.Nullable; + +public record SubscriptionResult( + SubscribeAckReasonCode subscribeAckReasonCode, + @Nullable SingleSubscriber subscriber, + boolean isSubscriptionAlreadyExisted) { + + public SubscriptionResult(SingleSubscriber subscriber, boolean isSubscriptionAlreadyExisted) { + this(subscriber.subscription().qos().subscribeAckReasonCode(), subscriber, isSubscriptionAlreadyExisted); + } + + public SubscriptionResult(SubscribeAckReasonCode subscribeAckReasonCode) { + this(subscribeAckReasonCode, null, false); + } + + public boolean isNotExistedPreviously(){ + return !isSubscriptionAlreadyExisted; + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java index 2b1f313e..ece98cca 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java @@ -32,6 +32,10 @@ protected AbstractTopic(String rawTopicName) { rawTopic = rawTopicName; } + public boolean isShared(){ + return false; + } + public String segment(int level) { return segments[level]; } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java b/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java index e9e3ba7d..b68c4689 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java @@ -28,6 +28,11 @@ public static SharedTopicFilter valueOf(String rawSharedTopicFilter) { return new SharedTopicFilter(rawTopicFilter, shareName); } + @Override + public boolean isShared(){ + return true; + } + public static boolean isShared(String rawTopicFilter) { return rawTopicFilter.startsWith(SharedTopicFilter.SHARE_KEYWORD); } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java new file mode 100644 index 00000000..ea4901db --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java @@ -0,0 +1,33 @@ +package javasabr.mqtt.model.topic.tree; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.common.ThreadSafe; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; + +@FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) +public class ConcurrentRetainedMessageTree implements ThreadSafe { + + RetainedMessageNode rootNode; + + public ConcurrentRetainedMessageTree() { + this.rootNode = new RetainedMessageNode(); + } + + public void addRetainedMessage(Publish message) { + rootNode.addRetainedMessage(0, message, message.topicName()); + } + + public void removeRetainedMessage(TopicName topicName) { + rootNode.removeRetainedMessage(0, topicName); + } + + public Array getRetainedMessages(TopicFilter topicFilter) { + var resultArray = Array.builder(Publish.class); + rootNode.collectRetainedMessages(0, topicFilter, resultArray); + return resultArray.build(); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java new file mode 100644 index 00000000..7850afdd --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java @@ -0,0 +1,122 @@ +package javasabr.mqtt.model.topic.tree; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.AbstractTrieNode; +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.ArrayBuilder; +import javasabr.rlib.collections.array.ArrayFactory; +import javasabr.rlib.collections.array.MutableArray; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; + +@Getter(AccessLevel.PACKAGE) +@Accessors(fluent = true, chain = false) +@FieldDefaults(level = AccessLevel.PRIVATE) +class RetainedMessageNode extends AbstractTrieNode { + + private final static Supplier NODE_FACTORY = RetainedMessageNode::new; + + static { + DebugUtils.registerIncludedFields("childNodes", "retainedMessage"); + } + + private static MutableArray childNodesFactory() { + return ArrayFactory.mutableArray(RetainedMessageNode.class); + } + + final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); + + @Override + protected Supplier getNodeFactory() { + return NODE_FACTORY; + } + + public void addRetainedMessage(int level, Publish message, TopicName topicName) { + var child = getOrCreateChildNode(topicName.segment(level)); + boolean isLastLevel = (level + 1 == topicName.levelsCount()); + if (isLastLevel) { + child.setRetainedMessage(message); + } else { + child.addRetainedMessage(level + 1, message, topicName); + } + } + + public void removeRetainedMessage(int level, TopicName topicName) { + var child = getOrCreateChildNode(topicName.segment(level)); + boolean isLastLevel = (level + 1 == topicName.levelsCount()); + if (isLastLevel) { + child.clearRetainedMessage(); + } else { + child.removeRetainedMessage(level + 1, + topicName); + } + } + + private void setRetainedMessage(Publish value) { + retainedMessage.set(value); + } + + private void clearRetainedMessage() { + retainedMessage.set(null); + } + + public void collectRetainedMessages(int level, TopicFilter topicFilter, ArrayBuilder result) { + if (level == topicFilter.levelsCount()) { + Publish publish = retainedMessage.get(); + if (publish != null) { + result.add(publish); + } + return; + } + String segment = topicFilter.segment(level); + boolean isOneChar = segment.length() == 1; + if (isOneChar && segment.charAt(0) == TopicFilter.SINGLE_LEVEL_WILDCARD_CHAR) { + collectAllChildren(level, topicFilter, result); + } else if (isOneChar && segment.charAt(0) == TopicFilter.MULTI_LEVEL_WILDCARD_CHAR) { + collectEverything(this, result); + } else { + collectExactSegment(level, segment, topicFilter, result); + } + } + + private void collectExactSegment( + int level, + String segment, + TopicFilter topicFilter, + ArrayBuilder result) { + RetainedMessageNode retainedMessageNode = getChildNode(segment); + if (retainedMessageNode != null) { + retainedMessageNode.collectRetainedMessages(level + 1, topicFilter, result); + } + } + + private void collectAllChildren(int level, TopicFilter topicFilter, ArrayBuilder result) { + var localChildNodes = getChildNodes(RetainedMessageNode::childNodesFactory); + if (localChildNodes != null) { + for (RetainedMessageNode childNode : localChildNodes) { + childNode.collectRetainedMessages(level + 1, topicFilter, result); + } + } + } + + private void collectEverything(RetainedMessageNode node, ArrayBuilder result) { + Publish message = node.retainedMessage.get(); + if (message != null) { + result.add(message); + } + + var childNodes = node.getChildNodes(RetainedMessageNode::childNodesFactory); + if (childNodes != null) { + for (RetainedMessageNode childNode : childNodes) { + collectEverything(childNode, result); + } + } + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java new file mode 100644 index 00000000..1df48806 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java @@ -0,0 +1,4 @@ +@NullMarked +package javasabr.mqtt.model.topic.tree; + +import org.jspecify.annotations.NullMarked; diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy new file mode 100644 index 00000000..e3c7c6cd --- /dev/null +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy @@ -0,0 +1,117 @@ +package javasabr.mqtt.model.topic.tree + + +import javasabr.mqtt.model.subscription.TestPublishFactory +import javasabr.mqtt.model.topic.TopicFilter +import javasabr.mqtt.model.topic.TopicName +import javasabr.mqtt.test.support.UnitSpecification + +class RetainedMessageTreeTest extends UnitSpecification { + + def "should fetch retained messages by topic filter"( + List messages, + String rawTopicFilter, + List expectedMessages) { + given: + ConcurrentRetainedMessageTree retainedMessageTree = new ConcurrentRetainedMessageTree() + messages.collect(TestPublishFactory::makePublish).each(retainedMessageTree::addRetainedMessage) + def topicFilter = TopicFilter.valueOf(rawTopicFilter) + when: + def retainedMessages = retainedMessageTree.getRetainedMessages(topicFilter) + then: + retainedMessages.size() == expectedMessages.size() + verifyEach(retainedMessages) { publish, index -> + publish.topicName().rawTopic() == expectedMessages[index] + } + where: + rawTopicFilter << [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment3", + "/topic/+/segment2", + "/topic/#" + ] + //noinspection GroovyAssignabilityCheck + messages << [ + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment1/segment2", + "/topic/", + "/topic" + ], + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment1/segment2", + "/topic/", + "/topic/segment2", + "/", + "/topic/segment2/segment1" + ], + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment3", + "/topic/segment3", + "/topic/segment3", + "/topic/segment3" + ], + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment1/segment2", + "/topic/segment500/segment2", + "/topic/", + "/topic" + ], + [ + "/topic1/segment1", + "/topic/segment2", + "/topic2/segment1/segment2", + "/topic/segment3", + "/topic/segment1/segment2" + ] + ] + //noinspection GroovyAssignabilityCheck + expectedMessages << [ + [ + "/topic/segment1" + ], + [ + "/topic/segment2" + ], + [ + "/topic/segment3" + ], + [ + "/topic/segment1/segment2", + "/topic/segment500/segment2" + ], + [ + "/topic/segment1/segment2", + "/topic/segment2", + "/topic/segment3" + ] + ] + } + + def "should add and remove retained messages to and from retained message tree"() { + given: + def publish = TestPublishFactory.makePublish("topic") + ConcurrentRetainedMessageTree retainedMessageTree = new ConcurrentRetainedMessageTree() + when: + retainedMessageTree.addRetainedMessage(publish) + then: + with(retainedMessageTree.getRetainedMessages(TopicFilter.valueOf("topic"))) { + size() == 1 + first() == publish + } + when: + retainedMessageTree.removeRetainedMessage(TopicName.valueOf("topic")) + then: + with(retainedMessageTree.getRetainedMessages(TopicFilter.valueOf("topic"))) { + isEmpty() + } + } +} diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy index cfb17623..fe048055 100644 --- a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy @@ -15,6 +15,18 @@ import javasabr.mqtt.test.support.UnitSpecification class SubscriberTreeTest extends UnitSpecification { + static SingleSubscriber createSubscriber(String clientId, String rawTopicFilter) { + return createSubscriber(clientId, rawTopicFilter, QoS.AT_LEAST_ONCE.number()) + } + + static SingleSubscriber createSubscriber(String clientId, String rawTopicFilter, int qos) { + return new SingleSubscriber(makeUser(clientId), makeSubscription(rawTopicFilter, qos)) + } + + static SingleSubscriber createShareSubscriber(String clientId, String rawTopicFilter) { + return new SingleSubscriber(makeUser(clientId), makeSharedSubscription(rawTopicFilter)) + } + def "should match simple topic correctly"( List subscriptions, List users, @@ -23,7 +35,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -36,6 +48,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment2", "/topic/segment3" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1"), @@ -62,6 +75,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic/segment3") ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -88,6 +102,7 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id4") ] ] + //noinspection GroovyAssignabilityCheck expectedUsers << [ [ makeUser("id1") @@ -111,7 +126,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -124,6 +139,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment2", "/topic/segment3" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1"), @@ -156,6 +172,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic2/+") ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -188,6 +205,7 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id8") ] ] + //noinspection GroovyAssignabilityCheck expectedUsers << [ [ makeUser("id1"), @@ -216,7 +234,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -229,6 +247,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment3/segment4", "/topic/segment2" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1/segment2"), @@ -264,6 +283,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic/segment3/#") ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -299,6 +319,7 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id9") ] ] + //noinspection GroovyAssignabilityCheck expectedUsers << [ [ makeUser("id1"), @@ -330,7 +351,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -342,6 +363,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment3", "/topic/segment2/" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1/segment2", 2), @@ -377,6 +399,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic/#", 0) ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -412,21 +435,22 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id3") ] ] + //noinspection GroovyAssignabilityCheck expectedSubscribers << [ [ - new SingleSubscriber(makeUser("id1"), makeSubscription("/topic/segment1/segment2", 2)), - new SingleSubscriber(makeUser("id2"), makeSubscription("/topic/segment1/#", 1)), - new SingleSubscriber(makeUser("id3"), makeSubscription("/topic/#", 0)), + createSubscriber("id1", "/topic/segment1/segment2", 2), + createSubscriber("id2", "/topic/segment1/#", 1), + createSubscriber("id3", "/topic/#", 0), ], [ - new SingleSubscriber(makeUser("id1"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id2"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id3"), makeSubscription("/topic/#", 0)), + createSubscriber("id1", "/topic/#", 0), + createSubscriber("id2", "/topic/#", 0), + createSubscriber("id3", "/topic/#", 0), ], [ - new SingleSubscriber(makeUser("id1"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id2"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id3"), makeSubscription("/topic/segment2/#", 1)), + createSubscriber("id1", "/topic/#", 0), + createSubscriber("id2", "/topic/#", 0), + createSubscriber("id3", "/topic/segment2/#", 1), ] ] } @@ -436,16 +460,16 @@ class SubscriberTreeTest extends UnitSpecification { def group1 = ["id1", "id2", "id3", "id4", "id5"] def group2 = ["id6", "id7", "id8", "id9", "id10"] ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() - subscriberTree.subscribe(makeUser("id1"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id2"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id4"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id5"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id6"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id7"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id8"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id9"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id10"), makeSharedSubscription('$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id1", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id2", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id3", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id4", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id5", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id6", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id7", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id8", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id9", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id10", '$share/group2/topic/name1')) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -469,9 +493,9 @@ class SubscriberTreeTest extends UnitSpecification { def "should subscribe and unsubscribe simple topic correctly correctly"() { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() - subscriberTree.subscribe(makeUser("id1"), makeSubscription('topic/name1')) - subscriberTree.subscribe(makeUser("id2"), makeSubscription('topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSubscription('topic/name1')) + subscriberTree.subscribe(createSubscriber("id1", 'topic/name1')) + subscriberTree.subscribe(createSubscriber("id2", 'topic/name1')) + subscriberTree.subscribe(createSubscriber("id3", 'topic/name1')) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -506,9 +530,9 @@ class SubscriberTreeTest extends UnitSpecification { def "should subscribe and unsubscribe shared topic correctly correctly"() { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() - subscriberTree.subscribe(makeUser("id1"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id2"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSharedSubscription('$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id1", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id2", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id3", '$share/group1/topic/name1')) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -517,8 +541,12 @@ class SubscriberTreeTest extends UnitSpecification { then: matched.size() == 1 when: - def id2WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id2"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) - def id3WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id3"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) + def id2WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id2"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) + def id3WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id3"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .collect { it.user().toString() } @@ -528,8 +556,12 @@ class SubscriberTreeTest extends UnitSpecification { id2WasUnsubscribed id3WasUnsubscribed when: - def id1WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id1"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) - id3WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id3"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) + def id1WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id1"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) + id3WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id3"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .collect { it.user().toString() } @@ -546,10 +578,10 @@ class SubscriberTreeTest extends UnitSpecification { def owner1 = makeUser("id1") def originalSub = makeSubscription('topic/name1') def replacementSub = makeSubscription('topic/name1') - subscriberTree.subscribe(makeUser("id2"), makeSubscription('topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSubscription('topic/name1')) + subscriberTree.subscribe(createSubscriber("id2", 'topic/name1')) + subscriberTree.subscribe(createSubscriber("id3", 'topic/name1')) when: - def previous = subscriberTree.subscribe(owner1, originalSub) + def previous = subscriberTree.subscribe(new SingleSubscriber(owner1, originalSub)) def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .toSet() @@ -557,7 +589,7 @@ class SubscriberTreeTest extends UnitSpecification { matched.size() == 3 previous == null; when: - previous = subscriberTree.subscribe(owner1, replacementSub) + previous = subscriberTree.subscribe(new SingleSubscriber(owner1, replacementSub)) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .toSet() @@ -573,8 +605,8 @@ class SubscriberTreeTest extends UnitSpecification { ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() def owner1 = makeUser("id1") def owner2 = makeUser("id2") - subscriberTree.subscribe(owner1, makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(owner2, makeSharedSubscription('$share/group1/topic/name1')) + subscriberTree.subscribe(new SingleSubscriber(owner1, makeSharedSubscription('$share/group1/topic/name1'))) + subscriberTree.subscribe(new SingleSubscriber(owner2, makeSharedSubscription('$share/group1/topic/name1'))) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -597,7 +629,7 @@ class SubscriberTreeTest extends UnitSpecification { matched.size() == 1 matched.first().user() == owner2 when: - subscriberTree.subscribe(owner1, makeSharedSubscription('$share/group1/topic/name1')) + subscriberTree.subscribe(new SingleSubscriber(owner1, makeSharedSubscription('$share/group1/topic/name1'))) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .toSet() diff --git a/model/src/testFixtures/groovy/javasabr/mqtt/model/subscription/TestPublishFactory.groovy b/model/src/testFixtures/groovy/javasabr/mqtt/model/subscription/TestPublishFactory.groovy new file mode 100644 index 00000000..257ab27d --- /dev/null +++ b/model/src/testFixtures/groovy/javasabr/mqtt/model/subscription/TestPublishFactory.groovy @@ -0,0 +1,67 @@ +package javasabr.mqtt.model.subscription + +import javasabr.mqtt.model.PayloadFormat +import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.publishing.Publish +import javasabr.mqtt.model.topic.TopicName +import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.IntArray + +import static java.nio.charset.StandardCharsets.UTF_8 + +class TestPublishFactory { + + static def makePublish(String topicName) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + "payload".getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } + + static def makePublishWithRetain(String topicName, String payload) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + payload.getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } + + static def makePublishWithoutRetain(String topicName, String payload) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + payload.getBytes(UTF_8), + false, + false, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } +}