diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-06 21:31:07 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-06 21:31:07 +0200 |
| commit | a062d934aef40829d9559a8ca83147ea4c44108e (patch) | |
| tree | 6ee1f9eb3208bea65365e63d039ad27c40111c6f /src/client.c | |
| parent | e9958e8a0f5aa5fbe0a4a03be42b8bf640add6f7 (diff) | |
| download | mqtt-a062d934aef40829d9559a8ca83147ea4c44108e.tar.gz mqtt-a062d934aef40829d9559a8ca83147ea4c44108e.zip | |
Massive refactoring of the internals
Diffstat (limited to 'src/client.c')
| -rw-r--r-- | src/client.c | 1160 |
1 files changed, 749 insertions, 411 deletions
diff --git a/src/client.c b/src/client.c index a6b0998..9238555 100644 --- a/src/client.c +++ b/src/client.c @@ -5,10 +5,11 @@ #include "socketstream.h" #include "socket.h" #include "misc.h" -#include "serialize.h" -#include "deserialize.h" #include "log.h" #include "private.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "message.h" #include "queue.h" @@ -25,9 +26,6 @@ #error define PRId64 for your platform #endif -TAILQ_HEAD(MessageList, MqttPacket); -typedef struct MessageList MessageList; - struct MqttClient { SocketStream stream; @@ -56,9 +54,9 @@ struct MqttClient /* packets waiting to be sent over network */ SIMPLEQ_HEAD(, MqttPacket) sendQueue; /* sent messages that are not done yet */ - MessageList outMessages; + MqttMessageList outMessages; /* received messages that are not done yet */ - MessageList inMessages; + MqttMessageList inMessages; int sessionPresent; /* when was the last packet sent */ int64_t lastPacketSentTime; @@ -80,18 +78,13 @@ struct MqttClient int paused; bstring userName; bstring password; -}; - -enum MessageState -{ - MessageStateQueued = 100, - MessageStateSend, - MessageStateSent + /* The packet we are receiving */ + MqttPacket inPacket; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); static int MqttClientQueueSimplePacket(MqttClient *client, int type); -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientSendPacket(MqttClient *client); static int MqttClientRecvPacket(MqttClient *client); static uint16_t MqttClientNextPacketId(MqttClient *client); static void MqttClientProcessMessageQueue(MqttClient *client); @@ -99,14 +92,14 @@ static void MqttClientClearQueues(MqttClient *client); static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int queued = 0; int inMessagesCount = 0; int outMessagesCount = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (packet->state == MessageStateQueued) + if (msg->state == MqttMessageStateQueued) { ++queued; } @@ -114,7 +107,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) ++outMessagesCount; } - TAILQ_FOREACH(packet, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { ++inMessagesCount; } @@ -212,11 +205,44 @@ void MqttClientSetOnPublish(MqttClient *client, client->onPublish = cb; } +static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static unsigned char MqttClientConnectFlags(MqttClient *client) +{ + unsigned char connectFlags = 0; + + if (client->cleanSession) + { + connectFlags |= 0x02; + } + + if (client->willTopic) + { + connectFlags |= 0x04; + connectFlags |= (client->willQos & 3) << 3; + connectFlags |= (client->willRetain & 1) << 5; + } + + if (client->userName) + { + connectFlags |= 0x80; + if (client->password) + { + connectFlags |= 0x40; + } + } + + return connectFlags; +} + int MqttClientConnect(MqttClient *client, const char *host, short port, int keepAlive, int cleanSession) { int sock; - MqttPacketConnect *packet; + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(host != NULL); @@ -253,44 +279,37 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, return -1; } - packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + packet = MqttPacketNew(MqttPacketTypeConnect); if (!packet) return -1; - if (client->cleanSession) - { - packet->connectFlags |= 0x02; - } - - packet->keepAlive = client->keepAlive; + StringStreamInit(&ss); - packet->clientId = bstrcpy(client->clientId); + StreamWriteMqttString(&MqttProtocolId, pss); + StreamWriteByte(MqttProtocolLevel, pss); + StreamWriteByte(MqttClientConnectFlags(client), pss); + StreamWriteUint16Be(client->keepAlive, pss); + StreamWriteMqttString(client->clientId, pss); if (client->willTopic) { - packet->connectFlags |= 0x04; - - packet->willTopic = bstrcpy(client->willTopic); - packet->willMessage = bstrcpy(client->willMessage); - - packet->connectFlags |= (client->willQos & 3) << 3; - packet->connectFlags |= (client->willRetain & 1) << 5; + StreamWriteMqttString(client->willTopic, pss); + StreamWriteMqttString(client->willMessage, pss); } if (client->userName) { - packet->connectFlags |= 0x80; - packet->userName = bstrcpy(client->userName); - - if (client->password) + StreamWriteMqttString(client->userName, pss); + if(client->password) { - packet->connectFlags |= 0x40; - packet->password = bstrcpy(client->password); + StreamWriteMqttString(client->password, pss); } } - MqttClientQueuePacket(client, &packet->base); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); return 0; } @@ -339,6 +358,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (timeout < 0) { timeout = client->keepAlive * 1000; + if (timeout == 0) + { + timeout = 30 * 1000; + } + } + else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0) + { + timeout = client->keepAlive * 1000; } rv = SocketSelect(client->stream.sock, &events, timeout); @@ -354,21 +381,12 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (events & EV_WRITE) { - MqttPacket *packet; - LOG_DEBUG("socket writable"); - packet = SIMPLEQ_FIRST(&client->sendQueue); - - if (packet) + if (MqttClientSendPacket(client) == -1) { - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (MqttClientSendPacket(client, packet) == -1) - { - LOG_ERROR("MqttClientSendPacket failed"); - client->stopped = 1; - } + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; } } @@ -433,10 +451,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, } int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count) + int *qos, size_t count) { - MqttPacketSubscribe *packet = NULL; + MqttPacket *packet = NULL; size_t i; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(topicFilters != NULL); @@ -444,68 +464,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, assert(qos != NULL); assert(count > 0); - packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( - MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe, + MqttClientNextPacketId(client)); if (!packet) return -1; - packet->topicFilters = bstrListCreate(); - bstrListAllocMin(packet->topicFilters, count); + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(packet->id, pss); - packet->qos = (int *) malloc(sizeof(int) * count); + LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id); for (i = 0; i < count; ++i) { - packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); - ++packet->topicFilters->qty; + struct tagbstring filter; + btfromcstr(filter, topicFilters[i]); + StreamWriteMqttString(&filter, pss); + StreamWriteByte(qos[i] & 3, pss); } - memcpy(packet->qos, qos, sizeof(int) * count); + packet->payload = ss.buffer; - MqttClientQueuePacket(client, (MqttPacket *) packet); - - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) { - MqttPacketUnsubscribe *packet = NULL; + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + struct tagbstring filter; assert(client != NULL); assert(topicFilter != NULL); - packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( - MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe, + MqttClientNextPacketId(client)); - packet->topicFilter = bfromcstr(topicFilter); + if (!packet) + return -1; + + packet->flags = 0x02; + + StringStreamInit(&ss); + + StreamWriteUint16Be(packet->id, pss); + + btfromcstr(filter, topicFilter); + + StreamWriteMqttString(&filter, pss); - MqttClientQueuePacket(client, (MqttPacket *) packet); + packet->payload = ss.buffer; - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int count = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { ++count; } return count; } +static MqttPacket *PublishToPacket(MqttMessage *msg) +{ + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + + if (msg->qos > 0) + { + packet = MqttPacketWithIdNew(MqttPacketTypePublish, + msg->id); + } + else + { + packet = MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return NULL; + + packet->message = msg; + + StringStreamInit(&ss); + + StreamWriteMqttString(msg->topic, pss); + + if (msg->qos > 0) + { + StreamWriteUint16Be(msg->id, pss); + } + + StreamWrite(bdata(msg->payload), blength(msg->payload), pss); + + packet->payload = ss.buffer; + packet->flags = (msg->qos & 3) << 1; + packet->flags |= msg->retain & 1; + + return packet; +} + int MqttClientPublish(MqttClient *client, int qos, int retain, const char *topic, const void *data, size_t size) { - MqttPacketPublish *packet; - - assert(client != NULL); + MqttMessage *message; /* first check if the queue is already full */ if (qos > 0 && client->maxQueued > 0 && @@ -514,55 +588,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain, return -1; } - if (qos > 0) + message = calloc(1, sizeof(*message)); + if (!message) { - packet = (MqttPacketPublish *) MqttPacketWithIdNew( - MqttPacketTypePublish, MqttClientNextPacketId(client)); + return -1; } - else + + message->state = MqttMessageStateQueued; + message->qos = qos; + message->retain = retain; + message->dup = 0; + message->timestamp = MqttGetCurrentTime(); + + if (qos == 0) { - packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); - } + /* Copy payload and topic directly from user buffers as we don't need + to keep the message data around after this function. */ + MqttPacket *packet; + struct tagbstring bttopic, btpayload; - if (!packet) - return -1; + btfromcstr(bttopic, topic); + message->topic = &bttopic; - packet->qos = qos; - packet->retain = retain; - packet->topicName = bfromcstr(topic); - packet->message = blk2bstr(data, size); + btfromblk(btpayload, data, size); + message->payload = &btpayload; - if (qos > 0) - { - /* check how many messages there are coming in and going out currently - that are not yet done */ - if (client->maxInflight == 0 || - MqttClientInflightMessageCount(client) < client->maxInflight) - { - LOG_DEBUG("setting message (%d) state to MessageStateSend", - MqttPacketId(packet)); - packet->base.state = MessageStateSend; - } - else - { - LOG_DEBUG("setting message (%d) state to MessageStateQueued", - MqttPacketId(packet)); - packet->base.state = MessageStateQueued; - } + packet = PublishToPacket(message); + + message->topic = NULL; + message->payload = NULL; + + MqttClientQueuePacket(client, packet); + + MqttMessageFree(message); - /* add the message to the outMessages queue to wait for processing */ - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, - messages); + return 0; } else { - MqttClientQueuePacket(client, (MqttPacket *) packet); - } + /* Duplicate the user buffers as we need the data to be available + longer. */ + message->topic = bfromcstr(topic); + message->payload = blk2bstr(data, size); - if (qos > 0) - return MqttPacketId(packet); + message->id = MqttClientNextPacketId(client); - return 0; + TAILQ_INSERT_TAIL(&client->outMessages, message, chain); + + return message->id; + } } int MqttClientPublishCString(MqttClient *client, int qos, int retain, @@ -655,6 +729,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) { assert(client != NULL); LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + packet->state = MqttPacketStateWriteType; SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); } @@ -667,128 +742,332 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type) return 0; } -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +static int MqttClientSendPacket(MqttClient *client) { - if (MqttPacketSerialize(packet, &client->stream.base) == -1) - return -1; + MqttPacket *packet; - packet->sentAt = MqttGetCurrentTime(); - client->lastPacketSentTime = packet->sentAt; + packet = SIMPLEQ_FIRST(&client->sendQueue); - if (packet->type == MqttPacketTypeDisconnect) + if (!packet) { - client->stopped = 1; + LOG_WARNING("MqttClientSendPacket called with no queued packets"); + return 0; } - /* If the packet is not on any message list, it can be removed after - sending. */ - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) + while (packet != NULL) { - LOG_DEBUG("freeing packet %s after sending", - MqttPacketName(MqttPacketType(packet))); - MqttPacketFree(packet); + switch (packet->state) + { + case MqttPacketStateWriteType: + { + unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) | + (packet->flags & 0x0F); + + if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) + { + return -1; + } + + packet->state = MqttPacketStateWriteRemainingLength; + + break; + } + + case MqttPacketStateWriteRemainingLength: + { + if (StreamWriteRemainingLength(blength(packet->payload), + &client->stream.base) == -1) + { + return -1; + } + + packet->state = MqttPacketStateWritePayload; + + break; + } + + case MqttPacketStateWritePayload: + { + if (packet->payload) + { + if (StreamWrite(bdata(packet->payload), + blength(packet->payload), + &client->stream.base) == -1) + { + return -1; + } + } + + packet->state = MqttPacketStateWriteComplete; + + break; + } + + case MqttPacketStateWriteComplete: + { + client->lastPacketSentTime = MqttGetCurrentTime(); + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + } + + LOG_DEBUG("sent %s", MqttPacketName(packet->type)); + + if (packet->type == MqttPacketTypePublish && packet->message) + { + MqttMessage *msg = packet->message; + + if (msg->qos == 1) + { + msg->state = MqttMessageStateWaitPubAck; + } + else if (msg->qos == 2) + { + msg->state = MqttMessageStateWaitPubRec; + } + } + + if (packet->message) + { + packet->message->timestamp = client->lastPacketSentTime; + } + + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + MqttPacketFree(packet); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + break; + } + } } return 0; } -static void MqttClientHandleConnAck(MqttClient *client, - MqttPacketConnAck *packet) +static int MqttClientHandleConnAck(MqttClient *client) { - client->sessionPresent = packet->connAckFlags & 1; + StringStream ss; + Stream *pss = (Stream *) &ss; + unsigned char flags; + unsigned char rc; + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadByte(&flags, pss); + + StreamReadByte(&rc, pss); + + client->sessionPresent = flags & 1; LOG_DEBUG("sessionPresent:%d", client->sessionPresent); if (client->onConnect) { - LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); - client->onConnect(client, packet->returnCode, client->sessionPresent); + LOG_DEBUG("calling onConnect rc:%d", rc); + client->onConnect(client, rc, client->sessionPresent); } + + return 0; } -static void MqttClientHandlePingResp(MqttClient *client) +static int MqttClientHandlePingResp(MqttClient *client) { LOG_DEBUG("got ping response"); client->pingSent = 0; + return 0; } -static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +static int MqttClientHandleSubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + int *qos; + StringStream ss; + Stream *pss = (Stream *) &ss; + int count; + int i; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + LOG_DEBUG("received SUBACK with id:%d", (int) id); + + count = blength(client->inPacket.payload) - StreamTell(pss); + + if (count <= 0) { - if (MqttPacketType(sub) == MqttPacketTypeSubscribe && - MqttPacketId(sub) == MqttPacketId(packet)) - { - break; - } + LOG_ERROR("number of return codes invalid"); + return -1; } - if (!sub) + qos = malloc(count * sizeof(int)); + + for (i = 0; i < count; ++i) { - LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + unsigned char byte; + StreamReadByte(&byte, pss); + qos[i] = byte; } - else + + if (client->onSubscribe) { - if (client->onSubscribe) - { - MqttPacketSubscribe *sub2; - int i; + client->onSubscribe(client, id, qos, count); + } - sub2 = (MqttPacketSubscribe *) sub; + free(qos); - for (i = 0; i < sub2->topicFilters->qty; ++i) - { - const char *filter = bdata(sub2->topicFilters->entry[i]); - int rc = packet->returnCode[i]; + return 0; +} - LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", - MqttPacketId(packet), filter, rc); +static int MqttClientSendPubAck(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; - client->onSubscribe(client, MqttPacketId(packet), filter, rc); - } - } + packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id); - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - } + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id); + + if (!packet) + return -1; + + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubComp(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; } -static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +static int MqttClientHandlePublish(MqttClient *client) { + MqttMessage *msg; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; + MqttPacket *packet; + int qos; + int retain; + bstring topic; + void *payload; + int payloadSize; + + /* We are paused - do nothing */ if (client->paused) - return; + return 0; + + packet = &client->inPacket; + + qos = (packet->flags >> 1) & 3; + retain = packet->flags & 1; + + StringStreamInitFromBstring(&ss, packet->payload); + + StreamReadMqttString(&topic, pss); + + StreamReadUint16Be(&id, pss); + + payload = bdataofs(ss.buffer, ss.pos); + payloadSize = blength(ss.buffer) - ss.pos; - if (MqttPacketPublishQos(packet) == 2) + if (qos == 2) { /* Check if we have sent a PUBREC previously with the same id. If we have, we have to resend the PUBREC. We must not call the onMessage callback again. */ - MqttPacket *pubRec; - - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (pubRec) + if (msg) { - LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); - MqttClientQueuePacket(client, pubRec); - return; + LOG_DEBUG("resending PUBREC id:%u", msg->id); + MqttClientSendPubRec(client, msg); + bdestroy(topic); + return 0; } } @@ -796,268 +1075,347 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe { LOG_DEBUG("calling onMessage"); client->onMessage(client, - bdata(packet->topicName), - bdata(packet->message), - blength(packet->message), - packet->qos, - packet->retain); + bdata(topic), + payload, + payloadSize, + qos, + retain); } - if (MqttPacketPublishQos(packet) > 0) + bdestroy(topic); + + if (qos == 1) + { + MqttClientSendPubAck(client, id); + } + else if (qos == 2) { - int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : - MqttPacketTypePubRec; + msg = calloc(1, sizeof(*msg)); - MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + msg->state = MqttMessageStateWaitPubRel; + msg->id = id; + msg->qos = qos; - if (MqttPacketPublishQos(packet) == 2) - { - /* append to inMessages as we need a reply to this response */ - TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); - } + TAILQ_INSERT_TAIL(&client->inMessages, msg, chain); - MqttClientQueuePacket(client, resp); + MqttClientSendPubRec(client, msg); } + + return 0; } -static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubAck(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubAck) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); - if (client->onPublish) - { - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + if (client->onPublish) + { + client->onPublish(client, msg->id); } + + MqttMessageFree(msg); + + return 0; } -static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRec(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but + they haven't received it. */ + if (msg->id == id && + (msg->state == MqttMessageStateWaitPubRec || + msg->state == MqttMessageStateWaitPubComp)) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubRel; - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); + msg->state = MqttMessageStateWaitPubComp; - pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); - pubRel->state = MessageStateSend; + bdestroy(msg->payload); + msg->payload = NULL; - TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); - } + bdestroy(msg->topic); + msg->topic = NULL; + + if (MqttClientSendPubRel(client, msg) == -1) + return -1; + + return 0; } -static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRel(MqttClient *client) { - MqttPacket *pubRec; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (!pubRec) + if (!msg) { - MqttPacket *pubComp; - - TAILQ_FOREACH(pubComp, &client->inMessages, messages) - { - if (MqttPacketId(pubComp) == MqttPacketId(packet) && - MqttPacketType(pubComp) == MqttPacketTypePubComp) - { - break; - } - } - - if (pubComp) - { - MqttClientQueuePacket(client, pubComp); - } - else - { - LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; - } + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubComp; - - TAILQ_REMOVE(&client->inMessages, pubRec, messages); - MqttPacketFree(pubRec); - pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, - MqttPacketId(packet)); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); - TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + if (MqttClientSendPubComp(client, id) == -1) + return -1; - MqttClientQueuePacket(client, pubComp); - } + return 0; } -static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubComp(MqttClient *client) { - MqttPacket *pubRel; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); - TAILQ_FOREACH(pubRel, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pubRel) == MqttPacketId(packet) && - MqttPacketType(pubRel) == MqttPacketTypePubRel) + if (msg->id == id && msg->state == MqttMessageStateWaitPubComp) { break; } } - if (!pubRel) + if (!msg) { - LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_WARNING("no message found with id %d", (int) id); + return 0; } - else - { - TAILQ_REMOVE(&client->outMessages, pubRel, messages); - MqttPacketFree(pubRel); - if (client->onPublish) - { - LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + MqttMessageFree(msg); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", id); + client->onPublish(client, id); } + + return 0; } -static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandleUnsubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) - { - if (MqttPacketId(sub) == MqttPacketId(packet) && - MqttPacketType(sub) == MqttPacketTypeUnsubscribe) - { - break; - } - } + StringStreamInitFromBstring(&ss, client->inPacket.payload); - if (!sub) + StreamReadUint16Be(&id, pss); + + if (client->onUnsubscribe) { - LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + client->onUnsubscribe(client, id); } - else - { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onUnsubscribe) - { - LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); - client->onUnsubscribe(client, MqttPacketId(packet)); - } - } + return 0; } -static int MqttClientRecvPacket(MqttClient *client) +static int MqttClientHandlePacket(MqttClient *client) { - MqttPacket *packet = NULL; - - if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) - return -1; - - LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); + int rc; - switch (MqttPacketType(packet)) + switch (client->inPacket.type) { case MqttPacketTypeConnAck: - MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + rc = MqttClientHandleConnAck(client); break; case MqttPacketTypePingResp: - MqttClientHandlePingResp(client); + rc = MqttClientHandlePingResp(client); break; case MqttPacketTypeSubAck: - MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + rc = MqttClientHandleSubAck(client); break; - case MqttPacketTypePublish: - MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + case MqttPacketTypeUnsubAck: + rc = MqttClientHandleUnsubAck(client); break; case MqttPacketTypePubAck: - MqttClientHandlePubAck(client, packet); + rc = MqttClientHandlePubAck(client); break; case MqttPacketTypePubRec: - MqttClientHandlePubRec(client, packet); + rc = MqttClientHandlePubRec(client); break; - case MqttPacketTypePubRel: - MqttClientHandlePubRel(client, packet); + case MqttPacketTypePubComp: + rc = MqttClientHandlePubComp(client); break; - case MqttPacketTypePubComp: - MqttClientHandlePubComp(client, packet); + case MqttPacketTypePubRel: + rc = MqttClientHandlePubRel(client); break; - case MqttPacketTypeUnsubAck: - MqttClientHandleUnsubAck(client, packet); + case MqttPacketTypePublish: + rc = MqttClientHandlePublish(client); break; default: - LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + LOG_ERROR("packet not handled yet"); + rc = -1; break; } - MqttPacketFree(packet); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + + client->inPacket.state = MqttPacketStateReadType; + + return rc; +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + while (1) + { + switch (client->inPacket.state) + { + case MqttPacketStateReadType: + { + unsigned char typeAndFlags; + int rc; + + if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1) + { + LOG_ERROR("failed reading packet type: %d", rc); + return -1; + } + + client->inPacket.type = typeAndFlags >> 4; + client->inPacket.flags = typeAndFlags & 0x0F; + + if (client->inPacket.type < MqttPacketTypeConnect || + client->inPacket.type > MqttPacketTypeDisconnect) + { + LOG_ERROR("unknown packet type: %d", client->inPacket.type); + return -1; + } + + client->inPacket.state = MqttPacketStateReadRemainingLength; + + break; + } + + case MqttPacketStateReadRemainingLength: + { + if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->stream.base) == -1) + { + LOG_ERROR("failed to read remaining length"); + return -1; + } + client->inPacket.state = MqttPacketStateReadPayload; + break; + } + + case MqttPacketStateReadPayload: + { + if (client->inPacket.remainingLength > 0) + { + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + if (StreamRead(bdata(client->inPacket.payload), + client->inPacket.remainingLength, + &client->stream.base) == -1) + { + LOG_ERROR("failed reading packet payload"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + client->inPacket.payload->slen = client->inPacket.remainingLength; + } + client->inPacket.state = MqttPacketStateReadComplete; + break; + } + + case MqttPacketStateReadComplete: + { + int type = client->inPacket.type; + LOG_DEBUG("received %s", MqttPacketName(type)); + return MqttClientHandlePacket(client); + } + } + } return 0; } @@ -1072,101 +1430,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client) return id; } -static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +static int64_t MqttMessageTimeSinceSent(MqttMessage *msg) { int64_t now = MqttGetCurrentTime(); - return now - packet->sentAt; + return now - msg->timestamp; } -static void MqttClientProcessInMessages(MqttClient *client) +static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg) { - MqttPacket *packet, *next; - - LOG_DEBUG("processing inMessages"); - - TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + if (msg->timestamp > 0 && + MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000) { - LOG_DEBUG("packet type:%s id:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - - if (MqttPacketType(packet) == MqttPacketTypePubComp) - { - int64_t elapsed = MqttPacketTimeSinceSent(packet); - if (packet->sentAt > 0 && - elapsed >= client->retryTimeout*1000) - { - LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, - MqttPacketId(packet), elapsed); - - TAILQ_REMOVE(&client->inMessages, packet, messages); - - MqttPacketFree(packet); - } - } + return 1; } + + return 0; } -static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +static void MqttClientProcessInMessages(MqttClient *client) { - if (packet->sentAt > 0 && - MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + MqttMessage *msg, *next; + + TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next) { - return 1; - } + switch (msg->state) + { + case MqttMessageStateWaitPubRel: + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRec(client, msg); + } + break; - return 0; + default: + break; + } + } } static void MqttClientProcessOutMessages(MqttClient *client) { - MqttPacket *packet, *next; + MqttMessage *msg, *next; + MqttPacket *packet; int inflight = MqttClientInflightMessageCount(client); - LOG_DEBUG("processing outMessages inflight:%d", inflight); - - TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next) { - LOG_DEBUG("packet type:%s id:%d state:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet), - packet->state); - - switch (packet->state) + switch (msg->state) { - case MessageStateQueued: + case MqttMessageStateQueued: + { if (inflight >= client->maxInflight) { - LOG_DEBUG("cannot dequeue %s/%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - break; - } - else - { - /* If there's less than maxInflight messages currently - inflight, we can dequeue some messages by falling - through to MessageStateSend. */ - LOG_DEBUG("dequeuing %s (%d)", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - ++inflight; + continue; } - - case MessageStateSend: - packet->state = MessageStateSent; + /* State change from MqttMessageStatePublish happens after + the packet has been sent (in MqttClientSendPacket). */ + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); MqttClientQueuePacket(client, packet); + ++inflight; break; + } - case MessageStateSent: - if (MqttPacketShouldResend(client, packet)) + case MqttMessageStateWaitPubAck: + case MqttMessageStateWaitPubRec: + { + if (MqttMessageShouldResend(client, msg)) { - packet->state = MessageStateSend; + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); + MqttClientQueuePacket(client, packet); } break; + } - default: + case MqttMessageStateWaitPubComp: + { + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRel(client, msg); + } break; + } } } } @@ -1182,30 +1528,22 @@ static void MqttClientClearQueues(MqttClient *client) while (!SIMPLEQ_EMPTY(&client->sendQueue)) { MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue); - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) - { - MqttPacketFree(packet); - } + MqttPacketFree(packet); } while (!TAILQ_EMPTY(&client->outMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->outMessages); - TAILQ_REMOVE(&client->outMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->outMessages); + TAILQ_REMOVE(&client->outMessages, msg, chain); + MqttMessageFree(msg); } while (!TAILQ_EMPTY(&client->inMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->inMessages); - TAILQ_REMOVE(&client->inMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->inMessages); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); } } |
