diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/CMakeLists.txt | 4 | ||||
| -rw-r--r-- | src/client.c | 1310 | ||||
| -rw-r--r-- | src/deserialize.c | 286 | ||||
| -rw-r--r-- | src/deserialize.h | 11 | ||||
| -rw-r--r-- | src/message.c | 11 | ||||
| -rw-r--r-- | src/message.h | 40 | ||||
| -rw-r--r-- | src/mqtt.h | 6 | ||||
| -rw-r--r-- | src/packet.c | 76 | ||||
| -rw-r--r-- | src/packet.h | 98 | ||||
| -rw-r--r-- | src/serialize.c | 326 | ||||
| -rw-r--r-- | src/serialize.h | 11 | ||||
| -rw-r--r-- | src/socket.c | 73 | ||||
| -rw-r--r-- | src/socket.h | 26 | ||||
| -rw-r--r-- | src/stream.c | 10 | ||||
| -rw-r--r-- | src/stream.h | 2 | ||||
| -rw-r--r-- | src/stream_mqtt.c | 30 | ||||
| -rw-r--r-- | src/stream_mqtt.h | 6 | ||||
| -rw-r--r-- | src/stringstream.c | 115 | ||||
| -rw-r--r-- | src/stringstream.h | 21 |
19 files changed, 1219 insertions, 1243 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f51fabb..5a565ca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,14 +2,14 @@ ADD_SUBDIRECTORY(lib) ADD_LIBRARY(mqtt STATIC client.c - deserialize.c misc.c packet.c - serialize.c socket.c socketstream.c stream.c stream_mqtt.c + stringstream.c + message.c $<TARGET_OBJECTS:bstrlib> ) diff --git a/src/client.c b/src/client.c index a6b0998..b95c8d5 100644 --- a/src/client.c +++ b/src/client.c @@ -5,10 +5,11 @@ #include "socketstream.h" #include "socket.h" #include "misc.h" -#include "serialize.h" -#include "deserialize.h" #include "log.h" #include "private.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "message.h" #include "queue.h" @@ -25,8 +26,14 @@ #error define PRId64 for your platform #endif -TAILQ_HEAD(MessageList, MqttPacket); -typedef struct MessageList MessageList; +typedef enum MqttClientState MqttClientState; + +enum MqttClientState +{ + MqttClientStateDisconnected, + MqttClientStateConnecting, + MqttClientStateConnected, +}; struct MqttClient { @@ -56,9 +63,9 @@ struct MqttClient /* packets waiting to be sent over network */ SIMPLEQ_HEAD(, MqttPacket) sendQueue; /* sent messages that are not done yet */ - MessageList outMessages; + MqttMessageList outMessages; /* received messages that are not done yet */ - MessageList inMessages; + MqttMessageList inMessages; int sessionPresent; /* when was the last packet sent */ int64_t lastPacketSentTime; @@ -80,18 +87,14 @@ struct MqttClient int paused; bstring userName; bstring password; -}; - -enum MessageState -{ - MessageStateQueued = 100, - MessageStateSend, - MessageStateSent + /* The packet we are receiving */ + MqttPacket inPacket; + MqttClientState state; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); static int MqttClientQueueSimplePacket(MqttClient *client, int type); -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientSendPacket(MqttClient *client); static int MqttClientRecvPacket(MqttClient *client); static uint16_t MqttClientNextPacketId(MqttClient *client); static void MqttClientProcessMessageQueue(MqttClient *client); @@ -99,14 +102,14 @@ static void MqttClientClearQueues(MqttClient *client); static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int queued = 0; int inMessagesCount = 0; int outMessagesCount = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (packet->state == MessageStateQueued) + if (msg->state == MqttMessageStateQueued) { ++queued; } @@ -114,7 +117,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) ++outMessagesCount; } - TAILQ_FOREACH(packet, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { ++inMessagesCount; } @@ -142,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId) client->maxQueued = 0; client->maxInflight = 20; + client->state = MqttClientStateDisconnected; + TAILQ_INIT(&client->outMessages); TAILQ_INIT(&client->inMessages); SIMPLEQ_INIT(&client->sendQueue); @@ -157,6 +162,8 @@ void MqttClientFree(MqttClient *client) bdestroy(client->willTopic); bdestroy(client->willMessage); bdestroy(client->host); + bdestroy(client->userName); + bdestroy(client->password); if (client->stream.sock != -1) { @@ -212,15 +219,54 @@ void MqttClientSetOnPublish(MqttClient *client, client->onPublish = cb; } +static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static unsigned char MqttClientConnectFlags(MqttClient *client) +{ + unsigned char connectFlags = 0; + + if (client->cleanSession) + { + connectFlags |= 0x02; + } + + if (client->willTopic) + { + connectFlags |= 0x04; + connectFlags |= (client->willQos & 3) << 3; + connectFlags |= (client->willRetain & 1) << 5; + } + + if (client->userName) + { + connectFlags |= 0x80; + if (client->password) + { + connectFlags |= 0x40; + } + } + + return connectFlags; +} + int MqttClientConnect(MqttClient *client, const char *host, short port, int keepAlive, int cleanSession) { int sock; - MqttPacketConnect *packet; + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(host != NULL); + if (client->state != MqttClientStateDisconnected) + { + LOG_ERROR("client must be disconnected to connect"); + return -1; + } + if (client->host) bassigncstr(client->host, host); else @@ -242,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, LOG_DEBUG("connecting"); - if ((sock = SocketConnect(host, port)) == -1) + if ((sock = SocketConnect(host, port, 1)) == -1) { - LOG_ERROR("SocketConnect failed!"); - return -1; + if (SocketErrno != SOCKET_EINPROGRESS) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } } if (SocketStreamOpen(&client->stream, sock) == -1) @@ -253,44 +302,39 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, return -1; } - packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + packet = MqttPacketNew(MqttPacketTypeConnect); if (!packet) return -1; - if (client->cleanSession) - { - packet->connectFlags |= 0x02; - } + StringStreamInit(&ss); - packet->keepAlive = client->keepAlive; - - packet->clientId = bstrcpy(client->clientId); + StreamWriteMqttString(&MqttProtocolId, pss); + StreamWriteByte(MqttProtocolLevel, pss); + StreamWriteByte(MqttClientConnectFlags(client), pss); + StreamWriteUint16Be(client->keepAlive, pss); + StreamWriteMqttString(client->clientId, pss); if (client->willTopic) { - packet->connectFlags |= 0x04; - - packet->willTopic = bstrcpy(client->willTopic); - packet->willMessage = bstrcpy(client->willMessage); - - packet->connectFlags |= (client->willQos & 3) << 3; - packet->connectFlags |= (client->willRetain & 1) << 5; + StreamWriteMqttString(client->willTopic, pss); + StreamWriteMqttString(client->willMessage, pss); } if (client->userName) { - packet->connectFlags |= 0x80; - packet->userName = bstrcpy(client->userName); - - if (client->password) + StreamWriteMqttString(client->userName, pss); + if(client->password) { - packet->connectFlags |= 0x40; - packet->password = bstrcpy(client->password); + StreamWriteMqttString(client->password, pss); } } - MqttClientQueuePacket(client, &packet->base); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + client->state = MqttClientStateConnecting; return 0; } @@ -303,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client) int MqttClientIsConnected(MqttClient *client) { - return client->stream.sock != -1; + return client->stream.sock != -1 && + client->state == MqttClientStateConnected; } int MqttClientRunOnce(MqttClient *client, int timeout) { int rv; - int events; + int events = 0; assert(client != NULL); @@ -319,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout) return -1; } - events = EV_READ; + if (client->state == MqttClientStateConnected) + { + events = EV_READ; - /* Handle outMessages and inMessages, moving queued messages to sendQueue - if there are less than maxInflight number of messages in flight */ - MqttClientProcessMessageQueue(client); + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); - if (SIMPLEQ_EMPTY(&client->sendQueue)) + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + events |= EV_WRITE; + } + } + else if (client->state == MqttClientStateConnecting) { - LOG_DEBUG("nothing to write"); + events = EV_WRITE; } else { - events |= EV_WRITE; + LOG_ERROR("not connected"); + return -1; } LOG_DEBUG("selecting"); @@ -339,6 +396,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (timeout < 0) { timeout = client->keepAlive * 1000; + if (timeout == 0) + { + timeout = 30 * 1000; + } + } + else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0) + { + timeout = client->keepAlive * 1000; } rv = SocketSelect(client->stream.sock, &events, timeout); @@ -354,22 +419,26 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (events & EV_WRITE) { - MqttPacket *packet; - LOG_DEBUG("socket writable"); - packet = SIMPLEQ_FIRST(&client->sendQueue); - - if (packet) + if (client->state == MqttClientStateConnecting) { - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (MqttClientSendPacket(client, packet) == -1) + int sockError; + SocketGetError(client->stream.sock, &sockError); + LOG_DEBUG("sockError: %d", sockError); + if (sockError == 0) { - LOG_ERROR("MqttClientSendPacket failed"); - client->stopped = 1; + LOG_DEBUG("connected!"); + client->state = MqttClientStateConnected; + return 0; } } + + if (MqttClientSendPacket(client) == -1) + { + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; + } } if (events & EV_READ) @@ -433,10 +502,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, } int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count) + int *qos, size_t count) { - MqttPacketSubscribe *packet = NULL; + MqttPacket *packet = NULL; size_t i; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(topicFilters != NULL); @@ -444,68 +515,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, assert(qos != NULL); assert(count > 0); - packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( - MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe, + MqttClientNextPacketId(client)); if (!packet) return -1; - packet->topicFilters = bstrListCreate(); - bstrListAllocMin(packet->topicFilters, count); + packet->flags = 0x2; + + StringStreamInit(&ss); - packet->qos = (int *) malloc(sizeof(int) * count); + StreamWriteUint16Be(packet->id, pss); + + LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id); for (i = 0; i < count; ++i) { - packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); - ++packet->topicFilters->qty; + struct tagbstring filter; + btfromcstr(filter, topicFilters[i]); + StreamWriteMqttString(&filter, pss); + StreamWriteByte(qos[i] & 3, pss); } - memcpy(packet->qos, qos, sizeof(int) * count); - - MqttClientQueuePacket(client, (MqttPacket *) packet); + packet->payload = ss.buffer; - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) { - MqttPacketUnsubscribe *packet = NULL; + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + struct tagbstring filter; assert(client != NULL); assert(topicFilter != NULL); - packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( - MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe, + MqttClientNextPacketId(client)); + + if (!packet) + return -1; + + packet->flags = 0x02; + + StringStreamInit(&ss); - packet->topicFilter = bfromcstr(topicFilter); + StreamWriteUint16Be(packet->id, pss); - MqttClientQueuePacket(client, (MqttPacket *) packet); + btfromcstr(filter, topicFilter); - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + StreamWriteMqttString(&filter, pss); - return MqttPacketId(packet); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return packet->id; } static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int count = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { ++count; } return count; } +static MqttPacket *PublishToPacket(MqttMessage *msg) +{ + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + + if (msg->qos > 0) + { + packet = MqttPacketWithIdNew(MqttPacketTypePublish, + msg->id); + } + else + { + packet = MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return NULL; + + packet->message = msg; + + StringStreamInit(&ss); + + StreamWriteMqttString(msg->topic, pss); + + if (msg->qos > 0) + { + StreamWriteUint16Be(msg->id, pss); + } + + StreamWrite(bdata(msg->payload), blength(msg->payload), pss); + + packet->payload = ss.buffer; + packet->flags = (msg->qos & 3) << 1; + packet->flags |= msg->retain & 1; + + return packet; +} + int MqttClientPublish(MqttClient *client, int qos, int retain, const char *topic, const void *data, size_t size) { - MqttPacketPublish *packet; - - assert(client != NULL); + MqttMessage *message; /* first check if the queue is already full */ if (qos > 0 && client->maxQueued > 0 && @@ -514,55 +639,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain, return -1; } - if (qos > 0) + message = calloc(1, sizeof(*message)); + if (!message) { - packet = (MqttPacketPublish *) MqttPacketWithIdNew( - MqttPacketTypePublish, MqttClientNextPacketId(client)); + return -1; } - else + + message->state = MqttMessageStateQueued; + message->qos = qos; + message->retain = retain; + message->dup = 0; + message->timestamp = MqttGetCurrentTime(); + + if (qos == 0) { - packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); - } + /* Copy payload and topic directly from user buffers as we don't need + to keep the message data around after this function. */ + MqttPacket *packet; + struct tagbstring bttopic, btpayload; - if (!packet) - return -1; + btfromcstr(bttopic, topic); + message->topic = &bttopic; - packet->qos = qos; - packet->retain = retain; - packet->topicName = bfromcstr(topic); - packet->message = blk2bstr(data, size); + btfromblk(btpayload, data, size); + message->payload = &btpayload; - if (qos > 0) - { - /* check how many messages there are coming in and going out currently - that are not yet done */ - if (client->maxInflight == 0 || - MqttClientInflightMessageCount(client) < client->maxInflight) - { - LOG_DEBUG("setting message (%d) state to MessageStateSend", - MqttPacketId(packet)); - packet->base.state = MessageStateSend; - } - else - { - LOG_DEBUG("setting message (%d) state to MessageStateQueued", - MqttPacketId(packet)); - packet->base.state = MessageStateQueued; - } + packet = PublishToPacket(message); - /* add the message to the outMessages queue to wait for processing */ - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, - messages); + message->topic = NULL; + message->payload = NULL; + + MqttClientQueuePacket(client, packet); + + MqttMessageFree(message); + + return 0; } else { - MqttClientQueuePacket(client, (MqttPacket *) packet); - } + /* Duplicate the user buffers as we need the data to be available + longer. */ + message->topic = bfromcstr(topic); + message->payload = blk2bstr(data, size); - if (qos > 0) - return MqttPacketId(packet); + message->id = MqttClientNextPacketId(client); - return 0; + TAILQ_INSERT_TAIL(&client->outMessages, message, chain); + + return message->id; + } } int MqttClientPublishCString(MqttClient *client, int qos, int retain, @@ -613,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName, { assert(client != NULL); - if (MqttClientIsConnected(client)) + if (client->state == MqttClientStateConnecting) { LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect"); return -1; @@ -655,6 +780,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) { assert(client != NULL); LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + packet->state = MqttPacketStateWriteType; SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); } @@ -667,128 +793,363 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type) return 0; } -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +static int MqttClientSendPacket(MqttClient *client) { - if (MqttPacketSerialize(packet, &client->stream.base) == -1) - return -1; + MqttPacket *packet; - packet->sentAt = MqttGetCurrentTime(); - client->lastPacketSentTime = packet->sentAt; + packet = SIMPLEQ_FIRST(&client->sendQueue); - if (packet->type == MqttPacketTypeDisconnect) + if (!packet) { - client->stopped = 1; + LOG_WARNING("MqttClientSendPacket called with no queued packets"); + return 0; } - /* If the packet is not on any message list, it can be removed after - sending. */ - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) + while (packet != NULL) { - LOG_DEBUG("freeing packet %s after sending", - MqttPacketName(MqttPacketType(packet))); - MqttPacketFree(packet); + switch (packet->state) + { + case MqttPacketStateWriteType: + { + unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) | + (packet->flags & 0x0F); + + if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWriteRemainingLength; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWriteRemainingLength: + { + if (StreamWriteRemainingLength(&packet->remainingLength, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWritePayload; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWritePayload: + { + if (packet->payload) + { + int64_t offset = blength(packet->payload) - packet->remainingLength; + int64_t nwritten = 0; + int towrite = 16*1024; + + if (packet->remainingLength < 16*1024) + towrite = packet->remainingLength; + + nwritten = StreamWrite(bdataofs(packet->payload, offset), + towrite, + &client->stream.base); + + if (nwritten == -1) + { + if (SocketWouldBlock(SocketErrno)) + { + return 0; + } + return -1; + } + + packet->remainingLength -= nwritten; + + LOG_DEBUG("nwritten:%d", (int) nwritten); + } + + if (packet->remainingLength == 0) + { + LOG_DEBUG("packet payload sent"); + packet->state = MqttPacketStateWriteComplete; + } + + break; + } + + case MqttPacketStateWriteComplete: + { + client->lastPacketSentTime = MqttGetCurrentTime(); + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + client->state = MqttClientStateDisconnected; + } + + LOG_DEBUG("sent %s", MqttPacketName(packet->type)); + + if (packet->type == MqttPacketTypePublish && packet->message) + { + MqttMessage *msg = packet->message; + + if (msg->qos == 1) + { + msg->state = MqttMessageStateWaitPubAck; + } + else if (msg->qos == 2) + { + msg->state = MqttMessageStateWaitPubRec; + } + } + + if (packet->message) + { + packet->message->timestamp = client->lastPacketSentTime; + } + + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + MqttPacketFree(packet); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + break; + } + } } return 0; } -static void MqttClientHandleConnAck(MqttClient *client, - MqttPacketConnAck *packet) +static int MqttClientHandleConnAck(MqttClient *client) { - client->sessionPresent = packet->connAckFlags & 1; + StringStream ss; + Stream *pss = (Stream *) &ss; + unsigned char flags; + unsigned char rc; + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadByte(&flags, pss); + + StreamReadByte(&rc, pss); + + client->sessionPresent = flags & 1; LOG_DEBUG("sessionPresent:%d", client->sessionPresent); if (client->onConnect) { - LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); - client->onConnect(client, packet->returnCode, client->sessionPresent); + LOG_DEBUG("calling onConnect rc:%d", rc); + client->onConnect(client, rc, client->sessionPresent); } + + return 0; } -static void MqttClientHandlePingResp(MqttClient *client) +static int MqttClientHandlePingResp(MqttClient *client) { LOG_DEBUG("got ping response"); client->pingSent = 0; + return 0; } -static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +static int MqttClientHandleSubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + int *qos; + StringStream ss; + Stream *pss = (Stream *) &ss; + int count; + int i; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + LOG_DEBUG("received SUBACK with id:%d", (int) id); + + count = blength(client->inPacket.payload) - StreamTell(pss); + + if (count <= 0) { - if (MqttPacketType(sub) == MqttPacketTypeSubscribe && - MqttPacketId(sub) == MqttPacketId(packet)) - { - break; - } + LOG_ERROR("number of return codes invalid"); + return -1; } - if (!sub) + qos = malloc(count * sizeof(int)); + + for (i = 0; i < count; ++i) { - LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + unsigned char byte; + StreamReadByte(&byte, pss); + qos[i] = byte; } - else + + if (client->onSubscribe) { - if (client->onSubscribe) - { - MqttPacketSubscribe *sub2; - int i; + client->onSubscribe(client, id, qos, count); + } - sub2 = (MqttPacketSubscribe *) sub; + free(qos); - for (i = 0; i < sub2->topicFilters->qty; ++i) - { - const char *filter = bdata(sub2->topicFilters->entry[i]); - int rc = packet->returnCode[i]; + return 0; +} - LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", - MqttPacketId(packet), filter, rc); +static int MqttClientSendPubAck(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; - client->onSubscribe(client, MqttPacketId(packet), filter, rc); - } - } + packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id); - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - } + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id); + + if (!packet) + return -1; + + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubComp(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; } -static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +static int MqttClientHandlePublish(MqttClient *client) { + MqttMessage *msg; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; + MqttPacket *packet; + int qos; + int retain; + bstring topic; + void *payload; + int payloadSize; + + /* We are paused - do nothing */ if (client->paused) - return; + return 0; + + packet = &client->inPacket; + + qos = (packet->flags >> 1) & 3; + retain = packet->flags & 1; + + StringStreamInitFromBstring(&ss, packet->payload); + + StreamReadMqttString(&topic, pss); + + if (qos > 0) + { + StreamReadUint16Be(&id, pss); + } + + payload = bdataofs(ss.buffer, ss.pos); + payloadSize = blength(ss.buffer) - ss.pos; - if (MqttPacketPublishQos(packet) == 2) + if (qos == 2) { /* Check if we have sent a PUBREC previously with the same id. If we have, we have to resend the PUBREC. We must not call the onMessage callback again. */ - MqttPacket *pubRec; - - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (pubRec) + if (msg) { - LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); - MqttClientQueuePacket(client, pubRec); - return; + LOG_DEBUG("resending PUBREC id:%u", msg->id); + MqttClientSendPubRec(client, msg); + bdestroy(topic); + return 0; } } @@ -796,268 +1157,395 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe { LOG_DEBUG("calling onMessage"); client->onMessage(client, - bdata(packet->topicName), - bdata(packet->message), - blength(packet->message), - packet->qos, - packet->retain); + bdata(topic), + payload, + payloadSize, + qos, + retain); } - if (MqttPacketPublishQos(packet) > 0) + bdestroy(topic); + + if (qos == 1) + { + MqttClientSendPubAck(client, id); + } + else if (qos == 2) { - int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : - MqttPacketTypePubRec; + msg = calloc(1, sizeof(*msg)); - MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + msg->state = MqttMessageStateWaitPubRel; + msg->id = id; + msg->qos = qos; - if (MqttPacketPublishQos(packet) == 2) - { - /* append to inMessages as we need a reply to this response */ - TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); - } + TAILQ_INSERT_TAIL(&client->inMessages, msg, chain); - MqttClientQueuePacket(client, resp); + MqttClientSendPubRec(client, msg); } + + return 0; } -static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubAck(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubAck) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); - if (client->onPublish) - { - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + if (client->onPublish) + { + client->onPublish(client, msg->id); } + + MqttMessageFree(msg); + + return 0; } -static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRec(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but + they haven't received it. */ + if (msg->id == id && + (msg->state == MqttMessageStateWaitPubRec || + msg->state == MqttMessageStateWaitPubComp)) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubRel; - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); + msg->state = MqttMessageStateWaitPubComp; - pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); - pubRel->state = MessageStateSend; + bdestroy(msg->payload); + msg->payload = NULL; - TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); - } + bdestroy(msg->topic); + msg->topic = NULL; + + if (MqttClientSendPubRel(client, msg) == -1) + return -1; + + return 0; } -static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRel(MqttClient *client) { - MqttPacket *pubRec; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (!pubRec) + if (!msg) { - MqttPacket *pubComp; - - TAILQ_FOREACH(pubComp, &client->inMessages, messages) - { - if (MqttPacketId(pubComp) == MqttPacketId(packet) && - MqttPacketType(pubComp) == MqttPacketTypePubComp) - { - break; - } - } - - if (pubComp) - { - MqttClientQueuePacket(client, pubComp); - } - else - { - LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; - } + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubComp; - TAILQ_REMOVE(&client->inMessages, pubRec, messages); - MqttPacketFree(pubRec); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); - pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, - MqttPacketId(packet)); - - TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + if (MqttClientSendPubComp(client, id) == -1) + return -1; - MqttClientQueuePacket(client, pubComp); - } + return 0; } -static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubComp(MqttClient *client) { - MqttPacket *pubRel; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pubRel, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pubRel) == MqttPacketId(packet) && - MqttPacketType(pubRel) == MqttPacketTypePubRel) + if (msg->id == id && msg->state == MqttMessageStateWaitPubComp) { break; } } - if (!pubRel) + if (!msg) { - LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_WARNING("no message found with id %d", (int) id); + return 0; } - else - { - TAILQ_REMOVE(&client->outMessages, pubRel, messages); - MqttPacketFree(pubRel); - if (client->onPublish) - { - LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + MqttMessageFree(msg); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", id); + client->onPublish(client, id); } + + return 0; } -static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandleUnsubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) - { - if (MqttPacketId(sub) == MqttPacketId(packet) && - MqttPacketType(sub) == MqttPacketTypeUnsubscribe) - { - break; - } - } + StringStreamInitFromBstring(&ss, client->inPacket.payload); - if (!sub) + StreamReadUint16Be(&id, pss); + + if (client->onUnsubscribe) { - LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + client->onUnsubscribe(client, id); } - else - { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onUnsubscribe) - { - LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); - client->onUnsubscribe(client, MqttPacketId(packet)); - } - } + return 0; } -static int MqttClientRecvPacket(MqttClient *client) +static int MqttClientHandlePacket(MqttClient *client) { - MqttPacket *packet = NULL; + int rc; - if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) - return -1; - - LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); - - switch (MqttPacketType(packet)) + switch (client->inPacket.type) { case MqttPacketTypeConnAck: - MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + rc = MqttClientHandleConnAck(client); break; case MqttPacketTypePingResp: - MqttClientHandlePingResp(client); + rc = MqttClientHandlePingResp(client); break; case MqttPacketTypeSubAck: - MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + rc = MqttClientHandleSubAck(client); break; - case MqttPacketTypePublish: - MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + case MqttPacketTypeUnsubAck: + rc = MqttClientHandleUnsubAck(client); break; case MqttPacketTypePubAck: - MqttClientHandlePubAck(client, packet); + rc = MqttClientHandlePubAck(client); break; case MqttPacketTypePubRec: - MqttClientHandlePubRec(client, packet); + rc = MqttClientHandlePubRec(client); break; - case MqttPacketTypePubRel: - MqttClientHandlePubRel(client, packet); + case MqttPacketTypePubComp: + rc = MqttClientHandlePubComp(client); break; - case MqttPacketTypePubComp: - MqttClientHandlePubComp(client, packet); + case MqttPacketTypePubRel: + rc = MqttClientHandlePubRel(client); break; - case MqttPacketTypeUnsubAck: - MqttClientHandleUnsubAck(client, packet); + case MqttPacketTypePublish: + rc = MqttClientHandlePublish(client); break; default: - LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + LOG_ERROR("packet not handled yet"); + rc = -1; break; } - MqttPacketFree(packet); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + + client->inPacket.state = MqttPacketStateReadType; + + return rc; +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + while (1) + { + switch (client->inPacket.state) + { + case MqttPacketStateReadType: + { + unsigned char typeAndFlags; + + if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet type"); + return -1; + } + + client->inPacket.type = typeAndFlags >> 4; + client->inPacket.flags = typeAndFlags & 0x0F; + + if (client->inPacket.type < MqttPacketTypeConnect || + client->inPacket.type > MqttPacketTypeDisconnect) + { + LOG_ERROR("unknown packet type: %d", client->inPacket.type); + return -1; + } + + client->inPacket.state = MqttPacketStateReadRemainingLength; + client->inPacket.remainingLength = 0; + client->inPacket.remainingLengthMul = 1; + client->inPacket.payload = NULL; + + break; + } + + case MqttPacketStateReadRemainingLength: + { + if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->inPacket.remainingLengthMul, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed to read remaining length"); + return -1; + } + + LOG_DEBUG("remainingLength:%lu", + client->inPacket.remainingLength); + + client->inPacket.state = MqttPacketStateReadPayload; + + break; + } + + case MqttPacketStateReadPayload: + { + if (client->inPacket.remainingLength > 0) + { + int64_t nread, offset, toread; + + if (client->inPacket.payload == NULL) + { + unsigned char *data; + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + data = client->inPacket.payload->data; + data[client->inPacket.remainingLength] = '\0'; + } + + offset = blength(client->inPacket.payload); + + toread = 16*1024; + + if (client->inPacket.remainingLength < (size_t) toread) + toread = client->inPacket.remainingLength; + + nread = StreamRead(bdataofs(client->inPacket.payload, + offset), + toread, + &client->stream.base); + + if (nread == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet payload"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + else if (nread == 0) + { + LOG_ERROR("socket disconnected"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + + client->inPacket.remainingLength -= nread; + client->inPacket.payload->slen += nread; + + LOG_DEBUG("nread:%d", (int) nread); + } + + if (client->inPacket.remainingLength == 0) + { + client->inPacket.state = MqttPacketStateReadComplete; + } + break; + } + + case MqttPacketStateReadComplete: + { + int type = client->inPacket.type; + LOG_DEBUG("received %s", MqttPacketName(type)); + return MqttClientHandlePacket(client); + } + } + } return 0; } @@ -1072,101 +1560,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client) return id; } -static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +static int64_t MqttMessageTimeSinceSent(MqttMessage *msg) { int64_t now = MqttGetCurrentTime(); - return now - packet->sentAt; + return now - msg->timestamp; } -static void MqttClientProcessInMessages(MqttClient *client) +static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg) { - MqttPacket *packet, *next; - - LOG_DEBUG("processing inMessages"); - - TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + if (msg->timestamp > 0 && + MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000) { - LOG_DEBUG("packet type:%s id:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - - if (MqttPacketType(packet) == MqttPacketTypePubComp) - { - int64_t elapsed = MqttPacketTimeSinceSent(packet); - if (packet->sentAt > 0 && - elapsed >= client->retryTimeout*1000) - { - LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, - MqttPacketId(packet), elapsed); - - TAILQ_REMOVE(&client->inMessages, packet, messages); - - MqttPacketFree(packet); - } - } + return 1; } + + return 0; } -static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +static void MqttClientProcessInMessages(MqttClient *client) { - if (packet->sentAt > 0 && - MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + MqttMessage *msg, *next; + + TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next) { - return 1; - } + switch (msg->state) + { + case MqttMessageStateWaitPubRel: + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRec(client, msg); + } + break; - return 0; + default: + break; + } + } } static void MqttClientProcessOutMessages(MqttClient *client) { - MqttPacket *packet, *next; + MqttMessage *msg, *next; + MqttPacket *packet; int inflight = MqttClientInflightMessageCount(client); - LOG_DEBUG("processing outMessages inflight:%d", inflight); - - TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next) { - LOG_DEBUG("packet type:%s id:%d state:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet), - packet->state); - - switch (packet->state) + switch (msg->state) { - case MessageStateQueued: + case MqttMessageStateQueued: + { if (inflight >= client->maxInflight) { - LOG_DEBUG("cannot dequeue %s/%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - break; + continue; } - else - { - /* If there's less than maxInflight messages currently - inflight, we can dequeue some messages by falling - through to MessageStateSend. */ - LOG_DEBUG("dequeuing %s (%d)", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - ++inflight; - } - - case MessageStateSend: - packet->state = MessageStateSent; + /* State change from MqttMessageStatePublish happens after + the packet has been sent (in MqttClientSendPacket). */ + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); MqttClientQueuePacket(client, packet); + ++inflight; break; + } - case MessageStateSent: - if (MqttPacketShouldResend(client, packet)) + case MqttMessageStateWaitPubAck: + case MqttMessageStateWaitPubRec: + { + if (MqttMessageShouldResend(client, msg)) { - packet->state = MessageStateSend; + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); + MqttClientQueuePacket(client, packet); } break; + } - default: + case MqttMessageStateWaitPubComp: + { + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRel(client, msg); + } break; + } } } } @@ -1182,30 +1658,22 @@ static void MqttClientClearQueues(MqttClient *client) while (!SIMPLEQ_EMPTY(&client->sendQueue)) { MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue); - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) - { - MqttPacketFree(packet); - } + MqttPacketFree(packet); } while (!TAILQ_EMPTY(&client->outMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->outMessages); - TAILQ_REMOVE(&client->outMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->outMessages); + TAILQ_REMOVE(&client->outMessages, msg, chain); + MqttMessageFree(msg); } while (!TAILQ_EMPTY(&client->inMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->inMessages); - TAILQ_REMOVE(&client->inMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->inMessages); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); } } diff --git a/src/deserialize.c b/src/deserialize.c deleted file mode 100644 index 96d7789..0000000 --- a/src/deserialize.c +++ /dev/null @@ -1,286 +0,0 @@ -#include "deserialize.h" -#include "packet.h" -#include "stream_mqtt.h" -#include "log.h" - -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream); - -static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1) - return -1; - - if (StreamRead(&(*packet)->returnCode, 1, stream) != 1) - return -1; - - return 0; -} - -static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t i; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - return -1; - - remainingLength -= 2; - - (*packet)->returnCode = (unsigned char *) malloc( - sizeof(*(*packet)->returnCode) * remainingLength); - - for (i = 0; i < remainingLength; ++i) - { - if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t payloadSize = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadMqttString(&(*packet)->topicName, stream) == -1) - return -1; - - LOG_DEBUG("remainingLength:%lu", remainingLength); - - payloadSize = remainingLength - blength((*packet)->topicName) - 2; - - LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet), - payloadSize); - - if (MqttPacketHasId((const MqttPacket *) *packet)) - { - LOG_DEBUG("packet has id"); - payloadSize -= 2; - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - { - return -1; - } - } - - LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize); - - /* Allocate extra byte for a NULL terminator. If the user tries to print - the payload directly. */ - - (*packet)->message = bfromcstralloc(payloadSize+1, ""); - - if (StreamRead(bdata((*packet)->message), payloadSize, stream) == -1) - return -1; - - (*packet)->message->slen = payloadSize; - (*packet)->message->data[payloadSize] = '\0'; - - return 0; -} - -static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - char buffer[256]; - - (void) packet; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - while (remainingLength > 0) - { - size_t l = sizeof(buffer); - - if (remainingLength < l) - l = remainingLength; - - if (StreamRead(buffer, l, stream) != (int64_t) l) - return -1; - - remainingLength -= l; - } - - return 0; -} - -static int ValidateFlags(int type, int flags) -{ - int rv = 0; - - switch (type) - { - case MqttPacketTypePublish: - { - int qos = (flags >> 1) & 2; - if (qos >= 0 && qos <= 2) - rv = 1; - break; - } - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - if (flags == 2) - { - rv = 1; - } - break; - - default: - if (flags == 0) - { - rv = 1; - } - break; - } - - return rv; -} - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream) -{ - MqttPacketDeserializeFunc deserializer = NULL; - char typeAndFlags; - int type; - int flags; - int rv; - - if (StreamRead(&typeAndFlags, 1, stream) != 1) - return -1; - - type = (typeAndFlags & 0xF0) >> 4; - flags = (typeAndFlags & 0x0F); - - if (!ValidateFlags(type, flags)) - { - return -1; - } - - switch (type) - { - case MqttPacketTypeConnect: - break; - - case MqttPacketTypeConnAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize; - break; - - case MqttPacketTypePublish: - deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize; - break; - - case MqttPacketTypePubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRec: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRel: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubComp: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypeSubscribe: - break; - - case MqttPacketTypeSubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize; - break; - - case MqttPacketTypeUnsubscribe: - break; - - case MqttPacketTypeUnsubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize; - break; - - case MqttPacketTypePingReq: - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - break; - - default: - return -1; - } - - if (!deserializer) - { - deserializer = MqttPacketGenericDeserializer; - } - - *packet = MqttPacketNew(type); - - if (!*packet) - return -1; - - if (type == MqttPacketTypePublish) - { - MqttPacketPublishDup(*packet) = (flags >> 3) & 1; - MqttPacketPublishQos(*packet) = (flags >> 1) & 3; - MqttPacketPublishRetain(*packet) = flags & 1; - } - - rv = deserializer(packet, stream); - - return rv; -} diff --git a/src/deserialize.h b/src/deserialize.h deleted file mode 100644 index 8c29b3d..0000000 --- a/src/deserialize.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef DESERIALIZE_H -#define DESERIALIZE_H - -#include "config.h" - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream); - -#endif diff --git a/src/message.c b/src/message.c new file mode 100644 index 0000000..35d9c32 --- /dev/null +++ b/src/message.c @@ -0,0 +1,11 @@ +#include "message.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "packet.h" + +void MqttMessageFree(MqttMessage *msg) +{ + bdestroy(msg->topic); + bdestroy(msg->payload); + free(msg); +} diff --git a/src/message.h b/src/message.h new file mode 100644 index 0000000..04a3d61 --- /dev/null +++ b/src/message.h @@ -0,0 +1,40 @@ +#ifndef MESSAGE_H +#define MESSAGE_H + +#include <stdint.h> + +#include "queue.h" +#include <bstrlib/bstrlib.h> + +enum MqttMessageState +{ + MqttMessageStateQueued, + MqttMessageStatePublish, + MqttMessageStateWaitPubAck, + MqttMessageStateWaitPubRec, + MqttMessageStateWaitPubComp, + MqttMessageStateWaitPubRel +}; + +typedef struct MqttMessage MqttMessage; + +struct MqttMessage +{ + int state; + int qos; + int retain; + int dup; + int padding; + uint16_t id; + int64_t timestamp; + bstring topic; + bstring payload; + TAILQ_ENTRY(MqttMessage) chain; +}; + +typedef struct MqttMessageList MqttMessageList; +TAILQ_HEAD(MqttMessageList, MqttMessage); + +void MqttMessageFree(MqttMessage *msg); + +#endif @@ -33,8 +33,8 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client, typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client, int id, - const char *topicFilter, - MqttSubscriptionStatus status); + int *qos, + int count); typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id); @@ -84,7 +84,7 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, int qos); int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count); + int *qos, size_t count); int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter); diff --git a/src/packet.c b/src/packet.c index 47aa689..c833851 100644 --- a/src/packet.c +++ b/src/packet.c @@ -28,42 +28,16 @@ const char *MqttPacketName(int type) } } -static MQTT_INLINE size_t MqttPacketStructSize(int type) -{ - switch (type) - { - case MqttPacketTypeConnect: return sizeof(MqttPacketConnect); - case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck); - case MqttPacketTypePublish: return sizeof(MqttPacketPublish); - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: return sizeof(MqttPacket); - case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe); - case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck); - case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe); - case MqttPacketTypeUnsubAck: return sizeof(MqttPacket); - case MqttPacketTypePingReq: return sizeof(MqttPacket); - case MqttPacketTypePingResp: return sizeof(MqttPacket); - case MqttPacketTypeDisconnect: return sizeof(MqttPacket); - default: return (size_t) -1; - } -} - MqttPacket *MqttPacketNew(int type) { MqttPacket *packet = NULL; - packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type)); + packet = (MqttPacket *) calloc(1, sizeof(*packet)); if (!packet) return NULL; packet->type = type; - /* this will make sure that TAILQ_PREV does not segfault if a message - has not been added to a list at any point */ - packet->messages.tqe_prev = &packet->messages.tqe_next; - return packet; } @@ -78,52 +52,6 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id) void MqttPacketFree(MqttPacket *packet) { - if (MqttPacketType(packet) == MqttPacketTypeConnect) - { - MqttPacketConnect *p = (MqttPacketConnect *) packet; - bdestroy(p->clientId); - bdestroy(p->willTopic); - bdestroy(p->willMessage); - bdestroy(p->userName); - bdestroy(p->password); - } - else if (MqttPacketType(packet) == MqttPacketTypePublish) - { - MqttPacketPublish *p = (MqttPacketPublish *) packet; - bdestroy(p->topicName); - bdestroy(p->message); - } - else if (MqttPacketType(packet) == MqttPacketTypeSubscribe) - { - MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet; - bstrListDestroy(p->topicFilters); - } - else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe) - { - MqttPacketUnsubscribe *p = (MqttPacketUnsubscribe *) packet; - bdestroy(p->topicFilter); - } + bdestroy(packet->payload); free(packet); } - -int MqttPacketHasId(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return MqttPacketPublishQos(packet) > 0; - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - case MqttPacketTypeSubscribe: - case MqttPacketTypeSubAck: - case MqttPacketTypeUnsubscribe: - case MqttPacketTypeUnsubAck: - return 1; - - default: - return 0; - } -} diff --git a/src/packet.h b/src/packet.h index 7ab4f73..a5e2ce7 100644 --- a/src/packet.h +++ b/src/packet.h @@ -29,87 +29,35 @@ enum MqttPacketTypeDisconnect = 0xE }; +enum MqttPacketState +{ + MqttPacketStateReadType, + MqttPacketStateReadRemainingLength, + MqttPacketStateReadPayload, + MqttPacketStateReadComplete, + + MqttPacketStateWriteType, + MqttPacketStateWriteRemainingLength, + MqttPacketStateWritePayload, + MqttPacketStateWriteComplete +}; + +struct MqttMessage; + typedef struct MqttPacket MqttPacket; struct MqttPacket { int type; - uint16_t id; - int state; int flags; - int64_t sentAt; + int state; + uint16_t id; + size_t remainingLength; + size_t remainingLengthMul; + /* TODO: maybe switch to have a StringStream here? */ + bstring payload; + struct MqttMessage *message; SIMPLEQ_ENTRY(MqttPacket) sendQueue; - TAILQ_ENTRY(MqttPacket) messages; -}; - -#define MqttPacketType(packet) (((MqttPacket *) (packet))->type) - -#define MqttPacketId(packet) (((MqttPacket *) (packet))->id) - -#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt) - -typedef struct MqttPacketConnect MqttPacketConnect; - -struct MqttPacketConnect -{ - MqttPacket base; - char connectFlags; - uint16_t keepAlive; - bstring clientId; - bstring willTopic; - bstring willMessage; - bstring userName; - bstring password; -}; - -typedef struct MqttPacketConnAck MqttPacketConnAck; - -struct MqttPacketConnAck -{ - MqttPacket base; - unsigned char connAckFlags; - unsigned char returnCode; -}; - -typedef struct MqttPacketPublish MqttPacketPublish; - -struct MqttPacketPublish -{ - MqttPacket base; - bstring topicName; - bstring message; - char qos; - char dup; - char retain; -}; - -#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos) -#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup) -#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain) - -typedef struct MqttPacketSubscribe MqttPacketSubscribe; - -struct MqttPacketSubscribe -{ - MqttPacket base; - struct bstrList *topicFilters; - int *qos; -}; - -typedef struct MqttPacketSubAck MqttPacketSubAck; - -struct MqttPacketSubAck -{ - MqttPacket base; - unsigned char *returnCode; -}; - -typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe; - -struct MqttPacketUnsubscribe -{ - MqttPacket base; - bstring topicFilter; }; const char *MqttPacketName(int type); @@ -120,6 +68,4 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id); void MqttPacketFree(MqttPacket *packet); -int MqttPacketHasId(const MqttPacket *packet); - #endif diff --git a/src/serialize.c b/src/serialize.c deleted file mode 100644 index c1c8eb4..0000000 --- a/src/serialize.c +++ /dev/null @@ -1,326 +0,0 @@ -#include "serialize.h" -#include "packet.h" -#include "stream_mqtt.h" -#include "log.h" - -#include <bstrlib/bstrlib.h> - -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet, - Stream *stream); - -static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); -static const char MqttProtocolLevel = 0x04; - -static MQTT_INLINE size_t MqttStringLengthSerialized(const_bstring s) -{ - return 2 + blength(s); -} - -static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2; - - remainingLength += MqttStringLengthSerialized(packet->clientId); - - if (packet->connectFlags & 0x80) - remainingLength += MqttStringLengthSerialized(packet->userName); - - if (packet->connectFlags & 0x40) - remainingLength += MqttStringLengthSerialized(packet->password); - - if (packet->connectFlags & 0x04) - remainingLength += MqttStringLengthSerialized(packet->willTopic) + - MqttStringLengthSerialized(packet->willMessage); - - return remainingLength; -} - -static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet) -{ - size_t remaining = 2; - int i; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]); - remaining += 1; - } - - return remaining; -} - -static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet) -{ - return 2 + MqttStringLengthSerialized(packet->topicFilter); -} - -static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(packet->topicName); - - /* Packet id */ - if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2) - { - remainingLength += 2; - } - - remainingLength += blength(packet->message); - - return remainingLength; -} - -static size_t MqttPacketGetRemainingLength(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypeConnect: - return MqttPacketConnectGetRemainingLength( - (MqttPacketConnect *) packet); - - case MqttPacketTypeSubscribe: - return MqttPacketSubscribeGetRemainingLength( - (MqttPacketSubscribe *) packet); - - case MqttPacketTypePublish: - return MqttPacketPublishGetRemainingLength( - (MqttPacketPublish *) packet); - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - return 2; - - case MqttPacketTypeUnsubscribe: - return MqttPacketUnsubscribeGetRemainingLength( - (MqttPacketUnsubscribe *) packet); - - default: - return 0; - } -} - -static int MqttPacketFlags(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return ((MqttPacketPublishDup(packet) & 1) << 3) | - ((MqttPacketPublishQos(packet) & 3) << 1) | - (MqttPacketPublishRetain(packet) & 1); - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - return 0x2; - - default: - return 0; - } -} - -static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream) -{ - unsigned char typeAndFlags; - size_t remainingLength; - - typeAndFlags = ((packet->type & 0x0F) << 4) | - (MqttPacketFlags(packet) & 0x0F); - remainingLength = MqttPacketGetRemainingLength(packet); - - LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type, - MqttPacketName(packet->type), MqttPacketFlags(packet)); - - if (StreamWrite(&typeAndFlags, 1, stream) != 1) - return -1; - - if (StreamWriteRemainingLength(remainingLength, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream) -{ - assert(MqttPacketHasId((const MqttPacket *) packet)); - - if (MqttPacketBaseSerialize(packet, stream) == -1) - return -1; - - if (StreamWriteUint16Be(packet->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize(&packet->base, stream) == -1) - return -1; - - if (StreamWriteMqttString(&MqttProtocolId, stream) == -1) - return -1; - - if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1) - return -1; - - if (StreamWrite(&packet->connectFlags, 1, stream) != 1) - return -1; - - if (StreamWriteUint16Be(packet->keepAlive, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->clientId, stream) == -1) - return -1; - - if (packet->connectFlags & 0x04) - { - if (StreamWriteMqttString(packet->willTopic, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->willMessage, stream) == -1) - return -1; - } - - if (packet->connectFlags & 0x80) - { - if (StreamWriteMqttString(packet->userName, stream) == -1) - return -1; - - if (packet->connectFlags & 0x40) - { - if (StreamWriteMqttString(packet->password, stream) == -1) - return -1; - } - } - - return 0; -} - -static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream) -{ - int i; - - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - unsigned char qos = (unsigned char) packet->qos[i]; - - if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1) - return -1; - - if (StreamWrite(&qos, 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream) -{ - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicFilter, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicName, stream) == -1) - return -1; - - LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet)); - - if (MqttPacketPublishQos(packet) > 0) - { - if (StreamWriteUint16Be(packet->base.id, stream) == -1) - return -1; - } - - if (StreamWrite(bdata(packet->message), blength(packet->message), stream) == -1) - return -1; - - return 0; -} - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream) -{ - MqttPacketSerializeFunc f = NULL; - - switch (packet->type) - { - case MqttPacketTypeConnect: - f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize; - break; - - case MqttPacketTypeConnAck: - break; - - case MqttPacketTypePublish: - f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize; - break; - - case MqttPacketTypePubAck: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRec: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRel: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubComp: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypeSubscribe: - f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize; - break; - - case MqttPacketTypeSubAck: - break; - - case MqttPacketTypeUnsubscribe: - f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize; - break; - - case MqttPacketTypeUnsubAck: - break; - - case MqttPacketTypePingReq: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - default: - return -1; - } - - assert(f != NULL && "no serializer"); - - return f(packet, stream); -} diff --git a/src/serialize.h b/src/serialize.h deleted file mode 100644 index 7eb988f..0000000 --- a/src/serialize.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef SERIALIZE_H -#define SERIALIZE_H - -#include "config.h" - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream); - -#endif diff --git a/src/socket.c b/src/socket.c index 64a7c01..b70f4fb 100644 --- a/src/socket.c +++ b/src/socket.c @@ -6,18 +6,6 @@ #include <assert.h> #if defined(_WIN32) -#include "win32.h" -#else -#include <sys/types.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/select.h> -#include <netdb.h> -#include <unistd.h> -#include <arpa/inet.h> -#endif - -#if defined(_WIN32) static int InitializeWsa() { WSADATA wsa; @@ -33,9 +21,9 @@ static int InitializeWsa() #define close closesocket #endif -int SocketConnect(const char *host, short port) +int SocketConnect(const char *host, short port, int nonblocking) { - struct addrinfo hints, *servinfo, *p = NULL; + struct addrinfo hints, *servinfo = NULL, *p = NULL; int rv; char portstr[6]; int sock; @@ -66,8 +54,16 @@ int SocketConnect(const char *host, short port) continue; } + if (nonblocking) + { + SocketSetNonblocking(sock, 1); + } + if (connect(sock, p->ai_addr, p->ai_addrlen) == -1) { + int err = SocketErrno; + if (err == SOCKET_EINPROGRESS) + break; close(sock); continue; } @@ -75,10 +71,13 @@ int SocketConnect(const char *host, short port) break; } - freeaddrinfo(servinfo); - cleanup: + if (servinfo) + { + freeaddrinfo(servinfo); + } + if (p == NULL) { #if defined(_WIN32) @@ -178,3 +177,45 @@ int SocketSelect(int sock, int *events, int timeout) return rv; } + +void SocketSetNonblocking(int sock, int nb) +{ +#if defined(_WIN32) + unsigned int yes = nb; + ioctlsocket(s, FIONBIO, &yes); +#else + int flags = fcntl(sock, F_GETFL, 0); + if (nb) + flags |= O_NONBLOCK; + else + flags &= ~O_NONBLOCK; + fcntl(sock, F_SETFL, flags); +#endif +} + +int SocketGetOpt(int sock, int level, int name, void *val, int *len) +{ +#if defined(_WIN32) + return getsockopt(sock, level, name, (char *) val, len); +#else + socklen_t _len = *len; + int rc = getsockopt(sock, level, name, val, &_len); + *len = _len; + return rc; +#endif +} + +int SocketGetError(int sock, int *error) +{ + int len = sizeof(*error); + return SocketGetOpt(sock, SOL_SOCKET, SO_ERROR, error, &len); +} + +int SocketWouldBlock(int error) +{ +#if defined(_WIN32) + return error == WSAEWOULDBLOCK; +#else + return error == EWOULDBLOCK || error == EAGAIN; +#endif +} diff --git a/src/socket.h b/src/socket.h index e7b1a80..abc67af 100644 --- a/src/socket.h +++ b/src/socket.h @@ -6,7 +6,25 @@ #include <stdlib.h> #include <stdint.h> -int SocketConnect(const char *host, short port); +#if defined(_WIN32) +#include "win32.h" +#define SocketErrno (WSAGetLastError()) +#define SOCKET_EINPROGRESS (WSAEWOULDBLOCK) +#else +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/select.h> +#include <netdb.h> +#include <unistd.h> +#include <arpa/inet.h> +#include <fcntl.h> +#include <errno.h> +#define SocketErrno (errno) +#define SOCKET_EINPROGRESS (EINPROGRESS) +#endif + +int SocketConnect(const char *host, short port, int nonblocking); int SocketDisconnect(int sock); @@ -24,4 +42,10 @@ int64_t SocketRecv(int sock, void *buf, size_t len, int flags); int64_t SocketSend(int sock, const void *buf, size_t len, int flags); +void SocketSetNonblocking(int sock, int nb); + +int SocketGetError(int sock, int *error); + +int SocketWouldBlock(int error); + #endif diff --git a/src/stream.c b/src/stream.c index fd154a1..1c46668 100644 --- a/src/stream.c +++ b/src/stream.c @@ -47,6 +47,11 @@ int64_t StreamReadUint16Be(uint16_t *v, Stream *stream) return 2; } +int64_t StreamReadByte(unsigned char *byte, Stream *stream) +{ + return StreamRead(byte, sizeof(*byte), stream); +} + int64_t StreamWrite(const void *ptr, size_t size, Stream *stream) { STREAM_CHECK_OP(stream, write); @@ -65,6 +70,11 @@ int64_t StreamWriteUint16Be(uint16_t v, Stream *stream) return StreamWrite(data, sizeof(data), stream); } +int64_t StreamWriteByte(unsigned char byte, Stream *stream) +{ + return StreamWrite(&byte, sizeof(byte), stream); +} + int StreamSeek(Stream *stream, int64_t offset, int whence) { STREAM_CHECK_OP(stream, seek); diff --git a/src/stream.h b/src/stream.h index 839facb..50f1772 100644 --- a/src/stream.h +++ b/src/stream.h @@ -27,9 +27,11 @@ int StreamClose(Stream *stream); int64_t StreamRead(void *ptr, size_t size, Stream *stream); int64_t StreamReadUint16Be(uint16_t *v, Stream *stream); +int64_t StreamReadByte(unsigned char *byte, Stream *stream); int64_t StreamWrite(const void *ptr, size_t size, Stream *stream); int64_t StreamWriteUint16Be(uint16_t v, Stream *stream); +int64_t StreamWriteByte(unsigned char byte, Stream *stream); int StreamSeek(Stream *stream, int64_t offset, int whence); diff --git a/src/stream_mqtt.c b/src/stream_mqtt.c index 3864ef3..f2bd9cd 100644 --- a/src/stream_mqtt.c +++ b/src/stream_mqtt.c @@ -42,37 +42,39 @@ int64_t StreamWriteMqttString(const_bstring buf, Stream *stream) return 2 + blength(buf); } -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream) +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream) { - size_t multiplier = 1; unsigned char encodedByte; - *remainingLength = 0; do { if (StreamRead(&encodedByte, 1, stream) != 1) return -1; - *remainingLength += (encodedByte & 127) * multiplier; - if (multiplier > 128*128*128) + *remainingLength += (encodedByte & 127) * (*mul); + if ((*mul) > 128*128*128) return -1; - multiplier *= 128; + (*mul) *= 128; } while ((encodedByte & 128) != 0); + *mul = 0; return 0; } -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream) +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream) { - size_t nbytes = 0; do { - unsigned char encodedByte = remainingLength % 128; - remainingLength /= 128; - if (remainingLength > 0) + size_t tmp = *remainingLength; + unsigned char encodedByte = tmp % 128; + tmp /= 128; + if (tmp > 0) encodedByte |= 128; if (StreamWrite(&encodedByte, 1, stream) != 1) + { return -1; - ++nbytes; + } + *remainingLength = tmp; } - while (remainingLength > 0); - return nbytes; + while (*remainingLength > 0); + return 0; } diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h index 9023430..8c8ccb5 100644 --- a/src/stream_mqtt.h +++ b/src/stream_mqtt.h @@ -2,13 +2,15 @@ #define STREAM_MQTT_H #include "stream.h" +#include "stringstream.h" #include <bstrlib/bstrlib.h> int64_t StreamReadMqttString(bstring *buf, Stream *stream); int64_t StreamWriteMqttString(const_bstring buf, Stream *stream); -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream); -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream); +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream); +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream); #endif diff --git a/src/stringstream.c b/src/stringstream.c new file mode 100644 index 0000000..4353932 --- /dev/null +++ b/src/stringstream.c @@ -0,0 +1,115 @@ +#include "stringstream.h" + +#include <assert.h> + +static int StringStreamClose(Stream *base) +{ + StringStream *ss = (StringStream *) base; + bdestroy(ss->buffer); + ss->buffer = NULL; + return 0; +} + +static int64_t StringStreamRead(void *ptr, size_t size, Stream *stream) +{ + StringStream *ss = (StringStream *) stream; + int64_t available = blength(ss->buffer) - ss->pos; + void *bufptr; + + if (available <= 0) + { + return -1; + } + + if (size > (size_t) available) + size = available; + + /* Use a temp buffer pointer to make some warnings disappear when using + GCC */ + bufptr = bdataofs(ss->buffer, ss->pos); + memcpy(ptr, bufptr, size); + + ss->pos += size; + + return size; +} + +static int64_t StringStreamWrite(const void *ptr, size_t size, Stream *stream) +{ + StringStream *ss = (StringStream *) stream; + struct tagbstring buf; + if (ss->buffer->mlen <= 0) + return -1; + btfromblk(buf, ptr, size); + bsetstr(ss->buffer, ss->pos, &buf, '\0'); + ss->pos += size; + return size; +} + +int StringStreamSeek(Stream *base, int64_t offset, int whence) +{ + StringStream *ss = (StringStream *) base; + int64_t newpos = 0; + + if (whence == SEEK_SET) + { + newpos = offset; + } + else if (whence == SEEK_CUR) + { + newpos = ss->pos + offset; + } + else if (whence == SEEK_END) + { + newpos = blength(ss->buffer) - offset; + } + else + { + return -1; + } + + if (newpos > blength(ss->buffer)) + return -1; + + if (newpos < 0) + return -1; + + ss->pos = newpos; + + return 0; +} + +int64_t StringStreamTell(Stream *base) +{ + StringStream *ss = (StringStream *) base; + return ss->pos; +} + +static const StreamOps StringStreamOps = +{ + StringStreamRead, + StringStreamWrite, + StringStreamClose, + StringStreamSeek, + StringStreamTell +}; + +int StringStreamInit(StringStream *stream) +{ + assert(stream != NULL); + memset(stream, 0, sizeof(*stream)); + stream->pos = 0; + stream->buffer = bfromcstr(""); + stream->base.ops = &StringStreamOps; + return 0; +} + +int StringStreamInitFromBstring(StringStream *stream, bstring buffer) +{ + assert(stream != NULL); + memset(stream, 0, sizeof(*stream)); + stream->pos = 0; + stream->buffer = buffer; + stream->base.ops = &StringStreamOps; + return 0; +} diff --git a/src/stringstream.h b/src/stringstream.h new file mode 100644 index 0000000..60d42fb --- /dev/null +++ b/src/stringstream.h @@ -0,0 +1,21 @@ +#ifndef STRINGSTREAM_H +#define STRINGSTREAM_H + +#include "stream.h" +#include <bstrlib/bstrlib.h> +#include <stdio.h> + +typedef struct StringStream StringStream; + +struct StringStream +{ + Stream base; + bstring buffer; + int64_t pos; +}; + +int StringStreamInit(StringStream *stream); + +int StringStreamInitFromBstring(StringStream *stream, bstring buffer); + +#endif |
