diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:29:19 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:29:19 +0200 |
| commit | 7aeef53b089272f4633cc40512296bfd884a58d4 (patch) | |
| tree | 894753ced0495f725ad8362859f88d5b61e29eb7 /src/client.c | |
| parent | e9958e8a0f5aa5fbe0a4a03be42b8bf640add6f7 (diff) | |
| parent | 2c76b0da9e0aba2211d5b4a8e51c79e47ad9b6c8 (diff) | |
| download | mqtt-0.5.tar.gz mqtt-0.5.zip | |
Merge branch 'the-great-refactor'v0.5
* the-great-refactor:
Add big_message_test
Fix publish message serialization
Modify the code to use nonblocking sockets
Fix indentation
Free userName and password in MqttClientFree()
Add forgotten files
Massive refactoring of the internals
Diffstat (limited to 'src/client.c')
| -rw-r--r-- | src/client.c | 1310 |
1 files changed, 889 insertions, 421 deletions
diff --git a/src/client.c b/src/client.c index a6b0998..b95c8d5 100644 --- a/src/client.c +++ b/src/client.c @@ -5,10 +5,11 @@ #include "socketstream.h" #include "socket.h" #include "misc.h" -#include "serialize.h" -#include "deserialize.h" #include "log.h" #include "private.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "message.h" #include "queue.h" @@ -25,8 +26,14 @@ #error define PRId64 for your platform #endif -TAILQ_HEAD(MessageList, MqttPacket); -typedef struct MessageList MessageList; +typedef enum MqttClientState MqttClientState; + +enum MqttClientState +{ + MqttClientStateDisconnected, + MqttClientStateConnecting, + MqttClientStateConnected, +}; struct MqttClient { @@ -56,9 +63,9 @@ struct MqttClient /* packets waiting to be sent over network */ SIMPLEQ_HEAD(, MqttPacket) sendQueue; /* sent messages that are not done yet */ - MessageList outMessages; + MqttMessageList outMessages; /* received messages that are not done yet */ - MessageList inMessages; + MqttMessageList inMessages; int sessionPresent; /* when was the last packet sent */ int64_t lastPacketSentTime; @@ -80,18 +87,14 @@ struct MqttClient int paused; bstring userName; bstring password; -}; - -enum MessageState -{ - MessageStateQueued = 100, - MessageStateSend, - MessageStateSent + /* The packet we are receiving */ + MqttPacket inPacket; + MqttClientState state; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); static int MqttClientQueueSimplePacket(MqttClient *client, int type); -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientSendPacket(MqttClient *client); static int MqttClientRecvPacket(MqttClient *client); static uint16_t MqttClientNextPacketId(MqttClient *client); static void MqttClientProcessMessageQueue(MqttClient *client); @@ -99,14 +102,14 @@ static void MqttClientClearQueues(MqttClient *client); static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int queued = 0; int inMessagesCount = 0; int outMessagesCount = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (packet->state == MessageStateQueued) + if (msg->state == MqttMessageStateQueued) { ++queued; } @@ -114,7 +117,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) ++outMessagesCount; } - TAILQ_FOREACH(packet, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { ++inMessagesCount; } @@ -142,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId) client->maxQueued = 0; client->maxInflight = 20; + client->state = MqttClientStateDisconnected; + TAILQ_INIT(&client->outMessages); TAILQ_INIT(&client->inMessages); SIMPLEQ_INIT(&client->sendQueue); @@ -157,6 +162,8 @@ void MqttClientFree(MqttClient *client) bdestroy(client->willTopic); bdestroy(client->willMessage); bdestroy(client->host); + bdestroy(client->userName); + bdestroy(client->password); if (client->stream.sock != -1) { @@ -212,15 +219,54 @@ void MqttClientSetOnPublish(MqttClient *client, client->onPublish = cb; } +static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static unsigned char MqttClientConnectFlags(MqttClient *client) +{ + unsigned char connectFlags = 0; + + if (client->cleanSession) + { + connectFlags |= 0x02; + } + + if (client->willTopic) + { + connectFlags |= 0x04; + connectFlags |= (client->willQos & 3) << 3; + connectFlags |= (client->willRetain & 1) << 5; + } + + if (client->userName) + { + connectFlags |= 0x80; + if (client->password) + { + connectFlags |= 0x40; + } + } + + return connectFlags; +} + int MqttClientConnect(MqttClient *client, const char *host, short port, int keepAlive, int cleanSession) { int sock; - MqttPacketConnect *packet; + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(host != NULL); + if (client->state != MqttClientStateDisconnected) + { + LOG_ERROR("client must be disconnected to connect"); + return -1; + } + if (client->host) bassigncstr(client->host, host); else @@ -242,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, LOG_DEBUG("connecting"); - if ((sock = SocketConnect(host, port)) == -1) + if ((sock = SocketConnect(host, port, 1)) == -1) { - LOG_ERROR("SocketConnect failed!"); - return -1; + if (SocketErrno != SOCKET_EINPROGRESS) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } } if (SocketStreamOpen(&client->stream, sock) == -1) @@ -253,44 +302,39 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, return -1; } - packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + packet = MqttPacketNew(MqttPacketTypeConnect); if (!packet) return -1; - if (client->cleanSession) - { - packet->connectFlags |= 0x02; - } + StringStreamInit(&ss); - packet->keepAlive = client->keepAlive; - - packet->clientId = bstrcpy(client->clientId); + StreamWriteMqttString(&MqttProtocolId, pss); + StreamWriteByte(MqttProtocolLevel, pss); + StreamWriteByte(MqttClientConnectFlags(client), pss); + StreamWriteUint16Be(client->keepAlive, pss); + StreamWriteMqttString(client->clientId, pss); if (client->willTopic) { - packet->connectFlags |= 0x04; - - packet->willTopic = bstrcpy(client->willTopic); - packet->willMessage = bstrcpy(client->willMessage); - - packet->connectFlags |= (client->willQos & 3) << 3; - packet->connectFlags |= (client->willRetain & 1) << 5; + StreamWriteMqttString(client->willTopic, pss); + StreamWriteMqttString(client->willMessage, pss); } if (client->userName) { - packet->connectFlags |= 0x80; - packet->userName = bstrcpy(client->userName); - - if (client->password) + StreamWriteMqttString(client->userName, pss); + if(client->password) { - packet->connectFlags |= 0x40; - packet->password = bstrcpy(client->password); + StreamWriteMqttString(client->password, pss); } } - MqttClientQueuePacket(client, &packet->base); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + client->state = MqttClientStateConnecting; return 0; } @@ -303,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client) int MqttClientIsConnected(MqttClient *client) { - return client->stream.sock != -1; + return client->stream.sock != -1 && + client->state == MqttClientStateConnected; } int MqttClientRunOnce(MqttClient *client, int timeout) { int rv; - int events; + int events = 0; assert(client != NULL); @@ -319,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout) return -1; } - events = EV_READ; + if (client->state == MqttClientStateConnected) + { + events = EV_READ; - /* Handle outMessages and inMessages, moving queued messages to sendQueue - if there are less than maxInflight number of messages in flight */ - MqttClientProcessMessageQueue(client); + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); - if (SIMPLEQ_EMPTY(&client->sendQueue)) + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + events |= EV_WRITE; + } + } + else if (client->state == MqttClientStateConnecting) { - LOG_DEBUG("nothing to write"); + events = EV_WRITE; } else { - events |= EV_WRITE; + LOG_ERROR("not connected"); + return -1; } LOG_DEBUG("selecting"); @@ -339,6 +396,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (timeout < 0) { timeout = client->keepAlive * 1000; + if (timeout == 0) + { + timeout = 30 * 1000; + } + } + else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0) + { + timeout = client->keepAlive * 1000; } rv = SocketSelect(client->stream.sock, &events, timeout); @@ -354,22 +419,26 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (events & EV_WRITE) { - MqttPacket *packet; - LOG_DEBUG("socket writable"); - packet = SIMPLEQ_FIRST(&client->sendQueue); - - if (packet) + if (client->state == MqttClientStateConnecting) { - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (MqttClientSendPacket(client, packet) == -1) + int sockError; + SocketGetError(client->stream.sock, &sockError); + LOG_DEBUG("sockError: %d", sockError); + if (sockError == 0) { - LOG_ERROR("MqttClientSendPacket failed"); - client->stopped = 1; + LOG_DEBUG("connected!"); + client->state = MqttClientStateConnected; + return 0; } } + + if (MqttClientSendPacket(client) == -1) + { + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; + } } if (events & EV_READ) @@ -433,10 +502,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, } int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count) + int *qos, size_t count) { - MqttPacketSubscribe *packet = NULL; + MqttPacket *packet = NULL; size_t i; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(topicFilters != NULL); @@ -444,68 +515,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, assert(qos != NULL); assert(count > 0); - packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( - MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe, + MqttClientNextPacketId(client)); if (!packet) return -1; - packet->topicFilters = bstrListCreate(); - bstrListAllocMin(packet->topicFilters, count); + packet->flags = 0x2; + + StringStreamInit(&ss); - packet->qos = (int *) malloc(sizeof(int) * count); + StreamWriteUint16Be(packet->id, pss); + + LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id); for (i = 0; i < count; ++i) { - packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); - ++packet->topicFilters->qty; + struct tagbstring filter; + btfromcstr(filter, topicFilters[i]); + StreamWriteMqttString(&filter, pss); + StreamWriteByte(qos[i] & 3, pss); } - memcpy(packet->qos, qos, sizeof(int) * count); - - MqttClientQueuePacket(client, (MqttPacket *) packet); + packet->payload = ss.buffer; - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) { - MqttPacketUnsubscribe *packet = NULL; + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + struct tagbstring filter; assert(client != NULL); assert(topicFilter != NULL); - packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( - MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe, + MqttClientNextPacketId(client)); + + if (!packet) + return -1; + + packet->flags = 0x02; + + StringStreamInit(&ss); - packet->topicFilter = bfromcstr(topicFilter); + StreamWriteUint16Be(packet->id, pss); - MqttClientQueuePacket(client, (MqttPacket *) packet); + btfromcstr(filter, topicFilter); - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + StreamWriteMqttString(&filter, pss); - return MqttPacketId(packet); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return packet->id; } static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int count = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { ++count; } return count; } +static MqttPacket *PublishToPacket(MqttMessage *msg) +{ + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + + if (msg->qos > 0) + { + packet = MqttPacketWithIdNew(MqttPacketTypePublish, + msg->id); + } + else + { + packet = MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return NULL; + + packet->message = msg; + + StringStreamInit(&ss); + + StreamWriteMqttString(msg->topic, pss); + + if (msg->qos > 0) + { + StreamWriteUint16Be(msg->id, pss); + } + + StreamWrite(bdata(msg->payload), blength(msg->payload), pss); + + packet->payload = ss.buffer; + packet->flags = (msg->qos & 3) << 1; + packet->flags |= msg->retain & 1; + + return packet; +} + int MqttClientPublish(MqttClient *client, int qos, int retain, const char *topic, const void *data, size_t size) { - MqttPacketPublish *packet; - - assert(client != NULL); + MqttMessage *message; /* first check if the queue is already full */ if (qos > 0 && client->maxQueued > 0 && @@ -514,55 +639,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain, return -1; } - if (qos > 0) + message = calloc(1, sizeof(*message)); + if (!message) { - packet = (MqttPacketPublish *) MqttPacketWithIdNew( - MqttPacketTypePublish, MqttClientNextPacketId(client)); + return -1; } - else + + message->state = MqttMessageStateQueued; + message->qos = qos; + message->retain = retain; + message->dup = 0; + message->timestamp = MqttGetCurrentTime(); + + if (qos == 0) { - packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); - } + /* Copy payload and topic directly from user buffers as we don't need + to keep the message data around after this function. */ + MqttPacket *packet; + struct tagbstring bttopic, btpayload; - if (!packet) - return -1; + btfromcstr(bttopic, topic); + message->topic = &bttopic; - packet->qos = qos; - packet->retain = retain; - packet->topicName = bfromcstr(topic); - packet->message = blk2bstr(data, size); + btfromblk(btpayload, data, size); + message->payload = &btpayload; - if (qos > 0) - { - /* check how many messages there are coming in and going out currently - that are not yet done */ - if (client->maxInflight == 0 || - MqttClientInflightMessageCount(client) < client->maxInflight) - { - LOG_DEBUG("setting message (%d) state to MessageStateSend", - MqttPacketId(packet)); - packet->base.state = MessageStateSend; - } - else - { - LOG_DEBUG("setting message (%d) state to MessageStateQueued", - MqttPacketId(packet)); - packet->base.state = MessageStateQueued; - } + packet = PublishToPacket(message); - /* add the message to the outMessages queue to wait for processing */ - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, - messages); + message->topic = NULL; + message->payload = NULL; + + MqttClientQueuePacket(client, packet); + + MqttMessageFree(message); + + return 0; } else { - MqttClientQueuePacket(client, (MqttPacket *) packet); - } + /* Duplicate the user buffers as we need the data to be available + longer. */ + message->topic = bfromcstr(topic); + message->payload = blk2bstr(data, size); - if (qos > 0) - return MqttPacketId(packet); + message->id = MqttClientNextPacketId(client); - return 0; + TAILQ_INSERT_TAIL(&client->outMessages, message, chain); + + return message->id; + } } int MqttClientPublishCString(MqttClient *client, int qos, int retain, @@ -613,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName, { assert(client != NULL); - if (MqttClientIsConnected(client)) + if (client->state == MqttClientStateConnecting) { LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect"); return -1; @@ -655,6 +780,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) { assert(client != NULL); LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + packet->state = MqttPacketStateWriteType; SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); } @@ -667,128 +793,363 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type) return 0; } -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +static int MqttClientSendPacket(MqttClient *client) { - if (MqttPacketSerialize(packet, &client->stream.base) == -1) - return -1; + MqttPacket *packet; - packet->sentAt = MqttGetCurrentTime(); - client->lastPacketSentTime = packet->sentAt; + packet = SIMPLEQ_FIRST(&client->sendQueue); - if (packet->type == MqttPacketTypeDisconnect) + if (!packet) { - client->stopped = 1; + LOG_WARNING("MqttClientSendPacket called with no queued packets"); + return 0; } - /* If the packet is not on any message list, it can be removed after - sending. */ - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) + while (packet != NULL) { - LOG_DEBUG("freeing packet %s after sending", - MqttPacketName(MqttPacketType(packet))); - MqttPacketFree(packet); + switch (packet->state) + { + case MqttPacketStateWriteType: + { + unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) | + (packet->flags & 0x0F); + + if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWriteRemainingLength; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWriteRemainingLength: + { + if (StreamWriteRemainingLength(&packet->remainingLength, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWritePayload; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWritePayload: + { + if (packet->payload) + { + int64_t offset = blength(packet->payload) - packet->remainingLength; + int64_t nwritten = 0; + int towrite = 16*1024; + + if (packet->remainingLength < 16*1024) + towrite = packet->remainingLength; + + nwritten = StreamWrite(bdataofs(packet->payload, offset), + towrite, + &client->stream.base); + + if (nwritten == -1) + { + if (SocketWouldBlock(SocketErrno)) + { + return 0; + } + return -1; + } + + packet->remainingLength -= nwritten; + + LOG_DEBUG("nwritten:%d", (int) nwritten); + } + + if (packet->remainingLength == 0) + { + LOG_DEBUG("packet payload sent"); + packet->state = MqttPacketStateWriteComplete; + } + + break; + } + + case MqttPacketStateWriteComplete: + { + client->lastPacketSentTime = MqttGetCurrentTime(); + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + client->state = MqttClientStateDisconnected; + } + + LOG_DEBUG("sent %s", MqttPacketName(packet->type)); + + if (packet->type == MqttPacketTypePublish && packet->message) + { + MqttMessage *msg = packet->message; + + if (msg->qos == 1) + { + msg->state = MqttMessageStateWaitPubAck; + } + else if (msg->qos == 2) + { + msg->state = MqttMessageStateWaitPubRec; + } + } + + if (packet->message) + { + packet->message->timestamp = client->lastPacketSentTime; + } + + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + MqttPacketFree(packet); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + break; + } + } } return 0; } -static void MqttClientHandleConnAck(MqttClient *client, - MqttPacketConnAck *packet) +static int MqttClientHandleConnAck(MqttClient *client) { - client->sessionPresent = packet->connAckFlags & 1; + StringStream ss; + Stream *pss = (Stream *) &ss; + unsigned char flags; + unsigned char rc; + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadByte(&flags, pss); + + StreamReadByte(&rc, pss); + + client->sessionPresent = flags & 1; LOG_DEBUG("sessionPresent:%d", client->sessionPresent); if (client->onConnect) { - LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); - client->onConnect(client, packet->returnCode, client->sessionPresent); + LOG_DEBUG("calling onConnect rc:%d", rc); + client->onConnect(client, rc, client->sessionPresent); } + + return 0; } -static void MqttClientHandlePingResp(MqttClient *client) +static int MqttClientHandlePingResp(MqttClient *client) { LOG_DEBUG("got ping response"); client->pingSent = 0; + return 0; } -static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +static int MqttClientHandleSubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + int *qos; + StringStream ss; + Stream *pss = (Stream *) &ss; + int count; + int i; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + LOG_DEBUG("received SUBACK with id:%d", (int) id); + + count = blength(client->inPacket.payload) - StreamTell(pss); + + if (count <= 0) { - if (MqttPacketType(sub) == MqttPacketTypeSubscribe && - MqttPacketId(sub) == MqttPacketId(packet)) - { - break; - } + LOG_ERROR("number of return codes invalid"); + return -1; } - if (!sub) + qos = malloc(count * sizeof(int)); + + for (i = 0; i < count; ++i) { - LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + unsigned char byte; + StreamReadByte(&byte, pss); + qos[i] = byte; } - else + + if (client->onSubscribe) { - if (client->onSubscribe) - { - MqttPacketSubscribe *sub2; - int i; + client->onSubscribe(client, id, qos, count); + } - sub2 = (MqttPacketSubscribe *) sub; + free(qos); - for (i = 0; i < sub2->topicFilters->qty; ++i) - { - const char *filter = bdata(sub2->topicFilters->entry[i]); - int rc = packet->returnCode[i]; + return 0; +} - LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", - MqttPacketId(packet), filter, rc); +static int MqttClientSendPubAck(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; - client->onSubscribe(client, MqttPacketId(packet), filter, rc); - } - } + packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id); - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - } + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id); + + if (!packet) + return -1; + + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubComp(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; } -static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +static int MqttClientHandlePublish(MqttClient *client) { + MqttMessage *msg; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; + MqttPacket *packet; + int qos; + int retain; + bstring topic; + void *payload; + int payloadSize; + + /* We are paused - do nothing */ if (client->paused) - return; + return 0; + + packet = &client->inPacket; + + qos = (packet->flags >> 1) & 3; + retain = packet->flags & 1; + + StringStreamInitFromBstring(&ss, packet->payload); + + StreamReadMqttString(&topic, pss); + + if (qos > 0) + { + StreamReadUint16Be(&id, pss); + } + + payload = bdataofs(ss.buffer, ss.pos); + payloadSize = blength(ss.buffer) - ss.pos; - if (MqttPacketPublishQos(packet) == 2) + if (qos == 2) { /* Check if we have sent a PUBREC previously with the same id. If we have, we have to resend the PUBREC. We must not call the onMessage callback again. */ - MqttPacket *pubRec; - - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (pubRec) + if (msg) { - LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); - MqttClientQueuePacket(client, pubRec); - return; + LOG_DEBUG("resending PUBREC id:%u", msg->id); + MqttClientSendPubRec(client, msg); + bdestroy(topic); + return 0; } } @@ -796,268 +1157,395 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe { LOG_DEBUG("calling onMessage"); client->onMessage(client, - bdata(packet->topicName), - bdata(packet->message), - blength(packet->message), - packet->qos, - packet->retain); + bdata(topic), + payload, + payloadSize, + qos, + retain); } - if (MqttPacketPublishQos(packet) > 0) + bdestroy(topic); + + if (qos == 1) + { + MqttClientSendPubAck(client, id); + } + else if (qos == 2) { - int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : - MqttPacketTypePubRec; + msg = calloc(1, sizeof(*msg)); - MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + msg->state = MqttMessageStateWaitPubRel; + msg->id = id; + msg->qos = qos; - if (MqttPacketPublishQos(packet) == 2) - { - /* append to inMessages as we need a reply to this response */ - TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); - } + TAILQ_INSERT_TAIL(&client->inMessages, msg, chain); - MqttClientQueuePacket(client, resp); + MqttClientSendPubRec(client, msg); } + + return 0; } -static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubAck(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubAck) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); - if (client->onPublish) - { - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + if (client->onPublish) + { + client->onPublish(client, msg->id); } + + MqttMessageFree(msg); + + return 0; } -static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRec(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but + they haven't received it. */ + if (msg->id == id && + (msg->state == MqttMessageStateWaitPubRec || + msg->state == MqttMessageStateWaitPubComp)) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubRel; - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); + msg->state = MqttMessageStateWaitPubComp; - pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); - pubRel->state = MessageStateSend; + bdestroy(msg->payload); + msg->payload = NULL; - TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); - } + bdestroy(msg->topic); + msg->topic = NULL; + + if (MqttClientSendPubRel(client, msg) == -1) + return -1; + + return 0; } -static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRel(MqttClient *client) { - MqttPacket *pubRec; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (!pubRec) + if (!msg) { - MqttPacket *pubComp; - - TAILQ_FOREACH(pubComp, &client->inMessages, messages) - { - if (MqttPacketId(pubComp) == MqttPacketId(packet) && - MqttPacketType(pubComp) == MqttPacketTypePubComp) - { - break; - } - } - - if (pubComp) - { - MqttClientQueuePacket(client, pubComp); - } - else - { - LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; - } + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubComp; - TAILQ_REMOVE(&client->inMessages, pubRec, messages); - MqttPacketFree(pubRec); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); - pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, - MqttPacketId(packet)); - - TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + if (MqttClientSendPubComp(client, id) == -1) + return -1; - MqttClientQueuePacket(client, pubComp); - } + return 0; } -static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubComp(MqttClient *client) { - MqttPacket *pubRel; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pubRel, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pubRel) == MqttPacketId(packet) && - MqttPacketType(pubRel) == MqttPacketTypePubRel) + if (msg->id == id && msg->state == MqttMessageStateWaitPubComp) { break; } } - if (!pubRel) + if (!msg) { - LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_WARNING("no message found with id %d", (int) id); + return 0; } - else - { - TAILQ_REMOVE(&client->outMessages, pubRel, messages); - MqttPacketFree(pubRel); - if (client->onPublish) - { - LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + MqttMessageFree(msg); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", id); + client->onPublish(client, id); } + + return 0; } -static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandleUnsubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) - { - if (MqttPacketId(sub) == MqttPacketId(packet) && - MqttPacketType(sub) == MqttPacketTypeUnsubscribe) - { - break; - } - } + StringStreamInitFromBstring(&ss, client->inPacket.payload); - if (!sub) + StreamReadUint16Be(&id, pss); + + if (client->onUnsubscribe) { - LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + client->onUnsubscribe(client, id); } - else - { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onUnsubscribe) - { - LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); - client->onUnsubscribe(client, MqttPacketId(packet)); - } - } + return 0; } -static int MqttClientRecvPacket(MqttClient *client) +static int MqttClientHandlePacket(MqttClient *client) { - MqttPacket *packet = NULL; + int rc; - if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) - return -1; - - LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); - - switch (MqttPacketType(packet)) + switch (client->inPacket.type) { case MqttPacketTypeConnAck: - MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + rc = MqttClientHandleConnAck(client); break; case MqttPacketTypePingResp: - MqttClientHandlePingResp(client); + rc = MqttClientHandlePingResp(client); break; case MqttPacketTypeSubAck: - MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + rc = MqttClientHandleSubAck(client); break; - case MqttPacketTypePublish: - MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + case MqttPacketTypeUnsubAck: + rc = MqttClientHandleUnsubAck(client); break; case MqttPacketTypePubAck: - MqttClientHandlePubAck(client, packet); + rc = MqttClientHandlePubAck(client); break; case MqttPacketTypePubRec: - MqttClientHandlePubRec(client, packet); + rc = MqttClientHandlePubRec(client); break; - case MqttPacketTypePubRel: - MqttClientHandlePubRel(client, packet); + case MqttPacketTypePubComp: + rc = MqttClientHandlePubComp(client); break; - case MqttPacketTypePubComp: - MqttClientHandlePubComp(client, packet); + case MqttPacketTypePubRel: + rc = MqttClientHandlePubRel(client); break; - case MqttPacketTypeUnsubAck: - MqttClientHandleUnsubAck(client, packet); + case MqttPacketTypePublish: + rc = MqttClientHandlePublish(client); break; default: - LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + LOG_ERROR("packet not handled yet"); + rc = -1; break; } - MqttPacketFree(packet); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + + client->inPacket.state = MqttPacketStateReadType; + + return rc; +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + while (1) + { + switch (client->inPacket.state) + { + case MqttPacketStateReadType: + { + unsigned char typeAndFlags; + + if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet type"); + return -1; + } + + client->inPacket.type = typeAndFlags >> 4; + client->inPacket.flags = typeAndFlags & 0x0F; + + if (client->inPacket.type < MqttPacketTypeConnect || + client->inPacket.type > MqttPacketTypeDisconnect) + { + LOG_ERROR("unknown packet type: %d", client->inPacket.type); + return -1; + } + + client->inPacket.state = MqttPacketStateReadRemainingLength; + client->inPacket.remainingLength = 0; + client->inPacket.remainingLengthMul = 1; + client->inPacket.payload = NULL; + + break; + } + + case MqttPacketStateReadRemainingLength: + { + if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->inPacket.remainingLengthMul, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed to read remaining length"); + return -1; + } + + LOG_DEBUG("remainingLength:%lu", + client->inPacket.remainingLength); + + client->inPacket.state = MqttPacketStateReadPayload; + + break; + } + + case MqttPacketStateReadPayload: + { + if (client->inPacket.remainingLength > 0) + { + int64_t nread, offset, toread; + + if (client->inPacket.payload == NULL) + { + unsigned char *data; + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + data = client->inPacket.payload->data; + data[client->inPacket.remainingLength] = '\0'; + } + + offset = blength(client->inPacket.payload); + + toread = 16*1024; + + if (client->inPacket.remainingLength < (size_t) toread) + toread = client->inPacket.remainingLength; + + nread = StreamRead(bdataofs(client->inPacket.payload, + offset), + toread, + &client->stream.base); + + if (nread == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet payload"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + else if (nread == 0) + { + LOG_ERROR("socket disconnected"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + + client->inPacket.remainingLength -= nread; + client->inPacket.payload->slen += nread; + + LOG_DEBUG("nread:%d", (int) nread); + } + + if (client->inPacket.remainingLength == 0) + { + client->inPacket.state = MqttPacketStateReadComplete; + } + break; + } + + case MqttPacketStateReadComplete: + { + int type = client->inPacket.type; + LOG_DEBUG("received %s", MqttPacketName(type)); + return MqttClientHandlePacket(client); + } + } + } return 0; } @@ -1072,101 +1560,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client) return id; } -static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +static int64_t MqttMessageTimeSinceSent(MqttMessage *msg) { int64_t now = MqttGetCurrentTime(); - return now - packet->sentAt; + return now - msg->timestamp; } -static void MqttClientProcessInMessages(MqttClient *client) +static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg) { - MqttPacket *packet, *next; - - LOG_DEBUG("processing inMessages"); - - TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + if (msg->timestamp > 0 && + MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000) { - LOG_DEBUG("packet type:%s id:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - - if (MqttPacketType(packet) == MqttPacketTypePubComp) - { - int64_t elapsed = MqttPacketTimeSinceSent(packet); - if (packet->sentAt > 0 && - elapsed >= client->retryTimeout*1000) - { - LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, - MqttPacketId(packet), elapsed); - - TAILQ_REMOVE(&client->inMessages, packet, messages); - - MqttPacketFree(packet); - } - } + return 1; } + + return 0; } -static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +static void MqttClientProcessInMessages(MqttClient *client) { - if (packet->sentAt > 0 && - MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + MqttMessage *msg, *next; + + TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next) { - return 1; - } + switch (msg->state) + { + case MqttMessageStateWaitPubRel: + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRec(client, msg); + } + break; - return 0; + default: + break; + } + } } static void MqttClientProcessOutMessages(MqttClient *client) { - MqttPacket *packet, *next; + MqttMessage *msg, *next; + MqttPacket *packet; int inflight = MqttClientInflightMessageCount(client); - LOG_DEBUG("processing outMessages inflight:%d", inflight); - - TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next) { - LOG_DEBUG("packet type:%s id:%d state:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet), - packet->state); - - switch (packet->state) + switch (msg->state) { - case MessageStateQueued: + case MqttMessageStateQueued: + { if (inflight >= client->maxInflight) { - LOG_DEBUG("cannot dequeue %s/%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - break; + continue; } - else - { - /* If there's less than maxInflight messages currently - inflight, we can dequeue some messages by falling - through to MessageStateSend. */ - LOG_DEBUG("dequeuing %s (%d)", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - ++inflight; - } - - case MessageStateSend: - packet->state = MessageStateSent; + /* State change from MqttMessageStatePublish happens after + the packet has been sent (in MqttClientSendPacket). */ + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); MqttClientQueuePacket(client, packet); + ++inflight; break; + } - case MessageStateSent: - if (MqttPacketShouldResend(client, packet)) + case MqttMessageStateWaitPubAck: + case MqttMessageStateWaitPubRec: + { + if (MqttMessageShouldResend(client, msg)) { - packet->state = MessageStateSend; + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); + MqttClientQueuePacket(client, packet); } break; + } - default: + case MqttMessageStateWaitPubComp: + { + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRel(client, msg); + } break; + } } } } @@ -1182,30 +1658,22 @@ static void MqttClientClearQueues(MqttClient *client) while (!SIMPLEQ_EMPTY(&client->sendQueue)) { MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue); - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) - { - MqttPacketFree(packet); - } + MqttPacketFree(packet); } while (!TAILQ_EMPTY(&client->outMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->outMessages); - TAILQ_REMOVE(&client->outMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->outMessages); + TAILQ_REMOVE(&client->outMessages, msg, chain); + MqttMessageFree(msg); } while (!TAILQ_EMPTY(&client->inMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->inMessages); - TAILQ_REMOVE(&client->inMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->inMessages); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); } } |
