diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-06 21:31:07 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-06 21:31:07 +0200 |
| commit | a062d934aef40829d9559a8ca83147ea4c44108e (patch) | |
| tree | 6ee1f9eb3208bea65365e63d039ad27c40111c6f | |
| parent | e9958e8a0f5aa5fbe0a4a03be42b8bf640add6f7 (diff) | |
| download | mqtt-a062d934aef40829d9559a8ca83147ea4c44108e.tar.gz mqtt-a062d934aef40829d9559a8ca83147ea4c44108e.zip | |
Massive refactoring of the internals
| -rw-r--r-- | src/CMakeLists.txt | 4 | ||||
| -rw-r--r-- | src/client.c | 1160 | ||||
| -rw-r--r-- | src/deserialize.c | 286 | ||||
| -rw-r--r-- | src/deserialize.h | 11 | ||||
| -rw-r--r-- | src/mqtt.h | 4 | ||||
| -rw-r--r-- | src/packet.c | 76 | ||||
| -rw-r--r-- | src/packet.h | 96 | ||||
| -rw-r--r-- | src/serialize.c | 326 | ||||
| -rw-r--r-- | src/serialize.h | 11 | ||||
| -rw-r--r-- | src/stream.c | 10 | ||||
| -rw-r--r-- | src/stream.h | 2 | ||||
| -rw-r--r-- | src/stream_mqtt.h | 1 | ||||
| -rw-r--r-- | test/interop/CMakeLists.txt | 3 | ||||
| -rw-r--r-- | test/interop/ping_test.c | 27 | ||||
| -rw-r--r-- | test/interop/testclient.c | 44 | ||||
| -rw-r--r-- | test/interop/testclient.h | 7 | ||||
| -rw-r--r-- | test/interop/unsubscribe_test.c | 31 | ||||
| -rw-r--r-- | test/interop/username_and_password_test.c | 30 | ||||
| -rw-r--r-- | tools/sub.c | 5 |
19 files changed, 928 insertions, 1206 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f51fabb..5a565ca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,14 +2,14 @@ ADD_SUBDIRECTORY(lib) ADD_LIBRARY(mqtt STATIC client.c - deserialize.c misc.c packet.c - serialize.c socket.c socketstream.c stream.c stream_mqtt.c + stringstream.c + message.c $<TARGET_OBJECTS:bstrlib> ) diff --git a/src/client.c b/src/client.c index a6b0998..9238555 100644 --- a/src/client.c +++ b/src/client.c @@ -5,10 +5,11 @@ #include "socketstream.h" #include "socket.h" #include "misc.h" -#include "serialize.h" -#include "deserialize.h" #include "log.h" #include "private.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "message.h" #include "queue.h" @@ -25,9 +26,6 @@ #error define PRId64 for your platform #endif -TAILQ_HEAD(MessageList, MqttPacket); -typedef struct MessageList MessageList; - struct MqttClient { SocketStream stream; @@ -56,9 +54,9 @@ struct MqttClient /* packets waiting to be sent over network */ SIMPLEQ_HEAD(, MqttPacket) sendQueue; /* sent messages that are not done yet */ - MessageList outMessages; + MqttMessageList outMessages; /* received messages that are not done yet */ - MessageList inMessages; + MqttMessageList inMessages; int sessionPresent; /* when was the last packet sent */ int64_t lastPacketSentTime; @@ -80,18 +78,13 @@ struct MqttClient int paused; bstring userName; bstring password; -}; - -enum MessageState -{ - MessageStateQueued = 100, - MessageStateSend, - MessageStateSent + /* The packet we are receiving */ + MqttPacket inPacket; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); static int MqttClientQueueSimplePacket(MqttClient *client, int type); -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientSendPacket(MqttClient *client); static int MqttClientRecvPacket(MqttClient *client); static uint16_t MqttClientNextPacketId(MqttClient *client); static void MqttClientProcessMessageQueue(MqttClient *client); @@ -99,14 +92,14 @@ static void MqttClientClearQueues(MqttClient *client); static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int queued = 0; int inMessagesCount = 0; int outMessagesCount = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (packet->state == MessageStateQueued) + if (msg->state == MqttMessageStateQueued) { ++queued; } @@ -114,7 +107,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) ++outMessagesCount; } - TAILQ_FOREACH(packet, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { ++inMessagesCount; } @@ -212,11 +205,44 @@ void MqttClientSetOnPublish(MqttClient *client, client->onPublish = cb; } +static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static unsigned char MqttClientConnectFlags(MqttClient *client) +{ + unsigned char connectFlags = 0; + + if (client->cleanSession) + { + connectFlags |= 0x02; + } + + if (client->willTopic) + { + connectFlags |= 0x04; + connectFlags |= (client->willQos & 3) << 3; + connectFlags |= (client->willRetain & 1) << 5; + } + + if (client->userName) + { + connectFlags |= 0x80; + if (client->password) + { + connectFlags |= 0x40; + } + } + + return connectFlags; +} + int MqttClientConnect(MqttClient *client, const char *host, short port, int keepAlive, int cleanSession) { int sock; - MqttPacketConnect *packet; + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(host != NULL); @@ -253,44 +279,37 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, return -1; } - packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + packet = MqttPacketNew(MqttPacketTypeConnect); if (!packet) return -1; - if (client->cleanSession) - { - packet->connectFlags |= 0x02; - } - - packet->keepAlive = client->keepAlive; + StringStreamInit(&ss); - packet->clientId = bstrcpy(client->clientId); + StreamWriteMqttString(&MqttProtocolId, pss); + StreamWriteByte(MqttProtocolLevel, pss); + StreamWriteByte(MqttClientConnectFlags(client), pss); + StreamWriteUint16Be(client->keepAlive, pss); + StreamWriteMqttString(client->clientId, pss); if (client->willTopic) { - packet->connectFlags |= 0x04; - - packet->willTopic = bstrcpy(client->willTopic); - packet->willMessage = bstrcpy(client->willMessage); - - packet->connectFlags |= (client->willQos & 3) << 3; - packet->connectFlags |= (client->willRetain & 1) << 5; + StreamWriteMqttString(client->willTopic, pss); + StreamWriteMqttString(client->willMessage, pss); } if (client->userName) { - packet->connectFlags |= 0x80; - packet->userName = bstrcpy(client->userName); - - if (client->password) + StreamWriteMqttString(client->userName, pss); + if(client->password) { - packet->connectFlags |= 0x40; - packet->password = bstrcpy(client->password); + StreamWriteMqttString(client->password, pss); } } - MqttClientQueuePacket(client, &packet->base); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); return 0; } @@ -339,6 +358,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (timeout < 0) { timeout = client->keepAlive * 1000; + if (timeout == 0) + { + timeout = 30 * 1000; + } + } + else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0) + { + timeout = client->keepAlive * 1000; } rv = SocketSelect(client->stream.sock, &events, timeout); @@ -354,21 +381,12 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (events & EV_WRITE) { - MqttPacket *packet; - LOG_DEBUG("socket writable"); - packet = SIMPLEQ_FIRST(&client->sendQueue); - - if (packet) + if (MqttClientSendPacket(client) == -1) { - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (MqttClientSendPacket(client, packet) == -1) - { - LOG_ERROR("MqttClientSendPacket failed"); - client->stopped = 1; - } + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; } } @@ -433,10 +451,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, } int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count) + int *qos, size_t count) { - MqttPacketSubscribe *packet = NULL; + MqttPacket *packet = NULL; size_t i; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(topicFilters != NULL); @@ -444,68 +464,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, assert(qos != NULL); assert(count > 0); - packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( - MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe, + MqttClientNextPacketId(client)); if (!packet) return -1; - packet->topicFilters = bstrListCreate(); - bstrListAllocMin(packet->topicFilters, count); + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(packet->id, pss); - packet->qos = (int *) malloc(sizeof(int) * count); + LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id); for (i = 0; i < count; ++i) { - packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); - ++packet->topicFilters->qty; + struct tagbstring filter; + btfromcstr(filter, topicFilters[i]); + StreamWriteMqttString(&filter, pss); + StreamWriteByte(qos[i] & 3, pss); } - memcpy(packet->qos, qos, sizeof(int) * count); + packet->payload = ss.buffer; - MqttClientQueuePacket(client, (MqttPacket *) packet); - - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) { - MqttPacketUnsubscribe *packet = NULL; + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + struct tagbstring filter; assert(client != NULL); assert(topicFilter != NULL); - packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( - MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe, + MqttClientNextPacketId(client)); - packet->topicFilter = bfromcstr(topicFilter); + if (!packet) + return -1; + + packet->flags = 0x02; + + StringStreamInit(&ss); + + StreamWriteUint16Be(packet->id, pss); + + btfromcstr(filter, topicFilter); + + StreamWriteMqttString(&filter, pss); - MqttClientQueuePacket(client, (MqttPacket *) packet); + packet->payload = ss.buffer; - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int count = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { ++count; } return count; } +static MqttPacket *PublishToPacket(MqttMessage *msg) +{ + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + + if (msg->qos > 0) + { + packet = MqttPacketWithIdNew(MqttPacketTypePublish, + msg->id); + } + else + { + packet = MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return NULL; + + packet->message = msg; + + StringStreamInit(&ss); + + StreamWriteMqttString(msg->topic, pss); + + if (msg->qos > 0) + { + StreamWriteUint16Be(msg->id, pss); + } + + StreamWrite(bdata(msg->payload), blength(msg->payload), pss); + + packet->payload = ss.buffer; + packet->flags = (msg->qos & 3) << 1; + packet->flags |= msg->retain & 1; + + return packet; +} + int MqttClientPublish(MqttClient *client, int qos, int retain, const char *topic, const void *data, size_t size) { - MqttPacketPublish *packet; - - assert(client != NULL); + MqttMessage *message; /* first check if the queue is already full */ if (qos > 0 && client->maxQueued > 0 && @@ -514,55 +588,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain, return -1; } - if (qos > 0) + message = calloc(1, sizeof(*message)); + if (!message) { - packet = (MqttPacketPublish *) MqttPacketWithIdNew( - MqttPacketTypePublish, MqttClientNextPacketId(client)); + return -1; } - else + + message->state = MqttMessageStateQueued; + message->qos = qos; + message->retain = retain; + message->dup = 0; + message->timestamp = MqttGetCurrentTime(); + + if (qos == 0) { - packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); - } + /* Copy payload and topic directly from user buffers as we don't need + to keep the message data around after this function. */ + MqttPacket *packet; + struct tagbstring bttopic, btpayload; - if (!packet) - return -1; + btfromcstr(bttopic, topic); + message->topic = &bttopic; - packet->qos = qos; - packet->retain = retain; - packet->topicName = bfromcstr(topic); - packet->message = blk2bstr(data, size); + btfromblk(btpayload, data, size); + message->payload = &btpayload; - if (qos > 0) - { - /* check how many messages there are coming in and going out currently - that are not yet done */ - if (client->maxInflight == 0 || - MqttClientInflightMessageCount(client) < client->maxInflight) - { - LOG_DEBUG("setting message (%d) state to MessageStateSend", - MqttPacketId(packet)); - packet->base.state = MessageStateSend; - } - else - { - LOG_DEBUG("setting message (%d) state to MessageStateQueued", - MqttPacketId(packet)); - packet->base.state = MessageStateQueued; - } + packet = PublishToPacket(message); + + message->topic = NULL; + message->payload = NULL; + + MqttClientQueuePacket(client, packet); + + MqttMessageFree(message); - /* add the message to the outMessages queue to wait for processing */ - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, - messages); + return 0; } else { - MqttClientQueuePacket(client, (MqttPacket *) packet); - } + /* Duplicate the user buffers as we need the data to be available + longer. */ + message->topic = bfromcstr(topic); + message->payload = blk2bstr(data, size); - if (qos > 0) - return MqttPacketId(packet); + message->id = MqttClientNextPacketId(client); - return 0; + TAILQ_INSERT_TAIL(&client->outMessages, message, chain); + + return message->id; + } } int MqttClientPublishCString(MqttClient *client, int qos, int retain, @@ -655,6 +729,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) { assert(client != NULL); LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + packet->state = MqttPacketStateWriteType; SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); } @@ -667,128 +742,332 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type) return 0; } -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +static int MqttClientSendPacket(MqttClient *client) { - if (MqttPacketSerialize(packet, &client->stream.base) == -1) - return -1; + MqttPacket *packet; - packet->sentAt = MqttGetCurrentTime(); - client->lastPacketSentTime = packet->sentAt; + packet = SIMPLEQ_FIRST(&client->sendQueue); - if (packet->type == MqttPacketTypeDisconnect) + if (!packet) { - client->stopped = 1; + LOG_WARNING("MqttClientSendPacket called with no queued packets"); + return 0; } - /* If the packet is not on any message list, it can be removed after - sending. */ - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) + while (packet != NULL) { - LOG_DEBUG("freeing packet %s after sending", - MqttPacketName(MqttPacketType(packet))); - MqttPacketFree(packet); + switch (packet->state) + { + case MqttPacketStateWriteType: + { + unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) | + (packet->flags & 0x0F); + + if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) + { + return -1; + } + + packet->state = MqttPacketStateWriteRemainingLength; + + break; + } + + case MqttPacketStateWriteRemainingLength: + { + if (StreamWriteRemainingLength(blength(packet->payload), + &client->stream.base) == -1) + { + return -1; + } + + packet->state = MqttPacketStateWritePayload; + + break; + } + + case MqttPacketStateWritePayload: + { + if (packet->payload) + { + if (StreamWrite(bdata(packet->payload), + blength(packet->payload), + &client->stream.base) == -1) + { + return -1; + } + } + + packet->state = MqttPacketStateWriteComplete; + + break; + } + + case MqttPacketStateWriteComplete: + { + client->lastPacketSentTime = MqttGetCurrentTime(); + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + } + + LOG_DEBUG("sent %s", MqttPacketName(packet->type)); + + if (packet->type == MqttPacketTypePublish && packet->message) + { + MqttMessage *msg = packet->message; + + if (msg->qos == 1) + { + msg->state = MqttMessageStateWaitPubAck; + } + else if (msg->qos == 2) + { + msg->state = MqttMessageStateWaitPubRec; + } + } + + if (packet->message) + { + packet->message->timestamp = client->lastPacketSentTime; + } + + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + MqttPacketFree(packet); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + break; + } + } } return 0; } -static void MqttClientHandleConnAck(MqttClient *client, - MqttPacketConnAck *packet) +static int MqttClientHandleConnAck(MqttClient *client) { - client->sessionPresent = packet->connAckFlags & 1; + StringStream ss; + Stream *pss = (Stream *) &ss; + unsigned char flags; + unsigned char rc; + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadByte(&flags, pss); + + StreamReadByte(&rc, pss); + + client->sessionPresent = flags & 1; LOG_DEBUG("sessionPresent:%d", client->sessionPresent); if (client->onConnect) { - LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); - client->onConnect(client, packet->returnCode, client->sessionPresent); + LOG_DEBUG("calling onConnect rc:%d", rc); + client->onConnect(client, rc, client->sessionPresent); } + + return 0; } -static void MqttClientHandlePingResp(MqttClient *client) +static int MqttClientHandlePingResp(MqttClient *client) { LOG_DEBUG("got ping response"); client->pingSent = 0; + return 0; } -static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +static int MqttClientHandleSubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + int *qos; + StringStream ss; + Stream *pss = (Stream *) &ss; + int count; + int i; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + LOG_DEBUG("received SUBACK with id:%d", (int) id); + + count = blength(client->inPacket.payload) - StreamTell(pss); + + if (count <= 0) { - if (MqttPacketType(sub) == MqttPacketTypeSubscribe && - MqttPacketId(sub) == MqttPacketId(packet)) - { - break; - } + LOG_ERROR("number of return codes invalid"); + return -1; } - if (!sub) + qos = malloc(count * sizeof(int)); + + for (i = 0; i < count; ++i) { - LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + unsigned char byte; + StreamReadByte(&byte, pss); + qos[i] = byte; } - else + + if (client->onSubscribe) { - if (client->onSubscribe) - { - MqttPacketSubscribe *sub2; - int i; + client->onSubscribe(client, id, qos, count); + } - sub2 = (MqttPacketSubscribe *) sub; + free(qos); - for (i = 0; i < sub2->topicFilters->qty; ++i) - { - const char *filter = bdata(sub2->topicFilters->entry[i]); - int rc = packet->returnCode[i]; + return 0; +} - LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", - MqttPacketId(packet), filter, rc); +static int MqttClientSendPubAck(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; - client->onSubscribe(client, MqttPacketId(packet), filter, rc); - } - } + packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id); - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - } + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id); + + if (!packet) + return -1; + + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubComp(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; } -static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +static int MqttClientHandlePublish(MqttClient *client) { + MqttMessage *msg; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; + MqttPacket *packet; + int qos; + int retain; + bstring topic; + void *payload; + int payloadSize; + + /* We are paused - do nothing */ if (client->paused) - return; + return 0; + + packet = &client->inPacket; + + qos = (packet->flags >> 1) & 3; + retain = packet->flags & 1; + + StringStreamInitFromBstring(&ss, packet->payload); + + StreamReadMqttString(&topic, pss); + + StreamReadUint16Be(&id, pss); + + payload = bdataofs(ss.buffer, ss.pos); + payloadSize = blength(ss.buffer) - ss.pos; - if (MqttPacketPublishQos(packet) == 2) + if (qos == 2) { /* Check if we have sent a PUBREC previously with the same id. If we have, we have to resend the PUBREC. We must not call the onMessage callback again. */ - MqttPacket *pubRec; - - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (pubRec) + if (msg) { - LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); - MqttClientQueuePacket(client, pubRec); - return; + LOG_DEBUG("resending PUBREC id:%u", msg->id); + MqttClientSendPubRec(client, msg); + bdestroy(topic); + return 0; } } @@ -796,268 +1075,347 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe { LOG_DEBUG("calling onMessage"); client->onMessage(client, - bdata(packet->topicName), - bdata(packet->message), - blength(packet->message), - packet->qos, - packet->retain); + bdata(topic), + payload, + payloadSize, + qos, + retain); } - if (MqttPacketPublishQos(packet) > 0) + bdestroy(topic); + + if (qos == 1) + { + MqttClientSendPubAck(client, id); + } + else if (qos == 2) { - int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : - MqttPacketTypePubRec; + msg = calloc(1, sizeof(*msg)); - MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + msg->state = MqttMessageStateWaitPubRel; + msg->id = id; + msg->qos = qos; - if (MqttPacketPublishQos(packet) == 2) - { - /* append to inMessages as we need a reply to this response */ - TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); - } + TAILQ_INSERT_TAIL(&client->inMessages, msg, chain); - MqttClientQueuePacket(client, resp); + MqttClientSendPubRec(client, msg); } + + return 0; } -static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubAck(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubAck) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); - if (client->onPublish) - { - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + if (client->onPublish) + { + client->onPublish(client, msg->id); } + + MqttMessageFree(msg); + + return 0; } -static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRec(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but + they haven't received it. */ + if (msg->id == id && + (msg->state == MqttMessageStateWaitPubRec || + msg->state == MqttMessageStateWaitPubComp)) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubRel; - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); + msg->state = MqttMessageStateWaitPubComp; - pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); - pubRel->state = MessageStateSend; + bdestroy(msg->payload); + msg->payload = NULL; - TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); - } + bdestroy(msg->topic); + msg->topic = NULL; + + if (MqttClientSendPubRel(client, msg) == -1) + return -1; + + return 0; } -static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRel(MqttClient *client) { - MqttPacket *pubRec; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (!pubRec) + if (!msg) { - MqttPacket *pubComp; - - TAILQ_FOREACH(pubComp, &client->inMessages, messages) - { - if (MqttPacketId(pubComp) == MqttPacketId(packet) && - MqttPacketType(pubComp) == MqttPacketTypePubComp) - { - break; - } - } - - if (pubComp) - { - MqttClientQueuePacket(client, pubComp); - } - else - { - LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; - } + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubComp; - - TAILQ_REMOVE(&client->inMessages, pubRec, messages); - MqttPacketFree(pubRec); - pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, - MqttPacketId(packet)); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); - TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + if (MqttClientSendPubComp(client, id) == -1) + return -1; - MqttClientQueuePacket(client, pubComp); - } + return 0; } -static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubComp(MqttClient *client) { - MqttPacket *pubRel; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); - TAILQ_FOREACH(pubRel, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pubRel) == MqttPacketId(packet) && - MqttPacketType(pubRel) == MqttPacketTypePubRel) + if (msg->id == id && msg->state == MqttMessageStateWaitPubComp) { break; } } - if (!pubRel) + if (!msg) { - LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_WARNING("no message found with id %d", (int) id); + return 0; } - else - { - TAILQ_REMOVE(&client->outMessages, pubRel, messages); - MqttPacketFree(pubRel); - if (client->onPublish) - { - LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + MqttMessageFree(msg); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", id); + client->onPublish(client, id); } + + return 0; } -static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandleUnsubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) - { - if (MqttPacketId(sub) == MqttPacketId(packet) && - MqttPacketType(sub) == MqttPacketTypeUnsubscribe) - { - break; - } - } + StringStreamInitFromBstring(&ss, client->inPacket.payload); - if (!sub) + StreamReadUint16Be(&id, pss); + + if (client->onUnsubscribe) { - LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + client->onUnsubscribe(client, id); } - else - { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onUnsubscribe) - { - LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); - client->onUnsubscribe(client, MqttPacketId(packet)); - } - } + return 0; } -static int MqttClientRecvPacket(MqttClient *client) +static int MqttClientHandlePacket(MqttClient *client) { - MqttPacket *packet = NULL; - - if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) - return -1; - - LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); + int rc; - switch (MqttPacketType(packet)) + switch (client->inPacket.type) { case MqttPacketTypeConnAck: - MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + rc = MqttClientHandleConnAck(client); break; case MqttPacketTypePingResp: - MqttClientHandlePingResp(client); + rc = MqttClientHandlePingResp(client); break; case MqttPacketTypeSubAck: - MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + rc = MqttClientHandleSubAck(client); break; - case MqttPacketTypePublish: - MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + case MqttPacketTypeUnsubAck: + rc = MqttClientHandleUnsubAck(client); break; case MqttPacketTypePubAck: - MqttClientHandlePubAck(client, packet); + rc = MqttClientHandlePubAck(client); break; case MqttPacketTypePubRec: - MqttClientHandlePubRec(client, packet); + rc = MqttClientHandlePubRec(client); break; - case MqttPacketTypePubRel: - MqttClientHandlePubRel(client, packet); + case MqttPacketTypePubComp: + rc = MqttClientHandlePubComp(client); break; - case MqttPacketTypePubComp: - MqttClientHandlePubComp(client, packet); + case MqttPacketTypePubRel: + rc = MqttClientHandlePubRel(client); break; - case MqttPacketTypeUnsubAck: - MqttClientHandleUnsubAck(client, packet); + case MqttPacketTypePublish: + rc = MqttClientHandlePublish(client); break; default: - LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + LOG_ERROR("packet not handled yet"); + rc = -1; break; } - MqttPacketFree(packet); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + + client->inPacket.state = MqttPacketStateReadType; + + return rc; +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + while (1) + { + switch (client->inPacket.state) + { + case MqttPacketStateReadType: + { + unsigned char typeAndFlags; + int rc; + + if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1) + { + LOG_ERROR("failed reading packet type: %d", rc); + return -1; + } + + client->inPacket.type = typeAndFlags >> 4; + client->inPacket.flags = typeAndFlags & 0x0F; + + if (client->inPacket.type < MqttPacketTypeConnect || + client->inPacket.type > MqttPacketTypeDisconnect) + { + LOG_ERROR("unknown packet type: %d", client->inPacket.type); + return -1; + } + + client->inPacket.state = MqttPacketStateReadRemainingLength; + + break; + } + + case MqttPacketStateReadRemainingLength: + { + if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->stream.base) == -1) + { + LOG_ERROR("failed to read remaining length"); + return -1; + } + client->inPacket.state = MqttPacketStateReadPayload; + break; + } + + case MqttPacketStateReadPayload: + { + if (client->inPacket.remainingLength > 0) + { + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + if (StreamRead(bdata(client->inPacket.payload), + client->inPacket.remainingLength, + &client->stream.base) == -1) + { + LOG_ERROR("failed reading packet payload"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + client->inPacket.payload->slen = client->inPacket.remainingLength; + } + client->inPacket.state = MqttPacketStateReadComplete; + break; + } + + case MqttPacketStateReadComplete: + { + int type = client->inPacket.type; + LOG_DEBUG("received %s", MqttPacketName(type)); + return MqttClientHandlePacket(client); + } + } + } return 0; } @@ -1072,101 +1430,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client) return id; } -static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +static int64_t MqttMessageTimeSinceSent(MqttMessage *msg) { int64_t now = MqttGetCurrentTime(); - return now - packet->sentAt; + return now - msg->timestamp; } -static void MqttClientProcessInMessages(MqttClient *client) +static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg) { - MqttPacket *packet, *next; - - LOG_DEBUG("processing inMessages"); - - TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + if (msg->timestamp > 0 && + MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000) { - LOG_DEBUG("packet type:%s id:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - - if (MqttPacketType(packet) == MqttPacketTypePubComp) - { - int64_t elapsed = MqttPacketTimeSinceSent(packet); - if (packet->sentAt > 0 && - elapsed >= client->retryTimeout*1000) - { - LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, - MqttPacketId(packet), elapsed); - - TAILQ_REMOVE(&client->inMessages, packet, messages); - - MqttPacketFree(packet); - } - } + return 1; } + + return 0; } -static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +static void MqttClientProcessInMessages(MqttClient *client) { - if (packet->sentAt > 0 && - MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + MqttMessage *msg, *next; + + TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next) { - return 1; - } + switch (msg->state) + { + case MqttMessageStateWaitPubRel: + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRec(client, msg); + } + break; - return 0; + default: + break; + } + } } static void MqttClientProcessOutMessages(MqttClient *client) { - MqttPacket *packet, *next; + MqttMessage *msg, *next; + MqttPacket *packet; int inflight = MqttClientInflightMessageCount(client); - LOG_DEBUG("processing outMessages inflight:%d", inflight); - - TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next) { - LOG_DEBUG("packet type:%s id:%d state:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet), - packet->state); - - switch (packet->state) + switch (msg->state) { - case MessageStateQueued: + case MqttMessageStateQueued: + { if (inflight >= client->maxInflight) { - LOG_DEBUG("cannot dequeue %s/%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - break; - } - else - { - /* If there's less than maxInflight messages currently - inflight, we can dequeue some messages by falling - through to MessageStateSend. */ - LOG_DEBUG("dequeuing %s (%d)", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - ++inflight; + continue; } - - case MessageStateSend: - packet->state = MessageStateSent; + /* State change from MqttMessageStatePublish happens after + the packet has been sent (in MqttClientSendPacket). */ + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); MqttClientQueuePacket(client, packet); + ++inflight; break; + } - case MessageStateSent: - if (MqttPacketShouldResend(client, packet)) + case MqttMessageStateWaitPubAck: + case MqttMessageStateWaitPubRec: + { + if (MqttMessageShouldResend(client, msg)) { - packet->state = MessageStateSend; + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); + MqttClientQueuePacket(client, packet); } break; + } - default: + case MqttMessageStateWaitPubComp: + { + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRel(client, msg); + } break; + } } } } @@ -1182,30 +1528,22 @@ static void MqttClientClearQueues(MqttClient *client) while (!SIMPLEQ_EMPTY(&client->sendQueue)) { MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue); - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) - { - MqttPacketFree(packet); - } + MqttPacketFree(packet); } while (!TAILQ_EMPTY(&client->outMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->outMessages); - TAILQ_REMOVE(&client->outMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->outMessages); + TAILQ_REMOVE(&client->outMessages, msg, chain); + MqttMessageFree(msg); } while (!TAILQ_EMPTY(&client->inMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->inMessages); - TAILQ_REMOVE(&client->inMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->inMessages); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); } } diff --git a/src/deserialize.c b/src/deserialize.c deleted file mode 100644 index 96d7789..0000000 --- a/src/deserialize.c +++ /dev/null @@ -1,286 +0,0 @@ -#include "deserialize.h" -#include "packet.h" -#include "stream_mqtt.h" -#include "log.h" - -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream); - -static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1) - return -1; - - if (StreamRead(&(*packet)->returnCode, 1, stream) != 1) - return -1; - - return 0; -} - -static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t i; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - return -1; - - remainingLength -= 2; - - (*packet)->returnCode = (unsigned char *) malloc( - sizeof(*(*packet)->returnCode) * remainingLength); - - for (i = 0; i < remainingLength; ++i) - { - if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t payloadSize = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadMqttString(&(*packet)->topicName, stream) == -1) - return -1; - - LOG_DEBUG("remainingLength:%lu", remainingLength); - - payloadSize = remainingLength - blength((*packet)->topicName) - 2; - - LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet), - payloadSize); - - if (MqttPacketHasId((const MqttPacket *) *packet)) - { - LOG_DEBUG("packet has id"); - payloadSize -= 2; - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - { - return -1; - } - } - - LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize); - - /* Allocate extra byte for a NULL terminator. If the user tries to print - the payload directly. */ - - (*packet)->message = bfromcstralloc(payloadSize+1, ""); - - if (StreamRead(bdata((*packet)->message), payloadSize, stream) == -1) - return -1; - - (*packet)->message->slen = payloadSize; - (*packet)->message->data[payloadSize] = '\0'; - - return 0; -} - -static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - char buffer[256]; - - (void) packet; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - while (remainingLength > 0) - { - size_t l = sizeof(buffer); - - if (remainingLength < l) - l = remainingLength; - - if (StreamRead(buffer, l, stream) != (int64_t) l) - return -1; - - remainingLength -= l; - } - - return 0; -} - -static int ValidateFlags(int type, int flags) -{ - int rv = 0; - - switch (type) - { - case MqttPacketTypePublish: - { - int qos = (flags >> 1) & 2; - if (qos >= 0 && qos <= 2) - rv = 1; - break; - } - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - if (flags == 2) - { - rv = 1; - } - break; - - default: - if (flags == 0) - { - rv = 1; - } - break; - } - - return rv; -} - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream) -{ - MqttPacketDeserializeFunc deserializer = NULL; - char typeAndFlags; - int type; - int flags; - int rv; - - if (StreamRead(&typeAndFlags, 1, stream) != 1) - return -1; - - type = (typeAndFlags & 0xF0) >> 4; - flags = (typeAndFlags & 0x0F); - - if (!ValidateFlags(type, flags)) - { - return -1; - } - - switch (type) - { - case MqttPacketTypeConnect: - break; - - case MqttPacketTypeConnAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize; - break; - - case MqttPacketTypePublish: - deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize; - break; - - case MqttPacketTypePubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRec: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRel: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubComp: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypeSubscribe: - break; - - case MqttPacketTypeSubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize; - break; - - case MqttPacketTypeUnsubscribe: - break; - - case MqttPacketTypeUnsubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize; - break; - - case MqttPacketTypePingReq: - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - break; - - default: - return -1; - } - - if (!deserializer) - { - deserializer = MqttPacketGenericDeserializer; - } - - *packet = MqttPacketNew(type); - - if (!*packet) - return -1; - - if (type == MqttPacketTypePublish) - { - MqttPacketPublishDup(*packet) = (flags >> 3) & 1; - MqttPacketPublishQos(*packet) = (flags >> 1) & 3; - MqttPacketPublishRetain(*packet) = flags & 1; - } - - rv = deserializer(packet, stream); - - return rv; -} diff --git a/src/deserialize.h b/src/deserialize.h deleted file mode 100644 index 8c29b3d..0000000 --- a/src/deserialize.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef DESERIALIZE_H -#define DESERIALIZE_H - -#include "config.h" - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream); - -#endif @@ -33,8 +33,8 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client, typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client, int id, - const char *topicFilter, - MqttSubscriptionStatus status); + int *qos, + int count); typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id); diff --git a/src/packet.c b/src/packet.c index 47aa689..c833851 100644 --- a/src/packet.c +++ b/src/packet.c @@ -28,42 +28,16 @@ const char *MqttPacketName(int type) } } -static MQTT_INLINE size_t MqttPacketStructSize(int type) -{ - switch (type) - { - case MqttPacketTypeConnect: return sizeof(MqttPacketConnect); - case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck); - case MqttPacketTypePublish: return sizeof(MqttPacketPublish); - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: return sizeof(MqttPacket); - case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe); - case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck); - case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe); - case MqttPacketTypeUnsubAck: return sizeof(MqttPacket); - case MqttPacketTypePingReq: return sizeof(MqttPacket); - case MqttPacketTypePingResp: return sizeof(MqttPacket); - case MqttPacketTypeDisconnect: return sizeof(MqttPacket); - default: return (size_t) -1; - } -} - MqttPacket *MqttPacketNew(int type) { MqttPacket *packet = NULL; - packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type)); + packet = (MqttPacket *) calloc(1, sizeof(*packet)); if (!packet) return NULL; packet->type = type; - /* this will make sure that TAILQ_PREV does not segfault if a message - has not been added to a list at any point */ - packet->messages.tqe_prev = &packet->messages.tqe_next; - return packet; } @@ -78,52 +52,6 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id) void MqttPacketFree(MqttPacket *packet) { - if (MqttPacketType(packet) == MqttPacketTypeConnect) - { - MqttPacketConnect *p = (MqttPacketConnect *) packet; - bdestroy(p->clientId); - bdestroy(p->willTopic); - bdestroy(p->willMessage); - bdestroy(p->userName); - bdestroy(p->password); - } - else if (MqttPacketType(packet) == MqttPacketTypePublish) - { - MqttPacketPublish *p = (MqttPacketPublish *) packet; - bdestroy(p->topicName); - bdestroy(p->message); - } - else if (MqttPacketType(packet) == MqttPacketTypeSubscribe) - { - MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet; - bstrListDestroy(p->topicFilters); - } - else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe) - { - MqttPacketUnsubscribe *p = (MqttPacketUnsubscribe *) packet; - bdestroy(p->topicFilter); - } + bdestroy(packet->payload); free(packet); } - -int MqttPacketHasId(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return MqttPacketPublishQos(packet) > 0; - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - case MqttPacketTypeSubscribe: - case MqttPacketTypeSubAck: - case MqttPacketTypeUnsubscribe: - case MqttPacketTypeUnsubAck: - return 1; - - default: - return 0; - } -} diff --git a/src/packet.h b/src/packet.h index 7ab4f73..36dc81f 100644 --- a/src/packet.h +++ b/src/packet.h @@ -29,87 +29,33 @@ enum MqttPacketTypeDisconnect = 0xE }; +enum MqttPacketState +{ + MqttPacketStateReadType, + MqttPacketStateReadRemainingLength, + MqttPacketStateReadPayload, + MqttPacketStateReadComplete, + + MqttPacketStateWriteType, + MqttPacketStateWriteRemainingLength, + MqttPacketStateWritePayload, + MqttPacketStateWriteComplete +}; + +struct MqttMessage; + typedef struct MqttPacket MqttPacket; struct MqttPacket { int type; - uint16_t id; - int state; int flags; - int64_t sentAt; + int state; + uint16_t id; + size_t remainingLength; + bstring payload; + struct MqttMessage *message; SIMPLEQ_ENTRY(MqttPacket) sendQueue; - TAILQ_ENTRY(MqttPacket) messages; -}; - -#define MqttPacketType(packet) (((MqttPacket *) (packet))->type) - -#define MqttPacketId(packet) (((MqttPacket *) (packet))->id) - -#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt) - -typedef struct MqttPacketConnect MqttPacketConnect; - -struct MqttPacketConnect -{ - MqttPacket base; - char connectFlags; - uint16_t keepAlive; - bstring clientId; - bstring willTopic; - bstring willMessage; - bstring userName; - bstring password; -}; - -typedef struct MqttPacketConnAck MqttPacketConnAck; - -struct MqttPacketConnAck -{ - MqttPacket base; - unsigned char connAckFlags; - unsigned char returnCode; -}; - -typedef struct MqttPacketPublish MqttPacketPublish; - -struct MqttPacketPublish -{ - MqttPacket base; - bstring topicName; - bstring message; - char qos; - char dup; - char retain; -}; - -#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos) -#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup) -#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain) - -typedef struct MqttPacketSubscribe MqttPacketSubscribe; - -struct MqttPacketSubscribe -{ - MqttPacket base; - struct bstrList *topicFilters; - int *qos; -}; - -typedef struct MqttPacketSubAck MqttPacketSubAck; - -struct MqttPacketSubAck -{ - MqttPacket base; - unsigned char *returnCode; -}; - -typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe; - -struct MqttPacketUnsubscribe -{ - MqttPacket base; - bstring topicFilter; }; const char *MqttPacketName(int type); @@ -120,6 +66,4 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id); void MqttPacketFree(MqttPacket *packet); -int MqttPacketHasId(const MqttPacket *packet); - #endif diff --git a/src/serialize.c b/src/serialize.c deleted file mode 100644 index c1c8eb4..0000000 --- a/src/serialize.c +++ /dev/null @@ -1,326 +0,0 @@ -#include "serialize.h" -#include "packet.h" -#include "stream_mqtt.h" -#include "log.h" - -#include <bstrlib/bstrlib.h> - -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet, - Stream *stream); - -static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); -static const char MqttProtocolLevel = 0x04; - -static MQTT_INLINE size_t MqttStringLengthSerialized(const_bstring s) -{ - return 2 + blength(s); -} - -static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2; - - remainingLength += MqttStringLengthSerialized(packet->clientId); - - if (packet->connectFlags & 0x80) - remainingLength += MqttStringLengthSerialized(packet->userName); - - if (packet->connectFlags & 0x40) - remainingLength += MqttStringLengthSerialized(packet->password); - - if (packet->connectFlags & 0x04) - remainingLength += MqttStringLengthSerialized(packet->willTopic) + - MqttStringLengthSerialized(packet->willMessage); - - return remainingLength; -} - -static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet) -{ - size_t remaining = 2; - int i; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]); - remaining += 1; - } - - return remaining; -} - -static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet) -{ - return 2 + MqttStringLengthSerialized(packet->topicFilter); -} - -static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(packet->topicName); - - /* Packet id */ - if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2) - { - remainingLength += 2; - } - - remainingLength += blength(packet->message); - - return remainingLength; -} - -static size_t MqttPacketGetRemainingLength(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypeConnect: - return MqttPacketConnectGetRemainingLength( - (MqttPacketConnect *) packet); - - case MqttPacketTypeSubscribe: - return MqttPacketSubscribeGetRemainingLength( - (MqttPacketSubscribe *) packet); - - case MqttPacketTypePublish: - return MqttPacketPublishGetRemainingLength( - (MqttPacketPublish *) packet); - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - return 2; - - case MqttPacketTypeUnsubscribe: - return MqttPacketUnsubscribeGetRemainingLength( - (MqttPacketUnsubscribe *) packet); - - default: - return 0; - } -} - -static int MqttPacketFlags(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return ((MqttPacketPublishDup(packet) & 1) << 3) | - ((MqttPacketPublishQos(packet) & 3) << 1) | - (MqttPacketPublishRetain(packet) & 1); - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - return 0x2; - - default: - return 0; - } -} - -static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream) -{ - unsigned char typeAndFlags; - size_t remainingLength; - - typeAndFlags = ((packet->type & 0x0F) << 4) | - (MqttPacketFlags(packet) & 0x0F); - remainingLength = MqttPacketGetRemainingLength(packet); - - LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type, - MqttPacketName(packet->type), MqttPacketFlags(packet)); - - if (StreamWrite(&typeAndFlags, 1, stream) != 1) - return -1; - - if (StreamWriteRemainingLength(remainingLength, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream) -{ - assert(MqttPacketHasId((const MqttPacket *) packet)); - - if (MqttPacketBaseSerialize(packet, stream) == -1) - return -1; - - if (StreamWriteUint16Be(packet->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize(&packet->base, stream) == -1) - return -1; - - if (StreamWriteMqttString(&MqttProtocolId, stream) == -1) - return -1; - - if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1) - return -1; - - if (StreamWrite(&packet->connectFlags, 1, stream) != 1) - return -1; - - if (StreamWriteUint16Be(packet->keepAlive, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->clientId, stream) == -1) - return -1; - - if (packet->connectFlags & 0x04) - { - if (StreamWriteMqttString(packet->willTopic, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->willMessage, stream) == -1) - return -1; - } - - if (packet->connectFlags & 0x80) - { - if (StreamWriteMqttString(packet->userName, stream) == -1) - return -1; - - if (packet->connectFlags & 0x40) - { - if (StreamWriteMqttString(packet->password, stream) == -1) - return -1; - } - } - - return 0; -} - -static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream) -{ - int i; - - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - unsigned char qos = (unsigned char) packet->qos[i]; - - if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1) - return -1; - - if (StreamWrite(&qos, 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream) -{ - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicFilter, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicName, stream) == -1) - return -1; - - LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet)); - - if (MqttPacketPublishQos(packet) > 0) - { - if (StreamWriteUint16Be(packet->base.id, stream) == -1) - return -1; - } - - if (StreamWrite(bdata(packet->message), blength(packet->message), stream) == -1) - return -1; - - return 0; -} - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream) -{ - MqttPacketSerializeFunc f = NULL; - - switch (packet->type) - { - case MqttPacketTypeConnect: - f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize; - break; - - case MqttPacketTypeConnAck: - break; - - case MqttPacketTypePublish: - f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize; - break; - - case MqttPacketTypePubAck: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRec: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRel: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubComp: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypeSubscribe: - f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize; - break; - - case MqttPacketTypeSubAck: - break; - - case MqttPacketTypeUnsubscribe: - f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize; - break; - - case MqttPacketTypeUnsubAck: - break; - - case MqttPacketTypePingReq: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - default: - return -1; - } - - assert(f != NULL && "no serializer"); - - return f(packet, stream); -} diff --git a/src/serialize.h b/src/serialize.h deleted file mode 100644 index 7eb988f..0000000 --- a/src/serialize.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef SERIALIZE_H -#define SERIALIZE_H - -#include "config.h" - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream); - -#endif diff --git a/src/stream.c b/src/stream.c index fd154a1..1c46668 100644 --- a/src/stream.c +++ b/src/stream.c @@ -47,6 +47,11 @@ int64_t StreamReadUint16Be(uint16_t *v, Stream *stream) return 2; } +int64_t StreamReadByte(unsigned char *byte, Stream *stream) +{ + return StreamRead(byte, sizeof(*byte), stream); +} + int64_t StreamWrite(const void *ptr, size_t size, Stream *stream) { STREAM_CHECK_OP(stream, write); @@ -65,6 +70,11 @@ int64_t StreamWriteUint16Be(uint16_t v, Stream *stream) return StreamWrite(data, sizeof(data), stream); } +int64_t StreamWriteByte(unsigned char byte, Stream *stream) +{ + return StreamWrite(&byte, sizeof(byte), stream); +} + int StreamSeek(Stream *stream, int64_t offset, int whence) { STREAM_CHECK_OP(stream, seek); diff --git a/src/stream.h b/src/stream.h index 839facb..50f1772 100644 --- a/src/stream.h +++ b/src/stream.h @@ -27,9 +27,11 @@ int StreamClose(Stream *stream); int64_t StreamRead(void *ptr, size_t size, Stream *stream); int64_t StreamReadUint16Be(uint16_t *v, Stream *stream); +int64_t StreamReadByte(unsigned char *byte, Stream *stream); int64_t StreamWrite(const void *ptr, size_t size, Stream *stream); int64_t StreamWriteUint16Be(uint16_t v, Stream *stream); +int64_t StreamWriteByte(unsigned char byte, Stream *stream); int StreamSeek(Stream *stream, int64_t offset, int whence); diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h index 9023430..a128524 100644 --- a/src/stream_mqtt.h +++ b/src/stream_mqtt.h @@ -2,6 +2,7 @@ #define STREAM_MQTT_H #include "stream.h" +#include "stringstream.h" #include <bstrlib/bstrlib.h> diff --git a/test/interop/CMakeLists.txt b/test/interop/CMakeLists.txt index e907776..d4b43d7 100644 --- a/test/interop/CMakeLists.txt +++ b/test/interop/CMakeLists.txt @@ -17,3 +17,6 @@ ADD_INTEROP_TEST(keepalive_test) ADD_INTEROP_TEST(redelivery_on_reconnect_test) ADD_INTEROP_TEST(subscribe_failure_test) ADD_INTEROP_TEST(dollar_topics_test) +ADD_INTEROP_TEST(username_and_password_test) +ADD_INTEROP_TEST(ping_test) +ADD_INTEROP_TEST(unsubscribe_test) diff --git a/test/interop/ping_test.c b/test/interop/ping_test.c new file mode 100644 index 0000000..6d699da --- /dev/null +++ b/test/interop/ping_test.c @@ -0,0 +1,27 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" + +TEST ping_test() +{ + TestClient *client; + + client = TestClientNew("clienta"); + ASSERT(TestClientConnect(client, "localhost", 1883, 1, 1)); + ASSERT(TestClientWait(client, 5000)); + TestClientDisconnect(client); + TestClientFree(client); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(ping_test); + GREATEST_MAIN_END(); +} diff --git a/test/interop/testclient.c b/test/interop/testclient.c index 8d616f6..09782b2 100644 --- a/test/interop/testclient.c +++ b/test/interop/testclient.c @@ -14,12 +14,15 @@ static void TestClientOnConnect(MqttClient *client, } static void TestClientOnSubscribe(MqttClient *client, int id, - const char *filter, - MqttSubscriptionStatus status) + int *qos, int count) { TestClient *testClient = (TestClient *) MqttClientGetUserData(client); testClient->subId = id; - testClient->subStatus[testClient->subCount++] = status; + for (testClient->subCount = 0; testClient->subCount < count; + ++testClient->subCount) + { + testClient->subStatus[testClient->subCount] = qos[testClient->subCount]; + } } static void TestClientOnPublish(MqttClient *client, int id) @@ -37,6 +40,12 @@ static void TestClientOnMessage(MqttClient *client, const char *topic, SIMPLEQ_INSERT_TAIL(&testClient->messages, msg, chain); } +static void TestClientOnUnsubscribe(MqttClient *client, int id) +{ + TestClient *testClient = (TestClient *) MqttClientGetUserData(client); + testClient->unsubId = id; +} + Message *MessageNew(const char *topic, const void *data, size_t size, int qos, int retain) { @@ -69,6 +78,8 @@ TestClient *TestClientNew(const char *clientId) { TestClient *client = calloc(1, sizeof(*client)); + client->clientId = clientId; + client->client = MqttClientNew(clientId); MqttClientSetUserData(client->client, client); @@ -79,6 +90,7 @@ TestClient *TestClientNew(const char *clientId) MqttClientSetOnSubscribe(client->client, TestClientOnSubscribe); MqttClientSetOnPublish(client->client, TestClientOnPublish); MqttClientSetOnMessage(client->client, TestClientOnMessage); + MqttClientSetOnUnsubscribe(client->client, TestClientOnUnsubscribe); SIMPLEQ_INIT(&client->messages); @@ -235,8 +247,8 @@ int TestClientWait(TestClient *client, int timeout) printf("TestClientWait timeout:%d rc:%d\n", timeout, rc); int64_t now = MqttGetCurrentTime(); int64_t elapsed = now - start; - timeout -= elapsed; printf("TestClientWait elapsed:%d\n", (int) elapsed); + timeout = timeout - elapsed; if (timeout <= 0) { break; @@ -245,3 +257,27 @@ int TestClientWait(TestClient *client, int timeout) return rc != -1; } + +int TestClientUnsubscribe(TestClient *client, const char *topic) +{ + int id = MqttClientUnsubscribe(client->client, topic); + int rc; + + client->unsubId = -1; + + while ((rc = MqttClientRunOnce(client->client, -1)) != -1) + { + if (client->unsubId != -1) + { + if (client->unsubId != id) + { + printf( + "WARNING: unsubscribe id mismatch: expected %d, got %d\n", + id, client->unsubId); + } + break; + } + } + + return rc != -1; +} diff --git a/test/interop/testclient.h b/test/interop/testclient.h index 2aa229b..3665f5e 100644 --- a/test/interop/testclient.h +++ b/test/interop/testclient.h @@ -21,6 +21,8 @@ typedef struct TestClient TestClient; struct TestClient { + const char *clientId; + MqttClient *client; /* OnConnect */ @@ -37,6 +39,9 @@ struct TestClient /* OnMessage */ SIMPLEQ_HEAD(messages, Message) messages; + + /* OnUnsubscribe */ + int unsubId; }; Message *MessageNew(const char *topic, const void *data, size_t size, @@ -65,4 +70,6 @@ int TestClientMessageCount(TestClient *client); int TestClientWait(TestClient *client, int timeout); +int TestClientUnsubscribe(TestClient *client, const char *topic); + #endif diff --git a/test/interop/unsubscribe_test.c b/test/interop/unsubscribe_test.c new file mode 100644 index 0000000..a7e4668 --- /dev/null +++ b/test/interop/unsubscribe_test.c @@ -0,0 +1,31 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" + +TEST unsubscribe_test() +{ + TestClient *client; + + client = TestClientNew("clienta"); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + ASSERT(TestClientSubscribe(client, topics[0], 2)); + ASSERT(TestClientPublish(client, 2, 0, topics[0], "msg")); + ASSERT(TestClientUnsubscribe(client, topics[0])); + ASSERT(TestClientPublish(client, 2, 0, topics[0], "msg")); + TestClientDisconnect(client); + ASSERT_EQ(1, TestClientMessageCount(client)); + TestClientFree(client); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(unsubscribe_test); + GREATEST_MAIN_END(); +} diff --git a/test/interop/username_and_password_test.c b/test/interop/username_and_password_test.c new file mode 100644 index 0000000..6e0eaab --- /dev/null +++ b/test/interop/username_and_password_test.c @@ -0,0 +1,30 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" + +TEST username_and_password_test() +{ + TestClient *client; + + client = TestClientNew("clienta"); + ASSERT_EQ(0, MqttClientSetAuth(client->client, "myusername", NULL)); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + TestClientDisconnect(client); + ASSERT_EQ(0, MqttClientSetAuth(client->client, "myusername", "mypassword")); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + TestClientDisconnect(client); + TestClientFree(client); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(username_and_password_test); + GREATEST_MAIN_END(); +} diff --git a/tools/sub.c b/tools/sub.c index ebf3372..e0577c9 100644 --- a/tools/sub.c +++ b/tools/sub.c @@ -21,11 +21,10 @@ void onConnect(MqttClient *client, MqttConnectionStatus status, MqttClientSubscribe(client, options->topic, options->qos); } -void onSubscribe(MqttClient *client, int id, const char *filter, - MqttSubscriptionStatus status) +void onSubscribe(MqttClient *client, int id, int *qos, int count) { (void) client; - printf("onSubscribe id=%d status=%d\n", id, status); + printf("onSubscribe id=%d\n", id); } void onMessage(MqttClient *client, const char *topic, const void *data, |
