diff --git a/.github/workflows/maven_build.yml b/.github/workflows/maven_build.yml index dee3b1054..edba9df4b 100644 --- a/.github/workflows/maven_build.yml +++ b/.github/workflows/maven_build.yml @@ -5,9 +5,9 @@ name: Java CI with Maven on: push: - branches: [ main ] + branches: [ main, mqtt5_development ] pull_request: - branches: [ main ] + branches: [ main, mqtt5_development ] jobs: test: diff --git a/broker/pom.xml b/broker/pom.xml index f5cf8923e..78a71df2b 100644 --- a/broker/pom.xml +++ b/broker/pom.xml @@ -19,6 +19,7 @@ https://github.com/netty/netty/blob/netty-4.1.85.Final/pom.xml#L594 --> 2.0.54.Final 1.2.5 + 1.3.0 2.1.212 @@ -146,6 +147,13 @@ test + + com.hivemq + hivemq-mqtt-client + ${hivemqclient.version} + test + + org.eclipse.jetty.websocket websocket-client diff --git a/broker/src/main/java/io/moquette/broker/MQTTConnection.java b/broker/src/main/java/io/moquette/broker/MQTTConnection.java index 2e55101d8..a0bf5799f 100644 --- a/broker/src/main/java/io/moquette/broker/MQTTConnection.java +++ b/broker/src/main/java/io/moquette/broker/MQTTConnection.java @@ -149,7 +149,10 @@ PostOffice.RouteResult processConnect(MqttConnectMessage msg) { final String username = payload.userName(); LOG.trace("Processing CONNECT message. CId: {} username: {}", clientId, username); - if (isNotProtocolVersion(msg, MqttVersion.MQTT_3_1) && isNotProtocolVersion(msg, MqttVersion.MQTT_3_1_1)) { + if (isNotProtocolVersion(msg, MqttVersion.MQTT_3_1) && + isNotProtocolVersion(msg, MqttVersion.MQTT_3_1_1) && + isNotProtocolVersion(msg, MqttVersion.MQTT_5) + ) { LOG.warn("MQTT protocol version is not valid. CId: {}", clientId); abortConnection(CONNECTION_REFUSED_UNACCEPTABLE_PROTOCOL_VERSION); return PostOffice.RouteResult.failed(clientId); @@ -236,11 +239,14 @@ public void operationComplete(ChannelFuture future) throws Exception { // OK continue with sending queued messages and normal flow if (result.mode == SessionRegistry.CreationModeEnum.REOPEN_EXISTING) { - result.session.sendQueuedMessagesWhileOffline(); + result.session.reconnectSession(); } initializeKeepAliveTimeout(channel, msg, clientIdUsed); - setupInflightResender(channel); + if (isNotProtocolVersion(msg, MqttVersion.MQTT_5)) { + // In MQTT5 MQTT-4.4.0-1 avoid retries messages on timer base. + setupInflightResender(channel); + } postOffice.dispatchConnection(msg); LOG.trace("dispatch connection: {}", msg); diff --git a/broker/src/main/java/io/moquette/broker/Session.java b/broker/src/main/java/io/moquette/broker/Session.java index f5b1561be..78f5ee59d 100644 --- a/broker/src/main/java/io/moquette/broker/Session.java +++ b/broker/src/main/java/io/moquette/broker/Session.java @@ -42,6 +42,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; class Session { @@ -103,18 +104,20 @@ static final class Will { private MQTTConnection mqttConnection; private final Set subscriptions = new HashSet<>(); private final Map inflightWindow = new HashMap<>(); + // used only in MQTT3 where resends are done on timeout of ACKs. private final DelayQueue inflightTimeouts = new DelayQueue<>(); private final Map qos2Receiving = new HashMap<>(); private final AtomicInteger inflightSlots = new AtomicInteger(INFLIGHT_WINDOW_SIZE); // this should be configurable private final Instant created; private final int expiryInterval; + private final boolean resendInflightOnTimeout; - Session(String clientId, boolean clean, Will will, SessionMessageQueue sessionQueue) { - this(clientId, clean, sessionQueue); + Session(String clientId, boolean clean, MqttVersion protocolVersion, Will will, SessionMessageQueue sessionQueue) { + this(clientId, clean, protocolVersion, sessionQueue); this.will = will; } - Session(String clientId, boolean clean, SessionMessageQueue sessionQueue) { + Session(String clientId, boolean clean, MqttVersion protocolVersion, SessionMessageQueue sessionQueue) { if (sessionQueue == null) { throw new IllegalArgumentException("sessionQueue parameter can't be null"); } @@ -124,6 +127,7 @@ static final class Will { this.created = Instant.now(); // in MQTT3 cleanSession = true means expiryInterval=0 else infinite expiryInterval = clean ? 0 : 0xFFFFFFFF; + this.resendInflightOnTimeout = protocolVersion != MqttVersion.MQTT_5; } public boolean expireImmediately() { @@ -228,7 +232,9 @@ public void processPubRec(int pubRecPacketId) { return; } inflightWindow.put(pubRecPacketId, new SessionRegistry.PubRelMarker()); - inflightTimeouts.add(new InFlightPacket(pubRecPacketId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(pubRecPacketId, FLIGHT_BEFORE_RESEND_MS)); + } MqttMessage pubRel = MQTTConnection.pubrel(pubRecPacketId); mqttConnection.sendIfWritableElseDrop(pubRel); @@ -297,7 +303,9 @@ private void sendPublishQos1(Topic topic, MqttQoS qos, ByteBuf payload, boolean old.release(); inflightSlots.incrementAndGet(); } - inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); + } MqttPublishMessage publishMsg = MQTTConnection.createNotRetainedPublishMessage(topic.toString(), qos, payload, packetId); @@ -331,8 +339,9 @@ private void sendPublishQos2(Topic topic, MqttQoS qos, ByteBuf payload, boolean old.release(); inflightSlots.incrementAndGet(); } - inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); - + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(packetId, FLIGHT_BEFORE_RESEND_MS)); + } MqttPublishMessage publishMsg = MQTTConnection.createNotRetainedPublishMessage(topic.toString(), qos, payload, packetId); localMqttConnectionRef.sendPublish(publishMsg); @@ -354,7 +363,7 @@ private boolean canSkipQueue(MQTTConnection localMqttConnectionRef) { localMqttConnectionRef.channel.isWritable(); } - private boolean inflighHasSlotsAndConnectionIsUp() { + private boolean inflightHasSlotsAndConnectionIsUp() { return inflightSlots.get() > 0 && connected() && mqttConnection.channel.isWritable(); @@ -378,20 +387,30 @@ public void flushAllQueuedMessages() { } public void resendInflightNotAcked() { - Collection expired = new ArrayList<>(INFLIGHT_WINDOW_SIZE); - inflightTimeouts.drainTo(expired); + Collection nonAckPacketIds; + if (resendInflightOnTimeout) { + // MQTT3 behavior, resend on timeout + Collection expired = new ArrayList<>(INFLIGHT_WINDOW_SIZE); + inflightTimeouts.drainTo(expired); + nonAckPacketIds = expired.stream().map(p -> p.packetId).collect(Collectors.toList()); + } else { + // MQTT5 behavior resend only not acked present in reopened session. + nonAckPacketIds = inflightWindow.keySet(); + } - debugLogPacketIds(expired); + debugLogPacketIds(nonAckPacketIds); - for (InFlightPacket notAckPacketId : expired) { - final SessionRegistry.EnqueuedMessage msg = inflightWindow.get(notAckPacketId.packetId); + for (Integer notAckPacketId : nonAckPacketIds) { + final SessionRegistry.EnqueuedMessage msg = inflightWindow.get(notAckPacketId); if (msg == null) { // Already acked... continue; } if (msg instanceof SessionRegistry.PubRelMarker) { - MqttMessage pubRel = MQTTConnection.pubrel(notAckPacketId.packetId); - inflightTimeouts.add(new InFlightPacket(notAckPacketId.packetId, FLIGHT_BEFORE_RESEND_MS)); + MqttMessage pubRel = MQTTConnection.pubrel(notAckPacketId); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(notAckPacketId, FLIGHT_BEFORE_RESEND_MS)); + } mqttConnection.sendIfWritableElseDrop(pubRel); } else { final SessionRegistry.PublishedMessage pubMsg = (SessionRegistry.PublishedMessage) msg; @@ -400,34 +419,36 @@ public void resendInflightNotAcked() { final ByteBuf payload = pubMsg.payload; // message fetched from map, but not removed from map. No need to duplicate or release. MqttPublishMessage publishMsg = publishNotRetainedDuplicated(notAckPacketId, topic, qos, payload); - inflightTimeouts.add(new InFlightPacket(notAckPacketId.packetId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(notAckPacketId, FLIGHT_BEFORE_RESEND_MS)); + } mqttConnection.sendPublish(publishMsg); } } } - private void debugLogPacketIds(Collection expired) { - if (!LOG.isDebugEnabled() || expired.isEmpty()) { + private void debugLogPacketIds(Collection packetIds) { + if (!LOG.isDebugEnabled() || packetIds.isEmpty()) { return; } StringBuilder sb = new StringBuilder(); - for (InFlightPacket packet : expired) { - sb.append(packet.packetId).append(", "); + for (Integer packetId : packetIds) { + sb.append(packetId).append(", "); } - LOG.debug("Resending {} in flight packets [{}]", expired.size(), sb); + LOG.debug("Resending {} in flight packets [{}]", packetIds.size(), sb); } - private MqttPublishMessage publishNotRetainedDuplicated(InFlightPacket notAckPacketId, Topic topic, MqttQoS qos, + private MqttPublishMessage publishNotRetainedDuplicated(int packetId, Topic topic, MqttQoS qos, ByteBuf payload) { MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PUBLISH, true, qos, false, 0); - MqttPublishVariableHeader varHeader = new MqttPublishVariableHeader(topic.toString(), notAckPacketId.packetId); + MqttPublishVariableHeader varHeader = new MqttPublishVariableHeader(topic.toString(), packetId); return new MqttPublishMessage(fixedHeader, varHeader, payload); } private void drainQueueToConnection() { // consume the queue - while (!sessionQueue.isEmpty() && inflighHasSlotsAndConnectionIsUp()) { + while (!sessionQueue.isEmpty() && inflightHasSlotsAndConnectionIsUp()) { final SessionRegistry.EnqueuedMessage msg = sessionQueue.dequeue(); if (msg == null) { // Our message was already fetched by another Thread. @@ -442,7 +463,9 @@ private void drainQueueToConnection() { old.release(); inflightSlots.incrementAndGet(); } - inflightTimeouts.add(new InFlightPacket(sendPacketId, FLIGHT_BEFORE_RESEND_MS)); + if (resendInflightOnTimeout) { + inflightTimeouts.add(new InFlightPacket(sendPacketId, FLIGHT_BEFORE_RESEND_MS)); + } final SessionRegistry.PublishedMessage msgPub = (SessionRegistry.PublishedMessage) msg; MqttPublishMessage publishMsg = MQTTConnection.createNotRetainedPublishMessage( msgPub.topic.toString(), @@ -458,8 +481,11 @@ public void writabilityChanged() { drainQueueToConnection(); } - public void sendQueuedMessagesWhileOffline() { + public void reconnectSession() { LOG.trace("Republishing all saved messages for session {}", this); + resendInflightNotAcked(); + + // send queued messages while offline drainQueueToConnection(); } diff --git a/broker/src/main/java/io/moquette/broker/SessionRegistry.java b/broker/src/main/java/io/moquette/broker/SessionRegistry.java index aba4f8d83..83fe355b3 100644 --- a/broker/src/main/java/io/moquette/broker/SessionRegistry.java +++ b/broker/src/main/java/io/moquette/broker/SessionRegistry.java @@ -23,6 +23,7 @@ import io.netty.buffer.Unpooled; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttVersion; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -133,7 +134,7 @@ private void recreateSessionPool() { if (queueRepository.containsQueue(clientId)) { final SessionMessageQueue persistentQueue = queueRepository.getOrCreateQueue(clientId); queues.remove(clientId); - Session rehydrated = new Session(clientId, false, persistentQueue); + Session rehydrated = new Session(clientId, false, MqttVersion.MQTT_3_1, persistentQueue); pool.put(clientId, rehydrated); } } @@ -228,11 +229,14 @@ private Session createNewSession(MqttConnectMessage msg, String clientId) { } else { queue = new InMemoryQueue(); } + + final MqttVersion mqttVersion = Utils.versionFromConnect(msg); + if (msg.variableHeader().isWillFlag()) { final Session.Will will = createWill(msg); - newSession = new Session(clientId, clean, will, queue); + newSession = new Session(clientId, clean, mqttVersion, will, queue); } else { - newSession = new Session(clientId, clean, queue); + newSession = new Session(clientId, clean, mqttVersion, queue); } newSession.markConnecting(); diff --git a/broker/src/main/java/io/moquette/broker/Utils.java b/broker/src/main/java/io/moquette/broker/Utils.java index 8c178e117..3d671c266 100644 --- a/broker/src/main/java/io/moquette/broker/Utils.java +++ b/broker/src/main/java/io/moquette/broker/Utils.java @@ -17,8 +17,11 @@ package io.moquette.broker; import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttMessage; import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader; +import io.netty.handler.codec.mqtt.MqttVersion; + import java.util.Map; /** @@ -46,6 +49,10 @@ public static byte[] readBytesAndRewind(ByteBuf payload) { return payloadContent; } + public static MqttVersion versionFromConnect(MqttConnectMessage msg) { + return MqttVersion.fromProtocolNameAndLevel(msg.variableHeader().name(), (byte) msg.variableHeader().version()); + } + private Utils() { } } diff --git a/broker/src/test/java/io/moquette/broker/SessionTest.java b/broker/src/test/java/io/moquette/broker/SessionTest.java index e440f6cac..ab2104774 100644 --- a/broker/src/test/java/io/moquette/broker/SessionTest.java +++ b/broker/src/test/java/io/moquette/broker/SessionTest.java @@ -6,6 +6,7 @@ import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttVersion; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -27,7 +28,7 @@ public class SessionTest { public void setUp() { testChannel = new EmbeddedChannel(); queuedMessages = new InMemoryQueue(); - client = new Session(CLIENT_ID, true, null, queuedMessages); + client = new Session(CLIENT_ID, true, MqttVersion.MQTT_3_1, null, queuedMessages); createConnection(client); } diff --git a/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java b/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java new file mode 100644 index 000000000..ea6bd4d15 --- /dev/null +++ b/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java @@ -0,0 +1,154 @@ +package io.moquette.integration.mqtt5; + +import com.hivemq.client.mqtt.MqttClient; +import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient; +import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAck; +import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAckReasonCode; +import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5PublishResult; +import io.moquette.broker.Server; +import io.moquette.broker.config.IConfig; +import io.moquette.broker.config.MemoryConfig; +import io.moquette.integration.IntegrationUtils; +import io.moquette.testclient.Client; +import io.netty.handler.codec.mqtt.MqttConnAckMessage; +import io.netty.handler.codec.mqtt.MqttConnectReturnCode; +import io.netty.handler.codec.mqtt.MqttMessage; +import io.netty.handler.codec.mqtt.MqttPublishMessage; +import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttSubAckMessage; +import org.awaitility.Awaitility; +import org.awaitility.Durations; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ConnectTest { + private static final Logger LOG = LoggerFactory.getLogger(ConnectTest.class); + + Server broker; + IConfig config; + + @TempDir + Path tempFolder; + private String dbPath; + private Client lowLevelClient; + + protected void startServer(String dbPath) throws IOException { + broker = new Server(); + final Properties configProps = IntegrationUtils.prepareTestProperties(dbPath); + config = new MemoryConfig(configProps); + broker.startServer(config); + } + + @BeforeAll + public static void beforeTests() { + Awaitility.setDefaultTimeout(Durations.ONE_SECOND); + } + + @BeforeEach + public void setUp() throws Exception { + dbPath = IntegrationUtils.tempH2Path(tempFolder); + startServer(dbPath); + + lowLevelClient = new Client("localhost").clientId("subscriber"); + } + + @AfterEach + public void tearDown() throws Exception { + stopServer(); + } + + private void stopServer() { + broker.stopServer(); + } + + @Test + public void simpleConnect() { + Mqtt5BlockingClient client = MqttClient.builder() + .useMqttVersion5() + .identifier("simple_connect_test") + .serverHost("localhost") + .serverPort(1883) + .buildBlocking(); + final Mqtt5ConnAck connectAck = client.connect(); + assertEquals(Mqtt5ConnAckReasonCode.SUCCESS, connectAck.getReasonCode(), "Accept plain connection"); + + client.disconnect(); + } + + @Test + public void sendConnectOnDisconnectedConnection() { + MqttConnAckMessage connAck = lowLevelClient.connectV5(); + assertConnectionAccepted(connAck, "Connection must be accepted"); + lowLevelClient.disconnect(); + + try { + lowLevelClient.connectV5(); + fail("Connect on Disconnected TCP socket can't happen"); + } catch (RuntimeException rex) { + assertEquals("Cannot receive ConnAck in 200 ms", rex.getMessage()); + } + } + + @Test + public void receiveInflightPublishesAfterAReconnect() { + final Mqtt5BlockingClient publisher = MqttClient.builder() + .useMqttVersion5() + .identifier("publisher") + .serverHost("localhost") + .serverPort(1883) + .buildBlocking(); + Mqtt5ConnAck connectAck = publisher.connect(); + assertEquals(Mqtt5ConnAckReasonCode.SUCCESS, connectAck.getReasonCode(), "Publisher connected"); + + final MqttConnAckMessage connAck = lowLevelClient.connectV5(); + assertConnectionAccepted(connAck, "Connection must be accepted"); + lowLevelClient.subscribe("/test", MqttQoS.AT_LEAST_ONCE); + + final Mqtt5PublishResult pubResult = publisher.publishWith() + .topic("/test") + .qos(com.hivemq.client.mqtt.datatypes.MqttQos.AT_LEAST_ONCE) + .payload("Hello".getBytes(StandardCharsets.UTF_8)) + .send(); + assertFalse(pubResult.getError().isPresent(), "Publisher published"); + + lowLevelClient.disconnect(); + + // reconnect the raw subscriber + final Client reconnectingSubscriber = new Client("localhost").clientId("subscriber"); + assertConnectionAccepted(reconnectingSubscriber.connectV5(), "Connection must be accepted"); + + Awaitility.await() + .atMost(2, TimeUnit.SECONDS) + .until(reconnectingSubscriber::hasReceivedMessages); + + final String publishPayload = reconnectingSubscriber.nextQueuedMessage() + .filter(m -> m instanceof MqttPublishMessage) + .map(m -> (MqttPublishMessage) m) + .map(m -> m.payload().toString(StandardCharsets.UTF_8)) + .orElse("Fake Payload"); + assertEquals("Hello", publishPayload, "The inflight payload from previous subscription MUST be received"); + + reconnectingSubscriber.disconnect(); + } + + private void assertConnectionAccepted(MqttConnAckMessage connAck, String message) { + assertEquals(MqttConnectReturnCode.CONNECTION_ACCEPTED, connAck.variableHeader().connectReturnCode(), message); + } +} diff --git a/broker/src/test/java/io/moquette/testclient/Client.java b/broker/src/test/java/io/moquette/testclient/Client.java index 7e714226a..95fbf906c 100644 --- a/broker/src/test/java/io/moquette/testclient/Client.java +++ b/broker/src/test/java/io/moquette/testclient/Client.java @@ -27,8 +27,12 @@ import org.slf4j.LoggerFactory; import java.nio.charset.Charset; +import java.util.Optional; +import java.util.Queue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.channel.ChannelFutureListener.CLOSE_ON_FAILURE; import static io.netty.channel.ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE; @@ -51,7 +55,8 @@ public interface ICallback { private boolean m_connectionLost; private ICallback callback; private String clientId; - private MqttMessage receivedMsg; + private AtomicReference receivedMsg = new AtomicReference<>(); + private final Queue receivedMessages = new LinkedBlockingQueue<>(); public Client(String host) { this(host, BrokerConstants.PORT); @@ -118,15 +123,6 @@ public void connect(String willTestamentTopic, String willTestamentMsg) { mqttConnectVariableHeader, mqttConnectPayload); - /* - * ConnectMessage connectMessage = new ConnectMessage(); - * connectMessage.setProtocolVersion((byte) 3); connectMessage.setClientID(this.clientId); - * connectMessage.setKeepAlive(2); //secs connectMessage.setWillFlag(true); - * connectMessage.setWillMessage(willTestamentMsg.getBytes()); - * connectMessage.setWillTopic(willTestamentTopic); - * connectMessage.setWillQos(MqttQoS.AT_MOST_ONCE.byteValue()); - */ - doConnect(connectMessage); } @@ -138,24 +134,90 @@ public void connect() { doConnect(connectMessage); } - private void doConnect(MqttConnectMessage connectMessage) { + public MqttConnAckMessage connectV5() { + MqttConnectMessage connectMessage = MqttMessageBuilders.connect().protocolVersion(MqttVersion.MQTT_5) + .clientId(clientId) + .keepAlive(2) // secs + .willFlag(false) + .willQoS(MqttQoS.AT_MOST_ONCE) + .build(); + + return doConnect(connectMessage); + } + + private MqttConnAckMessage doConnect(MqttConnectMessage connectMessage) { final CountDownLatch latch = new CountDownLatch(1); this.setCallback(msg -> { - receivedMsg = msg; + receivedMsg.getAndSet(msg); + LOG.info("Connect callback invocation, received message {}", msg.fixedHeader().messageType()); latch.countDown(); + + // clear the callback + setCallback(null); }); this.sendMessage(connectMessage); + boolean waitElapsed; try { - latch.await(200, TimeUnit.MILLISECONDS); + waitElapsed = !latch.await(2_000, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { - throw new RuntimeException("Cannot receive message in 200 ms", e); + throw new RuntimeException("Interrupted while waiting", e); + } + + if (waitElapsed) { + throw new RuntimeException("Cannot receive ConnAck in 200 ms"); } - if (!(this.receivedMsg instanceof MqttConnAckMessage)) { - MqttMessageType messageType = this.receivedMsg.fixedHeader().messageType(); + + final MqttMessage connAckMessage = this.receivedMsg.get(); + if (!(connAckMessage instanceof MqttConnAckMessage)) { + MqttMessageType messageType = connAckMessage.fixedHeader().messageType(); throw new RuntimeException("Expected a CONN_ACK message but received " + messageType); } + return (MqttConnAckMessage) connAckMessage; + } + + public MqttSubAckMessage subscribe(String topic, MqttQoS qos) { + final MqttSubscribeMessage subscribeMessage = MqttMessageBuilders.subscribe() + .messageId(1) + .addSubscription(qos, topic) + .build(); + + final CountDownLatch subscribeAckLatch = new CountDownLatch(1); + this.setCallback(msg -> { + receivedMsg.getAndSet(msg); + LOG.debug("Subscribe callback invocation, received message {}", msg.fixedHeader().messageType()); + subscribeAckLatch.countDown(); + + // clear the callback + setCallback(null); + }); + + LOG.debug("Sending SUBSCRIBE message"); + sendMessage(subscribeMessage); + LOG.debug("Sent SUBSCRIBE message"); + + boolean waitElapsed; + try { + waitElapsed = !subscribeAckLatch.await(200, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while waiting", e); + } + + if (waitElapsed) { + throw new RuntimeException("Cannot receive SubscribeAck in 200 ms"); + } + final MqttMessage subAckMessage = this.receivedMsg.get(); + if (!(subAckMessage instanceof MqttSubAckMessage)) { + MqttMessageType messageType = subAckMessage.fixedHeader().messageType(); + throw new RuntimeException("Expected a SUB_ACK message but received " + messageType); + } + return (MqttSubAckMessage) subAckMessage; + } + + public void disconnect() { + final MqttMessage disconnectMessage = MqttMessageBuilders.disconnect().build(); + sendMessage(disconnectMessage); } public void setCallback(ICallback callback) { @@ -167,16 +229,22 @@ public void sendMessage(MqttMessage msg) { } public MqttMessage lastReceivedMessage() { - return this.receivedMsg; + return this.receivedMsg.get(); } void messageReceived(MqttMessage msg) { LOG.info("Received message {}", msg); if (this.callback != null) { this.callback.call(msg); + } else { + receivedMessages.add(msg); } } + public boolean hasReceivedMessages() { + return !receivedMessages.isEmpty(); + } + void setConnectionLost(boolean status) { m_connectionLost = status; } @@ -185,6 +253,13 @@ public boolean isConnectionLost() { return m_connectionLost; } + public Optional nextQueuedMessage() { + if (receivedMessages.isEmpty()) { + return Optional.empty(); + } + return Optional.of(receivedMessages.poll()); + } + @SuppressWarnings("FutureReturnValueIgnored") public void close() throws InterruptedException { // Wait until the connection is closed.