From b0d5900639f73d09ce5dbfd1c01a757eea2c2498 Mon Sep 17 00:00:00 2001 From: Andrea Selva Date: Sun, 22 Jan 2023 11:57:53 +0100 Subject: [PATCH] [MQTT 5] Avoid to pub retries on timeout (#697) This PR stores the MQTT version into the Session instance and use that to keep the existing behavior for inflight resends (happening on a timeout basis on ACK received) in case the version is MQTT 3.1 or MQTT 3.1.1. When the version of the Session is MQTT 5 it removes the resend on PUB ACK timeouts and switch to send only in case the same client reconnects with cleanStart = 0 and there is any peding publishes in the flight zone to get acknowledged. To test this use the raw Client has been extended, so now can also subscribe and collect publish messages. --- .github/workflows/maven_build.yml | 4 +- broker/pom.xml | 8 + .../io/moquette/broker/MQTTConnection.java | 12 +- .../main/java/io/moquette/broker/Session.java | 78 ++++++--- .../io/moquette/broker/SessionRegistry.java | 10 +- .../main/java/io/moquette/broker/Utils.java | 7 + .../java/io/moquette/broker/SessionTest.java | 3 +- .../integration/mqtt5/ConnectTest.java | 154 ++++++++++++++++++ .../java/io/moquette/testclient/Client.java | 109 +++++++++++-- 9 files changed, 333 insertions(+), 52 deletions(-) create mode 100644 broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java diff --git a/.github/workflows/maven_build.yml b/.github/workflows/maven_build.yml index dee3b1054..edba9df4b 100644 --- a/.github/workflows/maven_build.yml +++ b/.github/workflows/maven_build.yml @@ -5,9 +5,9 @@ name: Java CI with Maven on: push: - branches: [ main ] + branches: [ main, mqtt5_development ] pull_request: - branches: [ main ] + branches: [ main, mqtt5_development ] jobs: test: diff --git a/broker/pom.xml b/broker/pom.xml index f5cf8923e..78a71df2b 100644 --- a/broker/pom.xml +++ b/broker/pom.xml @@ -19,6 +19,7 @@ https://github.com/netty/netty/blob/netty-4.1.85.Final/pom.xml#L594 --> 2.0.54.Final 1.2.5 + 1.3.0 2.1.212 @@ -146,6 +147,13 @@ test + + com.hivemq + hivemq-mqtt-client + ${hivemqclient.version} + test + + org.eclipse.jetty.websocket websocket-client diff --git a/broker/src/main/java/io/moquette/broker/MQTTConnection.java b/broker/src/main/java/io/moquette/broker/MQTTConnection.java index 2e55101d8..a0bf5799f 100644 --- a/broker/src/main/java/io/moquette/broker/MQTTConnection.java +++ b/broker/src/main/java/io/moquette/broker/MQTTConnection.java @@ -149,7 +149,10 @@ PostOffice.RouteResult processConnect(MqttConnectMessage msg) { final String username = payload.userName(); LOG.trace("Processing CONNECT message. CId: {} username: {}", clientId, username); - if (isNotProtocolVersion(msg, MqttVersion.MQTT_3_1) && isNotProtocolVersion(msg, MqttVersion.MQTT_3_1_1)) { + if (isNotProtocolVersion(msg, MqttVersion.MQTT_3_1) && + isNotProtocolVersion(msg, MqttVersion.MQTT_3_1_1) && + isNotProtocolVersion(msg, MqttVersion.MQTT_5) + ) { LOG.warn("MQTT protocol version is not valid. CId: {}", clientId); abortConnection(CONNECTION_REFUSED_UNACCEPTABLE_PROTOCOL_VERSION); return PostOffice.RouteResult.failed(clientId); @@ -236,11 +239,14 @@ public void operationComplete(ChannelFuture future) throws Exception { // OK continue with sending queued messages and normal flow if (result.mode == SessionRegistry.CreationModeEnum.REOPEN_EXISTING) { - result.session.sendQueuedMessagesWhileOffline(); + result.session.reconnectSession(); } initializeKeepAliveTimeout(channel, msg, clientIdUsed); - setupInflightResender(channel); + if (isNotProtocolVersion(msg, MqttVersion.MQTT_5)) { + // In MQTT5 MQTT-4.4.0-1 avoid retries messages on timer base. + setupInflightResender(channel); + } postOffice.dispatchConnection(msg); LOG.trace("dispatch connection: {}", msg); diff --git a/broker/src/main/java/io/moquette/broker/Session.java b/broker/src/main/java/io/moquette/broker/Session.java index f5b1561be..78f5ee59d 100644 --- a/broker/src/main/java/io/moquette/broker/Session.java +++ b/broker/src/main/java/io/moquette/broker/Session.java @@ -42,6 +42,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; class Session { @@ -103,18 +104,20 @@ static final class Will { private MQTTConnection mqttConnection; private final Set subscriptions = new HashSet<>(); private final Map inflightWindow = new HashMap<>(); + // used only in MQTT3 where resends are done on timeout of ACKs. private final DelayQueue inflightTimeouts = new DelayQueue<>(); private final Map qos2Receiving = new HashMap<>(); private final AtomicInteger inflightSlots = new AtomicInteger(INFLIGHT_WINDOW_SIZE); // this should be configurable private final Instant created; private final int expiryInterval; + private final boolean resendInflightOnTimeout; - Session(String clientId, boolean clean, Will will, SessionMessageQueue sessionQueue) { - this(clientId, clean, sessionQueue); + Session(String clientId, boolean clean, MqttVersion protocolVersion, Will will, SessionMessageQueue sessionQueue) { + this(clientId, clean, protocolVersion, sessionQueue); this.will = will; } - Session(String clientId, boolean clean, SessionMessageQueue sessionQueue) { + Session(String clientId, boolean clean, MqttVersion protocolVersion, SessionMessageQueue sessionQueue) { if (sessionQueue == null) { throw new IllegalArgumentException("sessionQueue parameter can't be null"); } @@ -124,6 +127,7 @@ static final class Will { this.created = Instant.now(); // in MQTT3 cleanSession = true means expiryInterval=0 else infinite expiryInterval = clean ? 0 : 0xFFFFFFFF; + this.resendInflightOnTimeout = protocolVersion != MqttVersion.MQTT_5; } public boolean expireImmediately() { @@ -228,7 +232,9 @@ public void processPubRec(int pubRecPacketId) { return; } inflightWindow.put(pubRecPacketId, new SessionRegistry.PubRelMarker()); - inflightTimeouts.add(new InFlightPacket(pubRecPacketId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(pubRecPacketId, FLIGHT_BEFORE_RESEND_MS)); + } MqttMessage pubRel = MQTTConnection.pubrel(pubRecPacketId); mqttConnection.sendIfWritableElseDrop(pubRel); @@ -297,7 +303,9 @@ private void sendPublishQos1(Topic topic, MqttQoS qos, ByteBuf payload, boolean old.release(); inflightSlots.incrementAndGet(); } - inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); + } MqttPublishMessage publishMsg = MQTTConnection.createNotRetainedPublishMessage(topic.toString(), qos, payload, packetId); @@ -331,8 +339,9 @@ private void sendPublishQos2(Topic topic, MqttQoS qos, ByteBuf payload, boolean old.release(); inflightSlots.incrementAndGet(); } - inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); - + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); + } MqttPublishMessage publishMsg = MQTTConnection.createNotRetainedPublishMessage(topic.toString(), qos, payload, packetId); localMqttConnectionRef.sendPublish(publishMsg); @@ -354,7 +363,7 @@ private boolean canSkipQueue(MQTTConnection localMqttConnectionRef) { localMqttConnectionRef.channel.isWritable(); } - private boolean inflighHasSlotsAndConnectionIsUp() { + private boolean inflightHasSlotsAndConnectionIsUp() { return inflightSlots.get() > 0 && connected() && mqttConnection.channel.isWritable(); @@ -378,20 +387,30 @@ public void flushAllQueuedMessages() { } public void resendInflightNotAcked() { - Collection expired = new ArrayList<>(INFLIGHT_WINDOW_SIZE); - inflightTimeouts.drainTo(expired); + Collection nonAckPacketIds; + if (resendInflightOnTimeout) { + // MQTT3 behavior, resend on timeout + Collection expired = new ArrayList<>(INFLIGHT_WINDOW_SIZE); + inflightTimeouts.drainTo(expired); + nonAckPacketIds = expired.stream().map(p -> p.packetId).collect(Collectors.toList()); + } else { + // MQTT5 behavior resend only not acked present in reopened session. + nonAckPacketIds = inflightWindow.keySet(); + } - debugLogPacketIds(expired); + debugLogPacketIds(nonAckPacketIds); - for (InFlightPacket notAckPacketId : expired) { - final SessionRegistry.EnqueuedMessage msg = inflightWindow.get(notAckPacketId.packetId); + for (Integer notAckPacketId : nonAckPacketIds) { + final SessionRegistry.EnqueuedMessage msg = inflightWindow.get(notAckPacketId); if (msg == null) { // Already acked... continue; } if (msg instanceof SessionRegistry.PubRelMarker) { - MqttMessage pubRel = MQTTConnection.pubrel(notAckPacketId.packetId); - inflightTimeouts.add(new InFlightPacket(notAckPacketId.packetId, FLIGHT_BEFORE_RESEND_MS)); + MqttMessage pubRel = MQTTConnection.pubrel(notAckPacketId); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(notAckPacketId, FLIGHT_BEFORE_RESEND_MS)); + } mqttConnection.sendIfWritableElseDrop(pubRel); } else { final SessionRegistry.PublishedMessage pubMsg = (SessionRegistry.PublishedMessage) msg; @@ -400,34 +419,36 @@ public void resendInflightNotAcked() { final ByteBuf payload = pubMsg.payload; // message fetched from map, but not removed from map. No need to duplicate or release. MqttPublishMessage publishMsg = publishNotRetainedDuplicated(notAckPacketId, topic, qos, payload); - inflightTimeouts.add(new InFlightPacket(notAckPacketId.packetId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(notAckPacketId, FLIGHT_BEFORE_RESEND_MS)); + } mqttConnection.sendPublish(publishMsg); } } } - private void debugLogPacketIds(Collection expired) { - if (!LOG.isDebugEnabled() || expired.isEmpty()) { + private void debugLogPacketIds(Collection packetIds) { + if (!LOG.isDebugEnabled() || packetIds.isEmpty()) { return; } StringBuilder sb = new StringBuilder(); - for (InFlightPacket packet : expired) { - sb.append(packet.packetId).append(", "); + for (Integer packetId : packetIds) { + sb.append(packetId).append(", "); } - LOG.debug("Resending {} in flight packets [{}]", expired.size(), sb); + LOG.debug("Resending {} in flight packets [{}]", packetIds.size(), sb); } - private MqttPublishMessage publishNotRetainedDuplicated(InFlightPacket notAckPacketId, Topic topic, MqttQoS qos, + private MqttPublishMessage publishNotRetainedDuplicated(int packetId, Topic topic, MqttQoS qos, ByteBuf payload) { MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PUBLISH, true, qos, false, 0); - MqttPublishVariableHeader varHeader = new MqttPublishVariableHeader(topic.toString(), notAckPacketId.packetId); + MqttPublishVariableHeader varHeader = new MqttPublishVariableHeader(topic.toString(), packetId); return new MqttPublishMessage(fixedHeader, varHeader, payload); } private void drainQueueToConnection() { // consume the queue - while (!sessionQueue.isEmpty() && inflighHasSlotsAndConnectionIsUp()) { + while (!sessionQueue.isEmpty() && inflightHasSlotsAndConnectionIsUp()) { final SessionRegistry.EnqueuedMessage msg = sessionQueue.dequeue(); if (msg == null) { // Our message was already fetched by another Thread. @@ -442,7 +463,9 @@ private void drainQueueToConnection() { old.release(); inflightSlots.incrementAndGet(); } - inflightTimeouts.add(new InFlightPacket(sendPacketId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(sendPacketId, FLIGHT_BEFORE_RESEND_MS)); + } final SessionRegistry.PublishedMessage msgPub = (SessionRegistry.PublishedMessage) msg; MqttPublishMessage publishMsg = MQTTConnection.createNotRetainedPublishMessage( msgPub.topic.toString(), @@ -458,8 +481,11 @@ public void writabilityChanged() { drainQueueToConnection(); } - public void sendQueuedMessagesWhileOffline() { + public void reconnectSession() { LOG.trace("Republishing all saved messages for session {}", this); + resendInflightNotAcked(); + + // send queued messages while offline drainQueueToConnection(); } diff --git a/broker/src/main/java/io/moquette/broker/SessionRegistry.java b/broker/src/main/java/io/moquette/broker/SessionRegistry.java index aba4f8d83..83fe355b3 100644 --- a/broker/src/main/java/io/moquette/broker/SessionRegistry.java +++ b/broker/src/main/java/io/moquette/broker/SessionRegistry.java @@ -23,6 +23,7 @@ import io.netty.buffer.Unpooled; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttVersion; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -133,7 +134,7 @@ private void recreateSessionPool() { if (queueRepository.containsQueue(clientId)) { final SessionMessageQueue persistentQueue = queueRepository.getOrCreateQueue(clientId); queues.remove(clientId); - Session rehydrated = new Session(clientId, false, persistentQueue); + Session rehydrated = new Session(clientId, false, MqttVersion.MQTT_3_1, persistentQueue); pool.put(clientId, rehydrated); } } @@ -228,11 +229,14 @@ private Session createNewSession(MqttConnectMessage msg, String clientId) { } else { queue = new InMemoryQueue(); } + + final MqttVersion mqttVersion = Utils.versionFromConnect(msg); + if (msg.variableHeader().isWillFlag()) { final Session.Will will = createWill(msg); - newSession = new Session(clientId, clean, will, queue); + newSession = new Session(clientId, clean, mqttVersion, will, queue); } else { - newSession = new Session(clientId, clean, queue); + newSession = new Session(clientId, clean, mqttVersion, queue); } newSession.markConnecting(); diff --git a/broker/src/main/java/io/moquette/broker/Utils.java b/broker/src/main/java/io/moquette/broker/Utils.java index 8c178e117..3d671c266 100644 --- a/broker/src/main/java/io/moquette/broker/Utils.java +++ b/broker/src/main/java/io/moquette/broker/Utils.java @@ -17,8 +17,11 @@ package io.moquette.broker; import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttMessage; import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader; +import io.netty.handler.codec.mqtt.MqttVersion; + import java.util.Map; /** @@ -46,6 +49,10 @@ public static byte[] readBytesAndRewind(ByteBuf payload) { return payloadContent; } + public static MqttVersion versionFromConnect(MqttConnectMessage msg) { + return MqttVersion.fromProtocolNameAndLevel(msg.variableHeader().name(), (byte) msg.variableHeader().version()); + } + private Utils() { } } diff --git a/broker/src/test/java/io/moquette/broker/SessionTest.java b/broker/src/test/java/io/moquette/broker/SessionTest.java index e440f6cac..ab2104774 100644 --- a/broker/src/test/java/io/moquette/broker/SessionTest.java +++ b/broker/src/test/java/io/moquette/broker/SessionTest.java @@ -6,6 +6,7 @@ import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttVersion; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -27,7 +28,7 @@ public class SessionTest { public void setUp() { testChannel = new EmbeddedChannel(); queuedMessages = new InMemoryQueue(); - client = new Session(CLIENT_ID, true, null, queuedMessages); + client = new Session(CLIENT_ID, true, MqttVersion.MQTT_3_1, null, queuedMessages); createConnection(client); } diff --git a/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java b/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java new file mode 100644 index 000000000..ea6bd4d15 --- /dev/null +++ b/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java @@ -0,0 +1,154 @@ +package io.moquette.integration.mqtt5; + +import com.hivemq.client.mqtt.MqttClient; +import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient; +import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAck; +import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAckReasonCode; +import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5PublishResult; +import io.moquette.broker.Server; +import io.moquette.broker.config.IConfig; +import io.moquette.broker.config.MemoryConfig; +import io.moquette.integration.IntegrationUtils; +import io.moquette.testclient.Client; +import io.netty.handler.codec.mqtt.MqttConnAckMessage; +import io.netty.handler.codec.mqtt.MqttConnectReturnCode; +import io.netty.handler.codec.mqtt.MqttMessage; +import io.netty.handler.codec.mqtt.MqttPublishMessage; +import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttSubAckMessage; +import org.awaitility.Awaitility; +import org.awaitility.Durations; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ConnectTest { + private static final Logger LOG = LoggerFactory.getLogger(ConnectTest.class); + + Server broker; + IConfig config; + + @TempDir + Path tempFolder; + private String dbPath; + private Client lowLevelClient; + + protected void startServer(String dbPath) throws IOException { + broker = new Server(); + final Properties configProps = IntegrationUtils.prepareTestProperties(dbPath); + config = new MemoryConfig(configProps); + broker.startServer(config); + } + + @BeforeAll + public static void beforeTests() { + Awaitility.setDefaultTimeout(Durations.ONE_SECOND); + } + + @BeforeEach + public void setUp() throws Exception { + dbPath = IntegrationUtils.tempH2Path(tempFolder); + startServer(dbPath); + + lowLevelClient = new Client("localhost").clientId("subscriber"); + } + + @AfterEach + public void tearDown() throws Exception { + stopServer(); + } + + private void stopServer() { + broker.stopServer(); + } + + @Test + public void simpleConnect() { + Mqtt5BlockingClient client = MqttClient.builder() + .useMqttVersion5() + .identifier("simple_connect_test") + .serverHost("localhost") + .serverPort(1883) + .buildBlocking(); + final Mqtt5ConnAck connectAck = client.connect(); + assertEquals(Mqtt5ConnAckReasonCode.SUCCESS, connectAck.getReasonCode(), "Accept plain connection"); + + client.disconnect(); + } + + @Test + public void sendConnectOnDisconnectedConnection() { + MqttConnAckMessage connAck = lowLevelClient.connectV5(); + assertConnectionAccepted(connAck, "Connection must be accepted"); + lowLevelClient.disconnect(); + + try { + lowLevelClient.connectV5(); + fail("Connect on Disconnected TCP socket can't happen"); + } catch (RuntimeException rex) { + assertEquals("Cannot receive ConnAck in 200 ms", rex.getMessage()); + } + } + + @Test + public void receiveInflightPublishesAfterAReconnect() { + final Mqtt5BlockingClient publisher = MqttClient.builder() + .useMqttVersion5() + .identifier("publisher") + .serverHost("localhost") + .serverPort(1883) + .buildBlocking(); + Mqtt5ConnAck connectAck = publisher.connect(); + assertEquals(Mqtt5ConnAckReasonCode.SUCCESS, connectAck.getReasonCode(), "Publisher connected"); + + final MqttConnAckMessage connAck = lowLevelClient.connectV5(); + assertConnectionAccepted(connAck, "Connection must be accepted"); + lowLevelClient.subscribe("/test", MqttQoS.AT_LEAST_ONCE); + + final Mqtt5PublishResult pubResult = publisher.publishWith() + .topic("/test") + .qos(com.hivemq.client.mqtt.datatypes.MqttQos.AT_LEAST_ONCE) + .payload("Hello".getBytes(StandardCharsets.UTF_8)) + .send(); + assertFalse(pubResult.getError().isPresent(), "Publisher published"); + + lowLevelClient.disconnect(); + + // reconnect the raw subscriber + final Client reconnectingSubscriber = new Client("localhost").clientId("subscriber"); + assertConnectionAccepted(reconnectingSubscriber.connectV5(), "Connection must be accepted"); + + Awaitility.await() + .atMost(2, TimeUnit.SECONDS) + .until(reconnectingSubscriber::hasReceivedMessages); + + final String publishPayload = reconnectingSubscriber.nextQueuedMessage() + .filter(m -> m instanceof MqttPublishMessage) + .map(m -> (MqttPublishMessage) m) + .map(m -> m.payload().toString(StandardCharsets.UTF_8)) + .orElse("Fake Payload"); + assertEquals("Hello", publishPayload, "The inflight payload from previous subscription MUST be received"); + + reconnectingSubscriber.disconnect(); + } + + private void assertConnectionAccepted(MqttConnAckMessage connAck, String message) { + assertEquals(MqttConnectReturnCode.CONNECTION_ACCEPTED, connAck.variableHeader().connectReturnCode(), message); + } +} diff --git a/broker/src/test/java/io/moquette/testclient/Client.java b/broker/src/test/java/io/moquette/testclient/Client.java index 7e714226a..95fbf906c 100644 --- a/broker/src/test/java/io/moquette/testclient/Client.java +++ b/broker/src/test/java/io/moquette/testclient/Client.java @@ -27,8 +27,12 @@ import org.slf4j.LoggerFactory; import java.nio.charset.Charset; +import java.util.Optional; +import java.util.Queue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.channel.ChannelFutureListener.CLOSE_ON_FAILURE; import static io.netty.channel.ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE; @@ -51,7 +55,8 @@ public interface ICallback { private boolean m_connectionLost; private ICallback callback; private String clientId; - private MqttMessage receivedMsg; + private AtomicReference receivedMsg = new AtomicReference<>(); + private final Queue receivedMessages = new LinkedBlockingQueue<>(); public Client(String host) { this(host, BrokerConstants.PORT); @@ -118,15 +123,6 @@ public void connect(String willTestamentTopic, String willTestamentMsg) { mqttConnectVariableHeader, mqttConnectPayload); - /* - * ConnectMessage connectMessage = new ConnectMessage(); - * connectMessage.setProtocolVersion((byte) 3); connectMessage.setClientID(this.clientId); - * connectMessage.setKeepAlive(2); //secs connectMessage.setWillFlag(true); - * connectMessage.setWillMessage(willTestamentMsg.getBytes()); - * connectMessage.setWillTopic(willTestamentTopic); - * connectMessage.setWillQos(MqttQoS.AT_MOST_ONCE.byteValue()); - */ - doConnect(connectMessage); } @@ -138,24 +134,90 @@ public void connect() { doConnect(connectMessage); } - private void doConnect(MqttConnectMessage connectMessage) { + public MqttConnAckMessage connectV5() { + MqttConnectMessage connectMessage = MqttMessageBuilders.connect().protocolVersion(MqttVersion.MQTT_5) + .clientId(clientId) + .keepAlive(2) // secs + .willFlag(false) + .willQoS(MqttQoS.AT_MOST_ONCE) + .build(); + + return doConnect(connectMessage); + } + + private MqttConnAckMessage doConnect(MqttConnectMessage connectMessage) { final CountDownLatch latch = new CountDownLatch(1); this.setCallback(msg -> { - receivedMsg = msg; + receivedMsg.getAndSet(msg); + LOG.info("Connect callback invocation, received message {}", msg.fixedHeader().messageType()); latch.countDown(); + + // clear the callback + setCallback(null); }); this.sendMessage(connectMessage); + boolean waitElapsed; try { - latch.await(200, TimeUnit.MILLISECONDS); + waitElapsed = !latch.await(2_000, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { - throw new RuntimeException("Cannot receive message in 200 ms", e); + throw new RuntimeException("Interrupted while waiting", e); + } + + if (waitElapsed) { + throw new RuntimeException("Cannot receive ConnAck in 200 ms"); } - if (!(this.receivedMsg instanceof MqttConnAckMessage)) { - MqttMessageType messageType = this.receivedMsg.fixedHeader().messageType(); + + final MqttMessage connAckMessage = this.receivedMsg.get(); + if (!(connAckMessage instanceof MqttConnAckMessage)) { + MqttMessageType messageType = connAckMessage.fixedHeader().messageType(); throw new RuntimeException("Expected a CONN_ACK message but received " + messageType); } + return (MqttConnAckMessage) connAckMessage; + } + + public MqttSubAckMessage subscribe(String topic, MqttQoS qos) { + final MqttSubscribeMessage subscribeMessage = MqttMessageBuilders.subscribe() + .messageId(1) + .addSubscription(qos, topic) + .build(); + + final CountDownLatch subscribeAckLatch = new CountDownLatch(1); + this.setCallback(msg -> { + receivedMsg.getAndSet(msg); + LOG.debug("Subscribe callback invocation, received message {}", msg.fixedHeader().messageType()); + subscribeAckLatch.countDown(); + + // clear the callback + setCallback(null); + }); + + LOG.debug("Sending SUBSCRIBE message"); + sendMessage(subscribeMessage); + LOG.debug("Sent SUBSCRIBE message"); + + boolean waitElapsed; + try { + waitElapsed = !subscribeAckLatch.await(200, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while waiting", e); + } + + if (waitElapsed) { + throw new RuntimeException("Cannot receive SubscribeAck in 200 ms"); + } + final MqttMessage subAckMessage = this.receivedMsg.get(); + if (!(subAckMessage instanceof MqttSubAckMessage)) { + MqttMessageType messageType = subAckMessage.fixedHeader().messageType(); + throw new RuntimeException("Expected a SUB_ACK message but received " + messageType); + } + return (MqttSubAckMessage) subAckMessage; + } + + public void disconnect() { + final MqttMessage disconnectMessage = MqttMessageBuilders.disconnect().build(); + sendMessage(disconnectMessage); } public void setCallback(ICallback callback) { @@ -167,16 +229,22 @@ public void sendMessage(MqttMessage msg) { } public MqttMessage lastReceivedMessage() { - return this.receivedMsg; + return this.receivedMsg.get(); } void messageReceived(MqttMessage msg) { LOG.info("Received message {}", msg); if (this.callback != null) { this.callback.call(msg); + } else { + receivedMessages.add(msg); } } + public boolean hasReceivedMessages() { + return !receivedMessages.isEmpty(); + } + void setConnectionLost(boolean status) { m_connectionLost = status; } @@ -185,6 +253,13 @@ public boolean isConnectionLost() { return m_connectionLost; } + public Optional nextQueuedMessage() { + if (receivedMessages.isEmpty()) { + return Optional.empty(); + } + return Optional.of(receivedMessages.poll()); + } + @SuppressWarnings("FutureReturnValueIgnored") public void close() throws InterruptedException { // Wait until the connection is closed.