diff options
| author | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:29:19 +0200 |
|---|---|---|
| committer | Oskari Timperi <oskari.timperi@iki.fi> | 2017-03-18 09:29:19 +0200 |
| commit | 7aeef53b089272f4633cc40512296bfd884a58d4 (patch) | |
| tree | 894753ced0495f725ad8362859f88d5b61e29eb7 | |
| parent | e9958e8a0f5aa5fbe0a4a03be42b8bf640add6f7 (diff) | |
| parent | 2c76b0da9e0aba2211d5b4a8e51c79e47ad9b6c8 (diff) | |
| download | mqtt-7aeef53b089272f4633cc40512296bfd884a58d4.tar.gz mqtt-7aeef53b089272f4633cc40512296bfd884a58d4.zip | |
Merge branch 'the-great-refactor'v0.5
* the-great-refactor:
Add big_message_test
Fix publish message serialization
Modify the code to use nonblocking sockets
Fix indentation
Free userName and password in MqttClientFree()
Add forgotten files
Massive refactoring of the internals
| -rw-r--r-- | src/CMakeLists.txt | 4 | ||||
| -rw-r--r-- | src/client.c | 1310 | ||||
| -rw-r--r-- | src/deserialize.c | 286 | ||||
| -rw-r--r-- | src/deserialize.h | 11 | ||||
| -rw-r--r-- | src/message.c | 11 | ||||
| -rw-r--r-- | src/message.h | 40 | ||||
| -rw-r--r-- | src/mqtt.h | 6 | ||||
| -rw-r--r-- | src/packet.c | 76 | ||||
| -rw-r--r-- | src/packet.h | 98 | ||||
| -rw-r--r-- | src/serialize.c | 326 | ||||
| -rw-r--r-- | src/serialize.h | 11 | ||||
| -rw-r--r-- | src/socket.c | 73 | ||||
| -rw-r--r-- | src/socket.h | 26 | ||||
| -rw-r--r-- | src/stream.c | 10 | ||||
| -rw-r--r-- | src/stream.h | 2 | ||||
| -rw-r--r-- | src/stream_mqtt.c | 30 | ||||
| -rw-r--r-- | src/stream_mqtt.h | 6 | ||||
| -rw-r--r-- | src/stringstream.c | 115 | ||||
| -rw-r--r-- | src/stringstream.h | 21 | ||||
| -rw-r--r-- | test/interop/CMakeLists.txt | 8 | ||||
| -rw-r--r-- | test/interop/big_message_test.c | 59 | ||||
| -rw-r--r-- | test/interop/bstraux.c | 1161 | ||||
| -rw-r--r-- | test/interop/bstraux.h | 115 | ||||
| -rw-r--r-- | test/interop/ping_test.c | 27 | ||||
| -rw-r--r-- | test/interop/testclient.c | 44 | ||||
| -rw-r--r-- | test/interop/testclient.h | 7 | ||||
| -rw-r--r-- | test/interop/unsubscribe_test.c | 31 | ||||
| -rw-r--r-- | test/interop/username_and_password_test.c | 30 | ||||
| -rw-r--r-- | tools/sub.c | 5 |
29 files changed, 2699 insertions, 1250 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f51fabb..5a565ca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,14 +2,14 @@ ADD_SUBDIRECTORY(lib) ADD_LIBRARY(mqtt STATIC client.c - deserialize.c misc.c packet.c - serialize.c socket.c socketstream.c stream.c stream_mqtt.c + stringstream.c + message.c $<TARGET_OBJECTS:bstrlib> ) diff --git a/src/client.c b/src/client.c index a6b0998..b95c8d5 100644 --- a/src/client.c +++ b/src/client.c @@ -5,10 +5,11 @@ #include "socketstream.h" #include "socket.h" #include "misc.h" -#include "serialize.h" -#include "deserialize.h" #include "log.h" #include "private.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "message.h" #include "queue.h" @@ -25,8 +26,14 @@ #error define PRId64 for your platform #endif -TAILQ_HEAD(MessageList, MqttPacket); -typedef struct MessageList MessageList; +typedef enum MqttClientState MqttClientState; + +enum MqttClientState +{ + MqttClientStateDisconnected, + MqttClientStateConnecting, + MqttClientStateConnected, +}; struct MqttClient { @@ -56,9 +63,9 @@ struct MqttClient /* packets waiting to be sent over network */ SIMPLEQ_HEAD(, MqttPacket) sendQueue; /* sent messages that are not done yet */ - MessageList outMessages; + MqttMessageList outMessages; /* received messages that are not done yet */ - MessageList inMessages; + MqttMessageList inMessages; int sessionPresent; /* when was the last packet sent */ int64_t lastPacketSentTime; @@ -80,18 +87,14 @@ struct MqttClient int paused; bstring userName; bstring password; -}; - -enum MessageState -{ - MessageStateQueued = 100, - MessageStateSend, - MessageStateSent + /* The packet we are receiving */ + MqttPacket inPacket; + MqttClientState state; }; static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet); static int MqttClientQueueSimplePacket(MqttClient *client, int type); -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet); +static int MqttClientSendPacket(MqttClient *client); static int MqttClientRecvPacket(MqttClient *client); static uint16_t MqttClientNextPacketId(MqttClient *client); static void MqttClientProcessMessageQueue(MqttClient *client); @@ -99,14 +102,14 @@ static void MqttClientClearQueues(MqttClient *client); static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int queued = 0; int inMessagesCount = 0; int outMessagesCount = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (packet->state == MessageStateQueued) + if (msg->state == MqttMessageStateQueued) { ++queued; } @@ -114,7 +117,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client) ++outMessagesCount; } - TAILQ_FOREACH(packet, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { ++inMessagesCount; } @@ -142,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId) client->maxQueued = 0; client->maxInflight = 20; + client->state = MqttClientStateDisconnected; + TAILQ_INIT(&client->outMessages); TAILQ_INIT(&client->inMessages); SIMPLEQ_INIT(&client->sendQueue); @@ -157,6 +162,8 @@ void MqttClientFree(MqttClient *client) bdestroy(client->willTopic); bdestroy(client->willMessage); bdestroy(client->host); + bdestroy(client->userName); + bdestroy(client->password); if (client->stream.sock != -1) { @@ -212,15 +219,54 @@ void MqttClientSetOnPublish(MqttClient *client, client->onPublish = cb; } +static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); +static const char MqttProtocolLevel = 0x04; + +static unsigned char MqttClientConnectFlags(MqttClient *client) +{ + unsigned char connectFlags = 0; + + if (client->cleanSession) + { + connectFlags |= 0x02; + } + + if (client->willTopic) + { + connectFlags |= 0x04; + connectFlags |= (client->willQos & 3) << 3; + connectFlags |= (client->willRetain & 1) << 5; + } + + if (client->userName) + { + connectFlags |= 0x80; + if (client->password) + { + connectFlags |= 0x40; + } + } + + return connectFlags; +} + int MqttClientConnect(MqttClient *client, const char *host, short port, int keepAlive, int cleanSession) { int sock; - MqttPacketConnect *packet; + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(host != NULL); + if (client->state != MqttClientStateDisconnected) + { + LOG_ERROR("client must be disconnected to connect"); + return -1; + } + if (client->host) bassigncstr(client->host, host); else @@ -242,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, LOG_DEBUG("connecting"); - if ((sock = SocketConnect(host, port)) == -1) + if ((sock = SocketConnect(host, port, 1)) == -1) { - LOG_ERROR("SocketConnect failed!"); - return -1; + if (SocketErrno != SOCKET_EINPROGRESS) + { + LOG_ERROR("SocketConnect failed!"); + return -1; + } } if (SocketStreamOpen(&client->stream, sock) == -1) @@ -253,44 +302,39 @@ int MqttClientConnect(MqttClient *client, const char *host, short port, return -1; } - packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect); + packet = MqttPacketNew(MqttPacketTypeConnect); if (!packet) return -1; - if (client->cleanSession) - { - packet->connectFlags |= 0x02; - } + StringStreamInit(&ss); - packet->keepAlive = client->keepAlive; - - packet->clientId = bstrcpy(client->clientId); + StreamWriteMqttString(&MqttProtocolId, pss); + StreamWriteByte(MqttProtocolLevel, pss); + StreamWriteByte(MqttClientConnectFlags(client), pss); + StreamWriteUint16Be(client->keepAlive, pss); + StreamWriteMqttString(client->clientId, pss); if (client->willTopic) { - packet->connectFlags |= 0x04; - - packet->willTopic = bstrcpy(client->willTopic); - packet->willMessage = bstrcpy(client->willMessage); - - packet->connectFlags |= (client->willQos & 3) << 3; - packet->connectFlags |= (client->willRetain & 1) << 5; + StreamWriteMqttString(client->willTopic, pss); + StreamWriteMqttString(client->willMessage, pss); } if (client->userName) { - packet->connectFlags |= 0x80; - packet->userName = bstrcpy(client->userName); - - if (client->password) + StreamWriteMqttString(client->userName, pss); + if(client->password) { - packet->connectFlags |= 0x40; - packet->password = bstrcpy(client->password); + StreamWriteMqttString(client->password, pss); } } - MqttClientQueuePacket(client, &packet->base); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + client->state = MqttClientStateConnecting; return 0; } @@ -303,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client) int MqttClientIsConnected(MqttClient *client) { - return client->stream.sock != -1; + return client->stream.sock != -1 && + client->state == MqttClientStateConnected; } int MqttClientRunOnce(MqttClient *client, int timeout) { int rv; - int events; + int events = 0; assert(client != NULL); @@ -319,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout) return -1; } - events = EV_READ; + if (client->state == MqttClientStateConnected) + { + events = EV_READ; - /* Handle outMessages and inMessages, moving queued messages to sendQueue - if there are less than maxInflight number of messages in flight */ - MqttClientProcessMessageQueue(client); + /* Handle outMessages and inMessages, moving queued messages to sendQueue + if there are less than maxInflight number of messages in flight */ + MqttClientProcessMessageQueue(client); - if (SIMPLEQ_EMPTY(&client->sendQueue)) + if (SIMPLEQ_EMPTY(&client->sendQueue)) + { + LOG_DEBUG("nothing to write"); + } + else + { + events |= EV_WRITE; + } + } + else if (client->state == MqttClientStateConnecting) { - LOG_DEBUG("nothing to write"); + events = EV_WRITE; } else { - events |= EV_WRITE; + LOG_ERROR("not connected"); + return -1; } LOG_DEBUG("selecting"); @@ -339,6 +396,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (timeout < 0) { timeout = client->keepAlive * 1000; + if (timeout == 0) + { + timeout = 30 * 1000; + } + } + else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0) + { + timeout = client->keepAlive * 1000; } rv = SocketSelect(client->stream.sock, &events, timeout); @@ -354,22 +419,26 @@ int MqttClientRunOnce(MqttClient *client, int timeout) if (events & EV_WRITE) { - MqttPacket *packet; - LOG_DEBUG("socket writable"); - packet = SIMPLEQ_FIRST(&client->sendQueue); - - if (packet) + if (client->state == MqttClientStateConnecting) { - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (MqttClientSendPacket(client, packet) == -1) + int sockError; + SocketGetError(client->stream.sock, &sockError); + LOG_DEBUG("sockError: %d", sockError); + if (sockError == 0) { - LOG_ERROR("MqttClientSendPacket failed"); - client->stopped = 1; + LOG_DEBUG("connected!"); + client->state = MqttClientStateConnected; + return 0; } } + + if (MqttClientSendPacket(client) == -1) + { + LOG_ERROR("MqttClientSendPacket failed"); + client->stopped = 1; + } } if (events & EV_READ) @@ -433,10 +502,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, } int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count) + int *qos, size_t count) { - MqttPacketSubscribe *packet = NULL; + MqttPacket *packet = NULL; size_t i; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); assert(topicFilters != NULL); @@ -444,68 +515,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, assert(qos != NULL); assert(count > 0); - packet = (MqttPacketSubscribe *) MqttPacketWithIdNew( - MqttPacketTypeSubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe, + MqttClientNextPacketId(client)); if (!packet) return -1; - packet->topicFilters = bstrListCreate(); - bstrListAllocMin(packet->topicFilters, count); + packet->flags = 0x2; + + StringStreamInit(&ss); - packet->qos = (int *) malloc(sizeof(int) * count); + StreamWriteUint16Be(packet->id, pss); + + LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id); for (i = 0; i < count; ++i) { - packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]); - ++packet->topicFilters->qty; + struct tagbstring filter; + btfromcstr(filter, topicFilters[i]); + StreamWriteMqttString(&filter, pss); + StreamWriteByte(qos[i] & 3, pss); } - memcpy(packet->qos, qos, sizeof(int) * count); - - MqttClientQueuePacket(client, (MqttPacket *) packet); + packet->payload = ss.buffer; - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + MqttClientQueuePacket(client, packet); - return MqttPacketId(packet); + return packet->id; } int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter) { - MqttPacketUnsubscribe *packet = NULL; + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + struct tagbstring filter; assert(client != NULL); assert(topicFilter != NULL); - packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew( - MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client)); + packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe, + MqttClientNextPacketId(client)); + + if (!packet) + return -1; + + packet->flags = 0x02; + + StringStreamInit(&ss); - packet->topicFilter = bfromcstr(topicFilter); + StreamWriteUint16Be(packet->id, pss); - MqttClientQueuePacket(client, (MqttPacket *) packet); + btfromcstr(filter, topicFilter); - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages); + StreamWriteMqttString(&filter, pss); - return MqttPacketId(packet); + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return packet->id; } static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client) { - MqttPacket *packet; + MqttMessage *msg; int count = 0; - TAILQ_FOREACH(packet, &client->outMessages, messages) + TAILQ_FOREACH(msg, &client->outMessages, chain) { ++count; } return count; } +static MqttPacket *PublishToPacket(MqttMessage *msg) +{ + MqttPacket *packet = NULL; + StringStream ss; + Stream *pss = (Stream *) &ss; + + if (msg->qos > 0) + { + packet = MqttPacketWithIdNew(MqttPacketTypePublish, + msg->id); + } + else + { + packet = MqttPacketNew(MqttPacketTypePublish); + } + + if (!packet) + return NULL; + + packet->message = msg; + + StringStreamInit(&ss); + + StreamWriteMqttString(msg->topic, pss); + + if (msg->qos > 0) + { + StreamWriteUint16Be(msg->id, pss); + } + + StreamWrite(bdata(msg->payload), blength(msg->payload), pss); + + packet->payload = ss.buffer; + packet->flags = (msg->qos & 3) << 1; + packet->flags |= msg->retain & 1; + + return packet; +} + int MqttClientPublish(MqttClient *client, int qos, int retain, const char *topic, const void *data, size_t size) { - MqttPacketPublish *packet; - - assert(client != NULL); + MqttMessage *message; /* first check if the queue is already full */ if (qos > 0 && client->maxQueued > 0 && @@ -514,55 +639,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain, return -1; } - if (qos > 0) + message = calloc(1, sizeof(*message)); + if (!message) { - packet = (MqttPacketPublish *) MqttPacketWithIdNew( - MqttPacketTypePublish, MqttClientNextPacketId(client)); + return -1; } - else + + message->state = MqttMessageStateQueued; + message->qos = qos; + message->retain = retain; + message->dup = 0; + message->timestamp = MqttGetCurrentTime(); + + if (qos == 0) { - packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish); - } + /* Copy payload and topic directly from user buffers as we don't need + to keep the message data around after this function. */ + MqttPacket *packet; + struct tagbstring bttopic, btpayload; - if (!packet) - return -1; + btfromcstr(bttopic, topic); + message->topic = &bttopic; - packet->qos = qos; - packet->retain = retain; - packet->topicName = bfromcstr(topic); - packet->message = blk2bstr(data, size); + btfromblk(btpayload, data, size); + message->payload = &btpayload; - if (qos > 0) - { - /* check how many messages there are coming in and going out currently - that are not yet done */ - if (client->maxInflight == 0 || - MqttClientInflightMessageCount(client) < client->maxInflight) - { - LOG_DEBUG("setting message (%d) state to MessageStateSend", - MqttPacketId(packet)); - packet->base.state = MessageStateSend; - } - else - { - LOG_DEBUG("setting message (%d) state to MessageStateQueued", - MqttPacketId(packet)); - packet->base.state = MessageStateQueued; - } + packet = PublishToPacket(message); - /* add the message to the outMessages queue to wait for processing */ - TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, - messages); + message->topic = NULL; + message->payload = NULL; + + MqttClientQueuePacket(client, packet); + + MqttMessageFree(message); + + return 0; } else { - MqttClientQueuePacket(client, (MqttPacket *) packet); - } + /* Duplicate the user buffers as we need the data to be available + longer. */ + message->topic = bfromcstr(topic); + message->payload = blk2bstr(data, size); - if (qos > 0) - return MqttPacketId(packet); + message->id = MqttClientNextPacketId(client); - return 0; + TAILQ_INSERT_TAIL(&client->outMessages, message, chain); + + return message->id; + } } int MqttClientPublishCString(MqttClient *client, int qos, int retain, @@ -613,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName, { assert(client != NULL); - if (MqttClientIsConnected(client)) + if (client->state == MqttClientStateConnecting) { LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect"); return -1; @@ -655,6 +780,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet) { assert(client != NULL); LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type)); + packet->state = MqttPacketStateWriteType; SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue); } @@ -667,128 +793,363 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type) return 0; } -static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet) +static int MqttClientSendPacket(MqttClient *client) { - if (MqttPacketSerialize(packet, &client->stream.base) == -1) - return -1; + MqttPacket *packet; - packet->sentAt = MqttGetCurrentTime(); - client->lastPacketSentTime = packet->sentAt; + packet = SIMPLEQ_FIRST(&client->sendQueue); - if (packet->type == MqttPacketTypeDisconnect) + if (!packet) { - client->stopped = 1; + LOG_WARNING("MqttClientSendPacket called with no queued packets"); + return 0; } - /* If the packet is not on any message list, it can be removed after - sending. */ - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) + while (packet != NULL) { - LOG_DEBUG("freeing packet %s after sending", - MqttPacketName(MqttPacketType(packet))); - MqttPacketFree(packet); + switch (packet->state) + { + case MqttPacketStateWriteType: + { + unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) | + (packet->flags & 0x0F); + + if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWriteRemainingLength; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWriteRemainingLength: + { + if (StreamWriteRemainingLength(&packet->remainingLength, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + return -1; + } + + packet->state = MqttPacketStateWritePayload; + packet->remainingLength = blength(packet->payload); + + break; + } + + case MqttPacketStateWritePayload: + { + if (packet->payload) + { + int64_t offset = blength(packet->payload) - packet->remainingLength; + int64_t nwritten = 0; + int towrite = 16*1024; + + if (packet->remainingLength < 16*1024) + towrite = packet->remainingLength; + + nwritten = StreamWrite(bdataofs(packet->payload, offset), + towrite, + &client->stream.base); + + if (nwritten == -1) + { + if (SocketWouldBlock(SocketErrno)) + { + return 0; + } + return -1; + } + + packet->remainingLength -= nwritten; + + LOG_DEBUG("nwritten:%d", (int) nwritten); + } + + if (packet->remainingLength == 0) + { + LOG_DEBUG("packet payload sent"); + packet->state = MqttPacketStateWriteComplete; + } + + break; + } + + case MqttPacketStateWriteComplete: + { + client->lastPacketSentTime = MqttGetCurrentTime(); + + if (packet->type == MqttPacketTypeDisconnect) + { + client->stopped = 1; + client->state = MqttClientStateDisconnected; + } + + LOG_DEBUG("sent %s", MqttPacketName(packet->type)); + + if (packet->type == MqttPacketTypePublish && packet->message) + { + MqttMessage *msg = packet->message; + + if (msg->qos == 1) + { + msg->state = MqttMessageStateWaitPubAck; + } + else if (msg->qos == 2) + { + msg->state = MqttMessageStateWaitPubRec; + } + } + + if (packet->message) + { + packet->message->timestamp = client->lastPacketSentTime; + } + + SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); + + MqttPacketFree(packet); + + packet = SIMPLEQ_FIRST(&client->sendQueue); + + break; + } + } } return 0; } -static void MqttClientHandleConnAck(MqttClient *client, - MqttPacketConnAck *packet) +static int MqttClientHandleConnAck(MqttClient *client) { - client->sessionPresent = packet->connAckFlags & 1; + StringStream ss; + Stream *pss = (Stream *) &ss; + unsigned char flags; + unsigned char rc; + + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadByte(&flags, pss); + + StreamReadByte(&rc, pss); + + client->sessionPresent = flags & 1; LOG_DEBUG("sessionPresent:%d", client->sessionPresent); if (client->onConnect) { - LOG_DEBUG("calling onConnect rc:%d", packet->returnCode); - client->onConnect(client, packet->returnCode, client->sessionPresent); + LOG_DEBUG("calling onConnect rc:%d", rc); + client->onConnect(client, rc, client->sessionPresent); } + + return 0; } -static void MqttClientHandlePingResp(MqttClient *client) +static int MqttClientHandlePingResp(MqttClient *client) { LOG_DEBUG("got ping response"); client->pingSent = 0; + return 0; } -static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet) +static int MqttClientHandleSubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + int *qos; + StringStream ss; + Stream *pss = (Stream *) &ss; + int count; + int i; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + LOG_DEBUG("received SUBACK with id:%d", (int) id); + + count = blength(client->inPacket.payload) - StreamTell(pss); + + if (count <= 0) { - if (MqttPacketType(sub) == MqttPacketTypeSubscribe && - MqttPacketId(sub) == MqttPacketId(packet)) - { - break; - } + LOG_ERROR("number of return codes invalid"); + return -1; } - if (!sub) + qos = malloc(count * sizeof(int)); + + for (i = 0; i < count; ++i) { - LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + unsigned char byte; + StreamReadByte(&byte, pss); + qos[i] = byte; } - else + + if (client->onSubscribe) { - if (client->onSubscribe) - { - MqttPacketSubscribe *sub2; - int i; + client->onSubscribe(client, id, qos, count); + } - sub2 = (MqttPacketSubscribe *) sub; + free(qos); - for (i = 0; i < sub2->topicFilters->qty; ++i) - { - const char *filter = bdata(sub2->topicFilters->entry[i]); - int rc = packet->returnCode[i]; + return 0; +} - LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d", - MqttPacketId(packet), filter, rc); +static int MqttClientSendPubAck(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; - client->onSubscribe(client, MqttPacketId(packet), filter, rc); - } - } + packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id); - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - } + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id); + + if (!packet) + return -1; + + packet->flags = 0x2; + + StringStreamInit(&ss); + + StreamWriteUint16Be(msg->id, pss); + + packet->payload = ss.buffer; + packet->message = msg; + + MqttClientQueuePacket(client, packet); + + return 0; +} + +static int MqttClientSendPubComp(MqttClient *client, uint16_t id) +{ + MqttPacket *packet; + StringStream ss; + Stream *pss = (Stream *) &ss; + + packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id); + + if (!packet) + return -1; + + StringStreamInit(&ss); + + StreamWriteUint16Be(id, pss); + + packet->payload = ss.buffer; + + MqttClientQueuePacket(client, packet); + + return 0; } -static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet) +static int MqttClientHandlePublish(MqttClient *client) { + MqttMessage *msg; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; + MqttPacket *packet; + int qos; + int retain; + bstring topic; + void *payload; + int payloadSize; + + /* We are paused - do nothing */ if (client->paused) - return; + return 0; + + packet = &client->inPacket; + + qos = (packet->flags >> 1) & 3; + retain = packet->flags & 1; + + StringStreamInitFromBstring(&ss, packet->payload); + + StreamReadMqttString(&topic, pss); + + if (qos > 0) + { + StreamReadUint16Be(&id, pss); + } + + payload = bdataofs(ss.buffer, ss.pos); + payloadSize = blength(ss.buffer) - ss.pos; - if (MqttPacketPublishQos(packet) == 2) + if (qos == 2) { /* Check if we have sent a PUBREC previously with the same id. If we have, we have to resend the PUBREC. We must not call the onMessage callback again. */ - MqttPacket *pubRec; - - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (pubRec) + if (msg) { - LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet)); - MqttClientQueuePacket(client, pubRec); - return; + LOG_DEBUG("resending PUBREC id:%u", msg->id); + MqttClientSendPubRec(client, msg); + bdestroy(topic); + return 0; } } @@ -796,268 +1157,395 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe { LOG_DEBUG("calling onMessage"); client->onMessage(client, - bdata(packet->topicName), - bdata(packet->message), - blength(packet->message), - packet->qos, - packet->retain); + bdata(topic), + payload, + payloadSize, + qos, + retain); } - if (MqttPacketPublishQos(packet) > 0) + bdestroy(topic); + + if (qos == 1) + { + MqttClientSendPubAck(client, id); + } + else if (qos == 2) { - int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck : - MqttPacketTypePubRec; + msg = calloc(1, sizeof(*msg)); - MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet)); + msg->state = MqttMessageStateWaitPubRel; + msg->id = id; + msg->qos = qos; - if (MqttPacketPublishQos(packet) == 2) - { - /* append to inMessages as we need a reply to this response */ - TAILQ_INSERT_TAIL(&client->inMessages, resp, messages); - } + TAILQ_INSERT_TAIL(&client->inMessages, msg, chain); - MqttClientQueuePacket(client, resp); + MqttClientSendPubRec(client, msg); } + + return 0; } -static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubAck(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubAck) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); - if (client->onPublish) - { - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + if (client->onPublish) + { + client->onPublish(client, msg->id); } + + MqttMessageFree(msg); + + return 0; } -static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRec(MqttClient *client) { - MqttPacket *pub; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pub, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pub) == MqttPacketId(packet) && - MqttPacketType(pub) == MqttPacketTypePublish) + /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but + they haven't received it. */ + if (msg->id == id && + (msg->state == MqttMessageStateWaitPubRec || + msg->state == MqttMessageStateWaitPubComp)) { break; } } - if (!pub) + if (!msg) { - LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubRel; - TAILQ_REMOVE(&client->outMessages, pub, messages); - MqttPacketFree(pub); + msg->state = MqttMessageStateWaitPubComp; - pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet)); - pubRel->state = MessageStateSend; + bdestroy(msg->payload); + msg->payload = NULL; - TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages); - } + bdestroy(msg->topic); + msg->topic = NULL; + + if (MqttClientSendPubRel(client, msg) == -1) + return -1; + + return 0; } -static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubRel(MqttClient *client) { - MqttPacket *pubRec; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(pubRec, &client->inMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->inMessages, chain) { - if (MqttPacketId(pubRec) == MqttPacketId(packet) && - MqttPacketType(pubRec) == MqttPacketTypePubRec) + if (msg->id == id && + msg->state == MqttMessageStateWaitPubRel) { break; } } - if (!pubRec) + if (!msg) { - MqttPacket *pubComp; - - TAILQ_FOREACH(pubComp, &client->inMessages, messages) - { - if (MqttPacketId(pubComp) == MqttPacketId(packet) && - MqttPacketType(pubComp) == MqttPacketTypePubComp) - { - break; - } - } - - if (pubComp) - { - MqttClientQueuePacket(client, pubComp); - } - else - { - LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; - } + LOG_ERROR("no message found with id %d", (int) id); + return -1; } - else - { - MqttPacket *pubComp; - TAILQ_REMOVE(&client->inMessages, pubRec, messages); - MqttPacketFree(pubRec); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); - pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp, - MqttPacketId(packet)); - - TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages); + if (MqttClientSendPubComp(client, id) == -1) + return -1; - MqttClientQueuePacket(client, pubComp); - } + return 0; } -static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet) +static int MqttClientHandlePubComp(MqttClient *client) { - MqttPacket *pubRel; + StringStream ss; + Stream *pss = (Stream *) &ss; + uint16_t id; + MqttMessage *msg; + + assert(client != NULL); - TAILQ_FOREACH(pubRel, &client->outMessages, messages) + StringStreamInitFromBstring(&ss, client->inPacket.payload); + + StreamReadUint16Be(&id, pss); + + TAILQ_FOREACH(msg, &client->outMessages, chain) { - if (MqttPacketId(pubRel) == MqttPacketId(packet) && - MqttPacketType(pubRel) == MqttPacketTypePubRel) + if (msg->id == id && msg->state == MqttMessageStateWaitPubComp) { break; } } - if (!pubRel) + if (!msg) { - LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + LOG_WARNING("no message found with id %d", (int) id); + return 0; } - else - { - TAILQ_REMOVE(&client->outMessages, pubRel, messages); - MqttPacketFree(pubRel); - if (client->onPublish) - { - LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet)); - client->onPublish(client, MqttPacketId(packet)); - } + TAILQ_REMOVE(&client->outMessages, msg, chain); + + MqttMessageFree(msg); + + if (client->onPublish) + { + LOG_DEBUG("calling onPublish id:%d", id); + client->onPublish(client, id); } + + return 0; } -static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet) +static int MqttClientHandleUnsubAck(MqttClient *client) { - MqttPacket *sub; + uint16_t id; + StringStream ss; + Stream *pss = (Stream *) &ss; assert(client != NULL); - assert(packet != NULL); - TAILQ_FOREACH(sub, &client->outMessages, messages) - { - if (MqttPacketId(sub) == MqttPacketId(packet) && - MqttPacketType(sub) == MqttPacketTypeUnsubscribe) - { - break; - } - } + StringStreamInitFromBstring(&ss, client->inPacket.payload); - if (!sub) + StreamReadUint16Be(&id, pss); + + if (client->onUnsubscribe) { - LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet)); - client->stopped = 1; + client->onUnsubscribe(client, id); } - else - { - TAILQ_REMOVE(&client->outMessages, sub, messages); - MqttPacketFree(sub); - if (client->onUnsubscribe) - { - LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet)); - client->onUnsubscribe(client, MqttPacketId(packet)); - } - } + return 0; } -static int MqttClientRecvPacket(MqttClient *client) +static int MqttClientHandlePacket(MqttClient *client) { - MqttPacket *packet = NULL; + int rc; - if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1) - return -1; - - LOG_DEBUG("received packet %s", MqttPacketName(packet->type)); - - switch (MqttPacketType(packet)) + switch (client->inPacket.type) { case MqttPacketTypeConnAck: - MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet); + rc = MqttClientHandleConnAck(client); break; case MqttPacketTypePingResp: - MqttClientHandlePingResp(client); + rc = MqttClientHandlePingResp(client); break; case MqttPacketTypeSubAck: - MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet); + rc = MqttClientHandleSubAck(client); break; - case MqttPacketTypePublish: - MqttClientHandlePublish(client, (MqttPacketPublish *) packet); + case MqttPacketTypeUnsubAck: + rc = MqttClientHandleUnsubAck(client); break; case MqttPacketTypePubAck: - MqttClientHandlePubAck(client, packet); + rc = MqttClientHandlePubAck(client); break; case MqttPacketTypePubRec: - MqttClientHandlePubRec(client, packet); + rc = MqttClientHandlePubRec(client); break; - case MqttPacketTypePubRel: - MqttClientHandlePubRel(client, packet); + case MqttPacketTypePubComp: + rc = MqttClientHandlePubComp(client); break; - case MqttPacketTypePubComp: - MqttClientHandlePubComp(client, packet); + case MqttPacketTypePubRel: + rc = MqttClientHandlePubRel(client); break; - case MqttPacketTypeUnsubAck: - MqttClientHandleUnsubAck(client, packet); + case MqttPacketTypePublish: + rc = MqttClientHandlePublish(client); break; default: - LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet)); + LOG_ERROR("packet not handled yet"); + rc = -1; break; } - MqttPacketFree(packet); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + + client->inPacket.state = MqttPacketStateReadType; + + return rc; +} + +static int MqttClientRecvPacket(MqttClient *client) +{ + while (1) + { + switch (client->inPacket.state) + { + case MqttPacketStateReadType: + { + unsigned char typeAndFlags; + + if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet type"); + return -1; + } + + client->inPacket.type = typeAndFlags >> 4; + client->inPacket.flags = typeAndFlags & 0x0F; + + if (client->inPacket.type < MqttPacketTypeConnect || + client->inPacket.type > MqttPacketTypeDisconnect) + { + LOG_ERROR("unknown packet type: %d", client->inPacket.type); + return -1; + } + + client->inPacket.state = MqttPacketStateReadRemainingLength; + client->inPacket.remainingLength = 0; + client->inPacket.remainingLengthMul = 1; + client->inPacket.payload = NULL; + + break; + } + + case MqttPacketStateReadRemainingLength: + { + if (StreamReadRemainingLength(&client->inPacket.remainingLength, + &client->inPacket.remainingLengthMul, + &client->stream.base) == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed to read remaining length"); + return -1; + } + + LOG_DEBUG("remainingLength:%lu", + client->inPacket.remainingLength); + + client->inPacket.state = MqttPacketStateReadPayload; + + break; + } + + case MqttPacketStateReadPayload: + { + if (client->inPacket.remainingLength > 0) + { + int64_t nread, offset, toread; + + if (client->inPacket.payload == NULL) + { + unsigned char *data; + client->inPacket.payload = bfromcstr(""); + ballocmin(client->inPacket.payload, + client->inPacket.remainingLength+1); + data = client->inPacket.payload->data; + data[client->inPacket.remainingLength] = '\0'; + } + + offset = blength(client->inPacket.payload); + + toread = 16*1024; + + if (client->inPacket.remainingLength < (size_t) toread) + toread = client->inPacket.remainingLength; + + nread = StreamRead(bdataofs(client->inPacket.payload, + offset), + toread, + &client->stream.base); + + if (nread == -1) + { + if (SocketWouldBlock(SocketErrno)) + return 0; + LOG_ERROR("failed reading packet payload"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + else if (nread == 0) + { + LOG_ERROR("socket disconnected"); + bdestroy(client->inPacket.payload); + client->inPacket.payload = NULL; + return -1; + } + + client->inPacket.remainingLength -= nread; + client->inPacket.payload->slen += nread; + + LOG_DEBUG("nread:%d", (int) nread); + } + + if (client->inPacket.remainingLength == 0) + { + client->inPacket.state = MqttPacketStateReadComplete; + } + break; + } + + case MqttPacketStateReadComplete: + { + int type = client->inPacket.type; + LOG_DEBUG("received %s", MqttPacketName(type)); + return MqttClientHandlePacket(client); + } + } + } return 0; } @@ -1072,101 +1560,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client) return id; } -static int64_t MqttPacketTimeSinceSent(MqttPacket *packet) +static int64_t MqttMessageTimeSinceSent(MqttMessage *msg) { int64_t now = MqttGetCurrentTime(); - return now - packet->sentAt; + return now - msg->timestamp; } -static void MqttClientProcessInMessages(MqttClient *client) +static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg) { - MqttPacket *packet, *next; - - LOG_DEBUG("processing inMessages"); - - TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next) + if (msg->timestamp > 0 && + MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000) { - LOG_DEBUG("packet type:%s id:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - - if (MqttPacketType(packet) == MqttPacketTypePubComp) - { - int64_t elapsed = MqttPacketTimeSinceSent(packet); - if (packet->sentAt > 0 && - elapsed >= client->retryTimeout*1000) - { - LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64, - MqttPacketId(packet), elapsed); - - TAILQ_REMOVE(&client->inMessages, packet, messages); - - MqttPacketFree(packet); - } - } + return 1; } + + return 0; } -static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet) +static void MqttClientProcessInMessages(MqttClient *client) { - if (packet->sentAt > 0 && - MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000) + MqttMessage *msg, *next; + + TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next) { - return 1; - } + switch (msg->state) + { + case MqttMessageStateWaitPubRel: + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRec(client, msg); + } + break; - return 0; + default: + break; + } + } } static void MqttClientProcessOutMessages(MqttClient *client) { - MqttPacket *packet, *next; + MqttMessage *msg, *next; + MqttPacket *packet; int inflight = MqttClientInflightMessageCount(client); - LOG_DEBUG("processing outMessages inflight:%d", inflight); - - TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next) + TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next) { - LOG_DEBUG("packet type:%s id:%d state:%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet), - packet->state); - - switch (packet->state) + switch (msg->state) { - case MessageStateQueued: + case MqttMessageStateQueued: + { if (inflight >= client->maxInflight) { - LOG_DEBUG("cannot dequeue %s/%d", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - break; + continue; } - else - { - /* If there's less than maxInflight messages currently - inflight, we can dequeue some messages by falling - through to MessageStateSend. */ - LOG_DEBUG("dequeuing %s (%d)", - MqttPacketName(MqttPacketType(packet)), - MqttPacketId(packet)); - ++inflight; - } - - case MessageStateSend: - packet->state = MessageStateSent; + /* State change from MqttMessageStatePublish happens after + the packet has been sent (in MqttClientSendPacket). */ + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); MqttClientQueuePacket(client, packet); + ++inflight; break; + } - case MessageStateSent: - if (MqttPacketShouldResend(client, packet)) + case MqttMessageStateWaitPubAck: + case MqttMessageStateWaitPubRec: + { + if (MqttMessageShouldResend(client, msg)) { - packet->state = MessageStateSend; + msg->state = MqttMessageStatePublish; + packet = PublishToPacket(msg); + MqttClientQueuePacket(client, packet); } break; + } - default: + case MqttMessageStateWaitPubComp: + { + if (MqttMessageShouldResend(client, msg)) + { + MqttClientSendPubRel(client, msg); + } break; + } } } } @@ -1182,30 +1658,22 @@ static void MqttClientClearQueues(MqttClient *client) while (!SIMPLEQ_EMPTY(&client->sendQueue)) { MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue); - SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue); - - if (TAILQ_NEXT(packet, messages) == NULL && - TAILQ_PREV(packet, MessageList, messages) == NULL && - TAILQ_FIRST(&client->inMessages) != packet && - TAILQ_FIRST(&client->outMessages) != packet) - { - MqttPacketFree(packet); - } + MqttPacketFree(packet); } while (!TAILQ_EMPTY(&client->outMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->outMessages); - TAILQ_REMOVE(&client->outMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->outMessages); + TAILQ_REMOVE(&client->outMessages, msg, chain); + MqttMessageFree(msg); } while (!TAILQ_EMPTY(&client->inMessages)) { - MqttPacket *packet = TAILQ_FIRST(&client->inMessages); - TAILQ_REMOVE(&client->inMessages, packet, messages); - MqttPacketFree(packet); + MqttMessage *msg = TAILQ_FIRST(&client->inMessages); + TAILQ_REMOVE(&client->inMessages, msg, chain); + MqttMessageFree(msg); } } diff --git a/src/deserialize.c b/src/deserialize.c deleted file mode 100644 index 96d7789..0000000 --- a/src/deserialize.c +++ /dev/null @@ -1,286 +0,0 @@ -#include "deserialize.h" -#include "packet.h" -#include "stream_mqtt.h" -#include "log.h" - -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream); - -static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1) - return -1; - - if (StreamRead(&(*packet)->returnCode, 1, stream) != 1) - return -1; - - return 0; -} - -static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t i; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - return -1; - - remainingLength -= 2; - - (*packet)->returnCode = (unsigned char *) malloc( - sizeof(*(*packet)->returnCode) * remainingLength); - - for (i = 0; i < remainingLength; ++i) - { - if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (remainingLength != 2) - return -1; - - if (StreamReadUint16Be(&(*packet)->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream) -{ - size_t remainingLength = 0; - size_t payloadSize = 0; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - if (StreamReadMqttString(&(*packet)->topicName, stream) == -1) - return -1; - - LOG_DEBUG("remainingLength:%lu", remainingLength); - - payloadSize = remainingLength - blength((*packet)->topicName) - 2; - - LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet), - payloadSize); - - if (MqttPacketHasId((const MqttPacket *) *packet)) - { - LOG_DEBUG("packet has id"); - payloadSize -= 2; - if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1) - { - return -1; - } - } - - LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize); - - /* Allocate extra byte for a NULL terminator. If the user tries to print - the payload directly. */ - - (*packet)->message = bfromcstralloc(payloadSize+1, ""); - - if (StreamRead(bdata((*packet)->message), payloadSize, stream) == -1) - return -1; - - (*packet)->message->slen = payloadSize; - (*packet)->message->data[payloadSize] = '\0'; - - return 0; -} - -static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream) -{ - size_t remainingLength = 0; - char buffer[256]; - - (void) packet; - - if (StreamReadRemainingLength(&remainingLength, stream) == -1) - return -1; - - while (remainingLength > 0) - { - size_t l = sizeof(buffer); - - if (remainingLength < l) - l = remainingLength; - - if (StreamRead(buffer, l, stream) != (int64_t) l) - return -1; - - remainingLength -= l; - } - - return 0; -} - -static int ValidateFlags(int type, int flags) -{ - int rv = 0; - - switch (type) - { - case MqttPacketTypePublish: - { - int qos = (flags >> 1) & 2; - if (qos >= 0 && qos <= 2) - rv = 1; - break; - } - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - if (flags == 2) - { - rv = 1; - } - break; - - default: - if (flags == 0) - { - rv = 1; - } - break; - } - - return rv; -} - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream) -{ - MqttPacketDeserializeFunc deserializer = NULL; - char typeAndFlags; - int type; - int flags; - int rv; - - if (StreamRead(&typeAndFlags, 1, stream) != 1) - return -1; - - type = (typeAndFlags & 0xF0) >> 4; - flags = (typeAndFlags & 0x0F); - - if (!ValidateFlags(type, flags)) - { - return -1; - } - - switch (type) - { - case MqttPacketTypeConnect: - break; - - case MqttPacketTypeConnAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize; - break; - - case MqttPacketTypePublish: - deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize; - break; - - case MqttPacketTypePubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRec: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubRel: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypePubComp: - deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize; - break; - - case MqttPacketTypeSubscribe: - break; - - case MqttPacketTypeSubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize; - break; - - case MqttPacketTypeUnsubscribe: - break; - - case MqttPacketTypeUnsubAck: - deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize; - break; - - case MqttPacketTypePingReq: - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - break; - - default: - return -1; - } - - if (!deserializer) - { - deserializer = MqttPacketGenericDeserializer; - } - - *packet = MqttPacketNew(type); - - if (!*packet) - return -1; - - if (type == MqttPacketTypePublish) - { - MqttPacketPublishDup(*packet) = (flags >> 3) & 1; - MqttPacketPublishQos(*packet) = (flags >> 1) & 3; - MqttPacketPublishRetain(*packet) = flags & 1; - } - - rv = deserializer(packet, stream); - - return rv; -} diff --git a/src/deserialize.h b/src/deserialize.h deleted file mode 100644 index 8c29b3d..0000000 --- a/src/deserialize.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef DESERIALIZE_H -#define DESERIALIZE_H - -#include "config.h" - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketDeserialize(MqttPacket **packet, Stream *stream); - -#endif diff --git a/src/message.c b/src/message.c new file mode 100644 index 0000000..35d9c32 --- /dev/null +++ b/src/message.c @@ -0,0 +1,11 @@ +#include "message.h" +#include "stringstream.h" +#include "stream_mqtt.h" +#include "packet.h" + +void MqttMessageFree(MqttMessage *msg) +{ + bdestroy(msg->topic); + bdestroy(msg->payload); + free(msg); +} diff --git a/src/message.h b/src/message.h new file mode 100644 index 0000000..04a3d61 --- /dev/null +++ b/src/message.h @@ -0,0 +1,40 @@ +#ifndef MESSAGE_H +#define MESSAGE_H + +#include <stdint.h> + +#include "queue.h" +#include <bstrlib/bstrlib.h> + +enum MqttMessageState +{ + MqttMessageStateQueued, + MqttMessageStatePublish, + MqttMessageStateWaitPubAck, + MqttMessageStateWaitPubRec, + MqttMessageStateWaitPubComp, + MqttMessageStateWaitPubRel +}; + +typedef struct MqttMessage MqttMessage; + +struct MqttMessage +{ + int state; + int qos; + int retain; + int dup; + int padding; + uint16_t id; + int64_t timestamp; + bstring topic; + bstring payload; + TAILQ_ENTRY(MqttMessage) chain; +}; + +typedef struct MqttMessageList MqttMessageList; +TAILQ_HEAD(MqttMessageList, MqttMessage); + +void MqttMessageFree(MqttMessage *msg); + +#endif @@ -33,8 +33,8 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client, typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client, int id, - const char *topicFilter, - MqttSubscriptionStatus status); + int *qos, + int count); typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id); @@ -84,7 +84,7 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter, int qos); int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters, - int *qos, size_t count); + int *qos, size_t count); int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter); diff --git a/src/packet.c b/src/packet.c index 47aa689..c833851 100644 --- a/src/packet.c +++ b/src/packet.c @@ -28,42 +28,16 @@ const char *MqttPacketName(int type) } } -static MQTT_INLINE size_t MqttPacketStructSize(int type) -{ - switch (type) - { - case MqttPacketTypeConnect: return sizeof(MqttPacketConnect); - case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck); - case MqttPacketTypePublish: return sizeof(MqttPacketPublish); - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: return sizeof(MqttPacket); - case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe); - case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck); - case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe); - case MqttPacketTypeUnsubAck: return sizeof(MqttPacket); - case MqttPacketTypePingReq: return sizeof(MqttPacket); - case MqttPacketTypePingResp: return sizeof(MqttPacket); - case MqttPacketTypeDisconnect: return sizeof(MqttPacket); - default: return (size_t) -1; - } -} - MqttPacket *MqttPacketNew(int type) { MqttPacket *packet = NULL; - packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type)); + packet = (MqttPacket *) calloc(1, sizeof(*packet)); if (!packet) return NULL; packet->type = type; - /* this will make sure that TAILQ_PREV does not segfault if a message - has not been added to a list at any point */ - packet->messages.tqe_prev = &packet->messages.tqe_next; - return packet; } @@ -78,52 +52,6 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id) void MqttPacketFree(MqttPacket *packet) { - if (MqttPacketType(packet) == MqttPacketTypeConnect) - { - MqttPacketConnect *p = (MqttPacketConnect *) packet; - bdestroy(p->clientId); - bdestroy(p->willTopic); - bdestroy(p->willMessage); - bdestroy(p->userName); - bdestroy(p->password); - } - else if (MqttPacketType(packet) == MqttPacketTypePublish) - { - MqttPacketPublish *p = (MqttPacketPublish *) packet; - bdestroy(p->topicName); - bdestroy(p->message); - } - else if (MqttPacketType(packet) == MqttPacketTypeSubscribe) - { - MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet; - bstrListDestroy(p->topicFilters); - } - else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe) - { - MqttPacketUnsubscribe *p = (MqttPacketUnsubscribe *) packet; - bdestroy(p->topicFilter); - } + bdestroy(packet->payload); free(packet); } - -int MqttPacketHasId(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return MqttPacketPublishQos(packet) > 0; - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - case MqttPacketTypeSubscribe: - case MqttPacketTypeSubAck: - case MqttPacketTypeUnsubscribe: - case MqttPacketTypeUnsubAck: - return 1; - - default: - return 0; - } -} diff --git a/src/packet.h b/src/packet.h index 7ab4f73..a5e2ce7 100644 --- a/src/packet.h +++ b/src/packet.h @@ -29,87 +29,35 @@ enum MqttPacketTypeDisconnect = 0xE }; +enum MqttPacketState +{ + MqttPacketStateReadType, + MqttPacketStateReadRemainingLength, + MqttPacketStateReadPayload, + MqttPacketStateReadComplete, + + MqttPacketStateWriteType, + MqttPacketStateWriteRemainingLength, + MqttPacketStateWritePayload, + MqttPacketStateWriteComplete +}; + +struct MqttMessage; + typedef struct MqttPacket MqttPacket; struct MqttPacket { int type; - uint16_t id; - int state; int flags; - int64_t sentAt; + int state; + uint16_t id; + size_t remainingLength; + size_t remainingLengthMul; + /* TODO: maybe switch to have a StringStream here? */ + bstring payload; + struct MqttMessage *message; SIMPLEQ_ENTRY(MqttPacket) sendQueue; - TAILQ_ENTRY(MqttPacket) messages; -}; - -#define MqttPacketType(packet) (((MqttPacket *) (packet))->type) - -#define MqttPacketId(packet) (((MqttPacket *) (packet))->id) - -#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt) - -typedef struct MqttPacketConnect MqttPacketConnect; - -struct MqttPacketConnect -{ - MqttPacket base; - char connectFlags; - uint16_t keepAlive; - bstring clientId; - bstring willTopic; - bstring willMessage; - bstring userName; - bstring password; -}; - -typedef struct MqttPacketConnAck MqttPacketConnAck; - -struct MqttPacketConnAck -{ - MqttPacket base; - unsigned char connAckFlags; - unsigned char returnCode; -}; - -typedef struct MqttPacketPublish MqttPacketPublish; - -struct MqttPacketPublish -{ - MqttPacket base; - bstring topicName; - bstring message; - char qos; - char dup; - char retain; -}; - -#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos) -#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup) -#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain) - -typedef struct MqttPacketSubscribe MqttPacketSubscribe; - -struct MqttPacketSubscribe -{ - MqttPacket base; - struct bstrList *topicFilters; - int *qos; -}; - -typedef struct MqttPacketSubAck MqttPacketSubAck; - -struct MqttPacketSubAck -{ - MqttPacket base; - unsigned char *returnCode; -}; - -typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe; - -struct MqttPacketUnsubscribe -{ - MqttPacket base; - bstring topicFilter; }; const char *MqttPacketName(int type); @@ -120,6 +68,4 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id); void MqttPacketFree(MqttPacket *packet); -int MqttPacketHasId(const MqttPacket *packet); - #endif diff --git a/src/serialize.c b/src/serialize.c deleted file mode 100644 index c1c8eb4..0000000 --- a/src/serialize.c +++ /dev/null @@ -1,326 +0,0 @@ -#include "serialize.h" -#include "packet.h" -#include "stream_mqtt.h" -#include "log.h" - -#include <bstrlib/bstrlib.h> - -#include <stdlib.h> -#include <assert.h> - -typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet, - Stream *stream); - -static const struct tagbstring MqttProtocolId = bsStatic("MQTT"); -static const char MqttProtocolLevel = 0x04; - -static MQTT_INLINE size_t MqttStringLengthSerialized(const_bstring s) -{ - return 2 + blength(s); -} - -static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2; - - remainingLength += MqttStringLengthSerialized(packet->clientId); - - if (packet->connectFlags & 0x80) - remainingLength += MqttStringLengthSerialized(packet->userName); - - if (packet->connectFlags & 0x40) - remainingLength += MqttStringLengthSerialized(packet->password); - - if (packet->connectFlags & 0x04) - remainingLength += MqttStringLengthSerialized(packet->willTopic) + - MqttStringLengthSerialized(packet->willMessage); - - return remainingLength; -} - -static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet) -{ - size_t remaining = 2; - int i; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]); - remaining += 1; - } - - return remaining; -} - -static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet) -{ - return 2 + MqttStringLengthSerialized(packet->topicFilter); -} - -static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet) -{ - size_t remainingLength = 0; - - remainingLength += MqttStringLengthSerialized(packet->topicName); - - /* Packet id */ - if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2) - { - remainingLength += 2; - } - - remainingLength += blength(packet->message); - - return remainingLength; -} - -static size_t MqttPacketGetRemainingLength(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypeConnect: - return MqttPacketConnectGetRemainingLength( - (MqttPacketConnect *) packet); - - case MqttPacketTypeSubscribe: - return MqttPacketSubscribeGetRemainingLength( - (MqttPacketSubscribe *) packet); - - case MqttPacketTypePublish: - return MqttPacketPublishGetRemainingLength( - (MqttPacketPublish *) packet); - - case MqttPacketTypePubAck: - case MqttPacketTypePubRec: - case MqttPacketTypePubRel: - case MqttPacketTypePubComp: - return 2; - - case MqttPacketTypeUnsubscribe: - return MqttPacketUnsubscribeGetRemainingLength( - (MqttPacketUnsubscribe *) packet); - - default: - return 0; - } -} - -static int MqttPacketFlags(const MqttPacket *packet) -{ - switch (packet->type) - { - case MqttPacketTypePublish: - return ((MqttPacketPublishDup(packet) & 1) << 3) | - ((MqttPacketPublishQos(packet) & 3) << 1) | - (MqttPacketPublishRetain(packet) & 1); - - case MqttPacketTypePubRel: - case MqttPacketTypeSubscribe: - case MqttPacketTypeUnsubscribe: - return 0x2; - - default: - return 0; - } -} - -static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream) -{ - unsigned char typeAndFlags; - size_t remainingLength; - - typeAndFlags = ((packet->type & 0x0F) << 4) | - (MqttPacketFlags(packet) & 0x0F); - remainingLength = MqttPacketGetRemainingLength(packet); - - LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type, - MqttPacketName(packet->type), MqttPacketFlags(packet)); - - if (StreamWrite(&typeAndFlags, 1, stream) != 1) - return -1; - - if (StreamWriteRemainingLength(remainingLength, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream) -{ - assert(MqttPacketHasId((const MqttPacket *) packet)); - - if (MqttPacketBaseSerialize(packet, stream) == -1) - return -1; - - if (StreamWriteUint16Be(packet->id, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize(&packet->base, stream) == -1) - return -1; - - if (StreamWriteMqttString(&MqttProtocolId, stream) == -1) - return -1; - - if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1) - return -1; - - if (StreamWrite(&packet->connectFlags, 1, stream) != 1) - return -1; - - if (StreamWriteUint16Be(packet->keepAlive, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->clientId, stream) == -1) - return -1; - - if (packet->connectFlags & 0x04) - { - if (StreamWriteMqttString(packet->willTopic, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->willMessage, stream) == -1) - return -1; - } - - if (packet->connectFlags & 0x80) - { - if (StreamWriteMqttString(packet->userName, stream) == -1) - return -1; - - if (packet->connectFlags & 0x40) - { - if (StreamWriteMqttString(packet->password, stream) == -1) - return -1; - } - } - - return 0; -} - -static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream) -{ - int i; - - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - for (i = 0; i < packet->topicFilters->qty; ++i) - { - unsigned char qos = (unsigned char) packet->qos[i]; - - if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1) - return -1; - - if (StreamWrite(&qos, 1, stream) == -1) - return -1; - } - - return 0; -} - -static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream) -{ - if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicFilter, stream) == -1) - return -1; - - return 0; -} - -static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream) -{ - if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1) - return -1; - - if (StreamWriteMqttString(packet->topicName, stream) == -1) - return -1; - - LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet)); - - if (MqttPacketPublishQos(packet) > 0) - { - if (StreamWriteUint16Be(packet->base.id, stream) == -1) - return -1; - } - - if (StreamWrite(bdata(packet->message), blength(packet->message), stream) == -1) - return -1; - - return 0; -} - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream) -{ - MqttPacketSerializeFunc f = NULL; - - switch (packet->type) - { - case MqttPacketTypeConnect: - f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize; - break; - - case MqttPacketTypeConnAck: - break; - - case MqttPacketTypePublish: - f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize; - break; - - case MqttPacketTypePubAck: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRec: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubRel: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypePubComp: - f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize; - break; - - case MqttPacketTypeSubscribe: - f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize; - break; - - case MqttPacketTypeSubAck: - break; - - case MqttPacketTypeUnsubscribe: - f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize; - break; - - case MqttPacketTypeUnsubAck: - break; - - case MqttPacketTypePingReq: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - case MqttPacketTypePingResp: - break; - - case MqttPacketTypeDisconnect: - f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize; - break; - - default: - return -1; - } - - assert(f != NULL && "no serializer"); - - return f(packet, stream); -} diff --git a/src/serialize.h b/src/serialize.h deleted file mode 100644 index 7eb988f..0000000 --- a/src/serialize.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef SERIALIZE_H -#define SERIALIZE_H - -#include "config.h" - -typedef struct MqttPacket MqttPacket; -typedef struct Stream Stream; - -int MqttPacketSerialize(const MqttPacket *packet, Stream *stream); - -#endif diff --git a/src/socket.c b/src/socket.c index 64a7c01..b70f4fb 100644 --- a/src/socket.c +++ b/src/socket.c @@ -6,18 +6,6 @@ #include <assert.h> #if defined(_WIN32) -#include "win32.h" -#else -#include <sys/types.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <sys/select.h> -#include <netdb.h> -#include <unistd.h> -#include <arpa/inet.h> -#endif - -#if defined(_WIN32) static int InitializeWsa() { WSADATA wsa; @@ -33,9 +21,9 @@ static int InitializeWsa() #define close closesocket #endif -int SocketConnect(const char *host, short port) +int SocketConnect(const char *host, short port, int nonblocking) { - struct addrinfo hints, *servinfo, *p = NULL; + struct addrinfo hints, *servinfo = NULL, *p = NULL; int rv; char portstr[6]; int sock; @@ -66,8 +54,16 @@ int SocketConnect(const char *host, short port) continue; } + if (nonblocking) + { + SocketSetNonblocking(sock, 1); + } + if (connect(sock, p->ai_addr, p->ai_addrlen) == -1) { + int err = SocketErrno; + if (err == SOCKET_EINPROGRESS) + break; close(sock); continue; } @@ -75,10 +71,13 @@ int SocketConnect(const char *host, short port) break; } - freeaddrinfo(servinfo); - cleanup: + if (servinfo) + { + freeaddrinfo(servinfo); + } + if (p == NULL) { #if defined(_WIN32) @@ -178,3 +177,45 @@ int SocketSelect(int sock, int *events, int timeout) return rv; } + +void SocketSetNonblocking(int sock, int nb) +{ +#if defined(_WIN32) + unsigned int yes = nb; + ioctlsocket(s, FIONBIO, &yes); +#else + int flags = fcntl(sock, F_GETFL, 0); + if (nb) + flags |= O_NONBLOCK; + else + flags &= ~O_NONBLOCK; + fcntl(sock, F_SETFL, flags); +#endif +} + +int SocketGetOpt(int sock, int level, int name, void *val, int *len) +{ +#if defined(_WIN32) + return getsockopt(sock, level, name, (char *) val, len); +#else + socklen_t _len = *len; + int rc = getsockopt(sock, level, name, val, &_len); + *len = _len; + return rc; +#endif +} + +int SocketGetError(int sock, int *error) +{ + int len = sizeof(*error); + return SocketGetOpt(sock, SOL_SOCKET, SO_ERROR, error, &len); +} + +int SocketWouldBlock(int error) +{ +#if defined(_WIN32) + return error == WSAEWOULDBLOCK; +#else + return error == EWOULDBLOCK || error == EAGAIN; +#endif +} diff --git a/src/socket.h b/src/socket.h index e7b1a80..abc67af 100644 --- a/src/socket.h +++ b/src/socket.h @@ -6,7 +6,25 @@ #include <stdlib.h> #include <stdint.h> -int SocketConnect(const char *host, short port); +#if defined(_WIN32) +#include "win32.h" +#define SocketErrno (WSAGetLastError()) +#define SOCKET_EINPROGRESS (WSAEWOULDBLOCK) +#else +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/select.h> +#include <netdb.h> +#include <unistd.h> +#include <arpa/inet.h> +#include <fcntl.h> +#include <errno.h> +#define SocketErrno (errno) +#define SOCKET_EINPROGRESS (EINPROGRESS) +#endif + +int SocketConnect(const char *host, short port, int nonblocking); int SocketDisconnect(int sock); @@ -24,4 +42,10 @@ int64_t SocketRecv(int sock, void *buf, size_t len, int flags); int64_t SocketSend(int sock, const void *buf, size_t len, int flags); +void SocketSetNonblocking(int sock, int nb); + +int SocketGetError(int sock, int *error); + +int SocketWouldBlock(int error); + #endif diff --git a/src/stream.c b/src/stream.c index fd154a1..1c46668 100644 --- a/src/stream.c +++ b/src/stream.c @@ -47,6 +47,11 @@ int64_t StreamReadUint16Be(uint16_t *v, Stream *stream) return 2; } +int64_t StreamReadByte(unsigned char *byte, Stream *stream) +{ + return StreamRead(byte, sizeof(*byte), stream); +} + int64_t StreamWrite(const void *ptr, size_t size, Stream *stream) { STREAM_CHECK_OP(stream, write); @@ -65,6 +70,11 @@ int64_t StreamWriteUint16Be(uint16_t v, Stream *stream) return StreamWrite(data, sizeof(data), stream); } +int64_t StreamWriteByte(unsigned char byte, Stream *stream) +{ + return StreamWrite(&byte, sizeof(byte), stream); +} + int StreamSeek(Stream *stream, int64_t offset, int whence) { STREAM_CHECK_OP(stream, seek); diff --git a/src/stream.h b/src/stream.h index 839facb..50f1772 100644 --- a/src/stream.h +++ b/src/stream.h @@ -27,9 +27,11 @@ int StreamClose(Stream *stream); int64_t StreamRead(void *ptr, size_t size, Stream *stream); int64_t StreamReadUint16Be(uint16_t *v, Stream *stream); +int64_t StreamReadByte(unsigned char *byte, Stream *stream); int64_t StreamWrite(const void *ptr, size_t size, Stream *stream); int64_t StreamWriteUint16Be(uint16_t v, Stream *stream); +int64_t StreamWriteByte(unsigned char byte, Stream *stream); int StreamSeek(Stream *stream, int64_t offset, int whence); diff --git a/src/stream_mqtt.c b/src/stream_mqtt.c index 3864ef3..f2bd9cd 100644 --- a/src/stream_mqtt.c +++ b/src/stream_mqtt.c @@ -42,37 +42,39 @@ int64_t StreamWriteMqttString(const_bstring buf, Stream *stream) return 2 + blength(buf); } -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream) +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream) { - size_t multiplier = 1; unsigned char encodedByte; - *remainingLength = 0; do { if (StreamRead(&encodedByte, 1, stream) != 1) return -1; - *remainingLength += (encodedByte & 127) * multiplier; - if (multiplier > 128*128*128) + *remainingLength += (encodedByte & 127) * (*mul); + if ((*mul) > 128*128*128) return -1; - multiplier *= 128; + (*mul) *= 128; } while ((encodedByte & 128) != 0); + *mul = 0; return 0; } -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream) +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream) { - size_t nbytes = 0; do { - unsigned char encodedByte = remainingLength % 128; - remainingLength /= 128; - if (remainingLength > 0) + size_t tmp = *remainingLength; + unsigned char encodedByte = tmp % 128; + tmp /= 128; + if (tmp > 0) encodedByte |= 128; if (StreamWrite(&encodedByte, 1, stream) != 1) + { return -1; - ++nbytes; + } + *remainingLength = tmp; } - while (remainingLength > 0); - return nbytes; + while (*remainingLength > 0); + return 0; } diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h index 9023430..8c8ccb5 100644 --- a/src/stream_mqtt.h +++ b/src/stream_mqtt.h @@ -2,13 +2,15 @@ #define STREAM_MQTT_H #include "stream.h" +#include "stringstream.h" #include <bstrlib/bstrlib.h> int64_t StreamReadMqttString(bstring *buf, Stream *stream); int64_t StreamWriteMqttString(const_bstring buf, Stream *stream); -int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream); -int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream); +int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul, + Stream *stream); +int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream); #endif diff --git a/src/stringstream.c b/src/stringstream.c new file mode 100644 index 0000000..4353932 --- /dev/null +++ b/src/stringstream.c @@ -0,0 +1,115 @@ +#include "stringstream.h" + +#include <assert.h> + +static int StringStreamClose(Stream *base) +{ + StringStream *ss = (StringStream *) base; + bdestroy(ss->buffer); + ss->buffer = NULL; + return 0; +} + +static int64_t StringStreamRead(void *ptr, size_t size, Stream *stream) +{ + StringStream *ss = (StringStream *) stream; + int64_t available = blength(ss->buffer) - ss->pos; + void *bufptr; + + if (available <= 0) + { + return -1; + } + + if (size > (size_t) available) + size = available; + + /* Use a temp buffer pointer to make some warnings disappear when using + GCC */ + bufptr = bdataofs(ss->buffer, ss->pos); + memcpy(ptr, bufptr, size); + + ss->pos += size; + + return size; +} + +static int64_t StringStreamWrite(const void *ptr, size_t size, Stream *stream) +{ + StringStream *ss = (StringStream *) stream; + struct tagbstring buf; + if (ss->buffer->mlen <= 0) + return -1; + btfromblk(buf, ptr, size); + bsetstr(ss->buffer, ss->pos, &buf, '\0'); + ss->pos += size; + return size; +} + +int StringStreamSeek(Stream *base, int64_t offset, int whence) +{ + StringStream *ss = (StringStream *) base; + int64_t newpos = 0; + + if (whence == SEEK_SET) + { + newpos = offset; + } + else if (whence == SEEK_CUR) + { + newpos = ss->pos + offset; + } + else if (whence == SEEK_END) + { + newpos = blength(ss->buffer) - offset; + } + else + { + return -1; + } + + if (newpos > blength(ss->buffer)) + return -1; + + if (newpos < 0) + return -1; + + ss->pos = newpos; + + return 0; +} + +int64_t StringStreamTell(Stream *base) +{ + StringStream *ss = (StringStream *) base; + return ss->pos; +} + +static const StreamOps StringStreamOps = +{ + StringStreamRead, + StringStreamWrite, + StringStreamClose, + StringStreamSeek, + StringStreamTell +}; + +int StringStreamInit(StringStream *stream) +{ + assert(stream != NULL); + memset(stream, 0, sizeof(*stream)); + stream->pos = 0; + stream->buffer = bfromcstr(""); + stream->base.ops = &StringStreamOps; + return 0; +} + +int StringStreamInitFromBstring(StringStream *stream, bstring buffer) +{ + assert(stream != NULL); + memset(stream, 0, sizeof(*stream)); + stream->pos = 0; + stream->buffer = buffer; + stream->base.ops = &StringStreamOps; + return 0; +} diff --git a/src/stringstream.h b/src/stringstream.h new file mode 100644 index 0000000..60d42fb --- /dev/null +++ b/src/stringstream.h @@ -0,0 +1,21 @@ +#ifndef STRINGSTREAM_H +#define STRINGSTREAM_H + +#include "stream.h" +#include <bstrlib/bstrlib.h> +#include <stdio.h> + +typedef struct StringStream StringStream; + +struct StringStream +{ + Stream base; + bstring buffer; + int64_t pos; +}; + +int StringStreamInit(StringStream *stream); + +int StringStreamInitFromBstring(StringStream *stream, bstring buffer); + +#endif diff --git a/test/interop/CMakeLists.txt b/test/interop/CMakeLists.txt index e907776..b06e28b 100644 --- a/test/interop/CMakeLists.txt +++ b/test/interop/CMakeLists.txt @@ -17,3 +17,11 @@ ADD_INTEROP_TEST(keepalive_test) ADD_INTEROP_TEST(redelivery_on_reconnect_test) ADD_INTEROP_TEST(subscribe_failure_test) ADD_INTEROP_TEST(dollar_topics_test) +ADD_INTEROP_TEST(username_and_password_test) +ADD_INTEROP_TEST(ping_test) +ADD_INTEROP_TEST(unsubscribe_test) +ADD_INTEROP_TEST(big_message_test) + +ADD_LIBRARY(bstraux STATIC bstraux.c) +TARGET_INCLUDE_DIRECTORIES(bstraux PUBLIC ${PROJECT_SOURCE_DIR}/src/lib/bstrlib) +TARGET_LINK_LIBRARIES(big_message_test PRIVATE bstraux) diff --git a/test/interop/big_message_test.c b/test/interop/big_message_test.c new file mode 100644 index 0000000..16cee84 --- /dev/null +++ b/test/interop/big_message_test.c @@ -0,0 +1,59 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" +#include <bstrlib/bstrlib.h> +#include "bstraux.h" + +static const struct tagbstring message = bsStatic( +"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas eu elit vel nisl fringilla ornare. Vestibulum eget sem lobortis, molestie velit in, gravida turpis. Donec ac sapien eu neque pellentesque dictum. Maecenas sed malesuada augue, nec ullamcorper libero. Donec consectetur sit amet orci non viverra. Morbi pharetra, urna ac luctus consequat, nibh urna semper metus, nec consectetur eros sapien in lorem. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Fusce elit magna, fringilla vel velit ac, finibus interdum nibh. Donec sit amet volutpat elit. Sed sodales finibus nisl, ut vulputate tortor. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec vel egestas tellus. Aliquam eget orci eget tortor porttitor ullamcorper in vel nulla. Cras facilisis tristique turpis vel molestie. Quisque suscipit orci orci, et convallis est eleifend sit amet." +"Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. In hac habitasse platea dictumst. Donec auctor ante odio, vitae tristique nisi egestas a. Suspendisse sit amet fermentum libero, viverra tempor neque. Nunc eleifend quam ac lacus ullamcorper fermentum. Integer lorem turpis, lobortis eget risus nec, auctor convallis sapien. In laoreet mauris at mi vehicula bibendum. Lorem ipsum dolor sit amet, consectetur adipiscing elit." +"Quisque commodo nisi vel tellus sodales, nec laoreet arcu gravida. Mauris vitae ligula nisl. Maecenas in euismod odio, vel vulputate arcu. Mauris vehicula tortor nec tempus euismod. Maecenas at tortor in libero pretium consequat a sed augue. Phasellus tortor erat, hendrerit id placerat id, pulvinar eget lacus. Curabitur rhoncus lobortis augue, hendrerit sodales tellus faucibus at. Donec a eros tellus. Sed at urna a lectus scelerisque lobortis." +"Duis accumsan ut augue sit amet suscipit. Cras tincidunt quam elementum magna faucibus eleifend. Etiam magna elit, commodo a tortor tempus, tempor vestibulum lorem. Nullam volutpat, libero a semper porttitor, neque turpis auctor augue, ut consequat diam nunc non tellus. Morbi nec varius ipsum, at imperdiet nisl. Fusce a est leo. Sed vitae turpis ligula. Vivamus eget eros id magna tincidunt consequat ut vel lorem. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Nam consectetur in tellus sit amet blandit. Cras cursus dictum ex, at iaculis sem ultricies quis. Fusce vitae pretium tellus, non cursus sem. Suspendisse ac dui eu quam semper eleifend ac sit amet orci. Nunc a nibh felis. Vivamus porta fermentum diam, vel commodo sem tincidunt ac." +"Nam dapibus, tellus nec pharetra efficitur, velit mauris faucibus nulla, ac sodales enim ex interdum tortor. Sed eget metus quis dolor euismod elementum vitae non felis. Nullam gravida diam sit amet suscipit iaculis. Quisque vehicula maximus lorem non volutpat. Vestibulum nec dui eu neque sodales finibus. Pellentesque eleifend fermentum erat, a tincidunt nisl luctus ultricies. Aliquam malesuada enim metus, nec pharetra orci dictum id. Lorem ipsum dolor sit amet, consectetur adipiscing elit. In neque urna, vehicula nec ante vel, porta dignissim lorem. Duis fringilla arcu nec tellus lacinia facilisis." +); + +TEST big_message_test() +{ + TestClient *client; + bstring encodedMessage; + bstring fullMessage; + int need = 1024 * 1024 * 3.5; + + fullMessage = bstrcpy(&message); + bpattern(fullMessage, need); + + encodedMessage = bBase64Encode(fullMessage); + + printf("ENCODED MESSAGE SIZE %d\n", blength(encodedMessage)); + + client = TestClientNew("clienta"); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + ASSERT(TestClientSubscribe(client, topics[0], 1)); + ASSERT(TestClientPublish(client, 1, 0, topics[0], bdata(encodedMessage))); + ASSERT(TestClientWait(client, 2000)); + TestClientDisconnect(client); + + ASSERT_EQ(1, TestClientMessageCount(client)); + ASSERT_EQ(blength(encodedMessage), SIMPLEQ_FIRST(&client->messages)->size); + ASSERT_MEM_EQ(bdata(encodedMessage), + SIMPLEQ_FIRST(&client->messages)->data, + blength(encodedMessage)); + + TestClientFree(client); + + bdestroy(encodedMessage); + bdestroy(fullMessage); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(big_message_test); + GREATEST_MAIN_END(); +} diff --git a/test/interop/bstraux.c b/test/interop/bstraux.c new file mode 100644 index 0000000..ac97836 --- /dev/null +++ b/test/interop/bstraux.c @@ -0,0 +1,1161 @@ + +/* + * This source file is part of the bstring string library. This code was + * written by Paul Hsieh in 2002-2015, and is covered by the BSD open source + * license and the GPL. Refer to the accompanying documentation for details + * on usage and license. + */ + +/* + * bstraux.c + * + * This file is not necessarily part of the core bstring library itself, but + * is just an auxilliary module which includes miscellaneous or trivial + * functions. + */ + +#if defined (_MSC_VER) +# define _CRT_SECURE_NO_WARNINGS +#endif + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <limits.h> +#include <ctype.h> +#include "bstrlib.h" +#include "bstraux.h" + +#ifndef UNUSED +#define UNUSED(x) (void)(x) +#endif + +/* bstring bTail (bstring b, int n) + * + * Return with a string of the last n characters of b. + */ +bstring bTail (bstring b, int n) { + if (b == NULL || n < 0 || (b->mlen < b->slen && b->mlen > 0)) return NULL; + if (n >= b->slen) return bstrcpy (b); + return bmidstr (b, b->slen - n, n); +} + +/* bstring bHead (bstring b, int n) + * + * Return with a string of the first n characters of b. + */ +bstring bHead (bstring b, int n) { + if (b == NULL || n < 0 || (b->mlen < b->slen && b->mlen > 0)) return NULL; + if (n >= b->slen) return bstrcpy (b); + return bmidstr (b, 0, n); +} + +/* int bFill (bstring a, char c, int len) + * + * Fill a given bstring with the character in parameter c, for a length n. + */ +int bFill (bstring b, char c, int len) { + if (b == NULL || len < 0 || (b->mlen < b->slen && b->mlen > 0)) return -__LINE__; + b->slen = 0; + return bsetstr (b, len, NULL, c); +} + +/* int bReplicate (bstring b, int n) + * + * Replicate the contents of b end to end n times and replace it in b. + */ +int bReplicate (bstring b, int n) { + return bpattern (b, n * b->slen); +} + +/* int bReverse (bstring b) + * + * Reverse the contents of b in place. + */ +int bReverse (bstring b) { +int i, n, m; +unsigned char t; + + if (b == NULL || b->slen < 0 || b->mlen < b->slen) return -__LINE__; + n = b->slen; + if (2 <= n) { + m = ((unsigned)n) >> 1; + n--; + for (i=0; i < m; i++) { + t = b->data[n - i]; + b->data[n - i] = b->data[i]; + b->data[i] = t; + } + } + return 0; +} + +/* int bInsertChrs (bstring b, int pos, int len, unsigned char c, unsigned char fill) + * + * Insert a repeated sequence of a given character into the string at + * position pos for a length len. + */ +int bInsertChrs (bstring b, int pos, int len, unsigned char c, unsigned char fill) { + if (b == NULL || b->slen < 0 || b->mlen < b->slen || pos < 0 || len <= 0) return -__LINE__; + + if (pos > b->slen + && 0 > bsetstr (b, pos, NULL, fill)) return -__LINE__; + + if (0 > balloc (b, b->slen + len)) return -__LINE__; + if (pos < b->slen) memmove (b->data + pos + len, b->data + pos, b->slen - pos); + memset (b->data + pos, c, len); + b->slen += len; + b->data[b->slen] = (unsigned char) '\0'; + return BSTR_OK; +} + +/* int bJustifyLeft (bstring b, int space) + * + * Left justify a string. + */ +int bJustifyLeft (bstring b, int space) { +int j, i, s, t; +unsigned char c = (unsigned char) space; + + if (b == NULL || b->slen < 0 || b->mlen < b->slen) return -__LINE__; + if (space != (int) c) return BSTR_OK; + + for (s=j=i=0; i < b->slen; i++) { + t = s; + s = c != (b->data[j] = b->data[i]); + j += (t|s); + } + if (j > 0 && b->data[j-1] == c) j--; + + b->data[j] = (unsigned char) '\0'; + b->slen = j; + return BSTR_OK; +} + +/* int bJustifyRight (bstring b, int width, int space) + * + * Right justify a string to within a given width. + */ +int bJustifyRight (bstring b, int width, int space) { +int ret; + if (width <= 0) return -__LINE__; + if (0 > (ret = bJustifyLeft (b, space))) return ret; + if (b->slen <= width) + return bInsertChrs (b, 0, width - b->slen, (unsigned char) space, (unsigned char) space); + return BSTR_OK; +} + +/* int bJustifyCenter (bstring b, int width, int space) + * + * Center a string's non-white space characters to within a given width by + * inserting whitespaces at the beginning. + */ +int bJustifyCenter (bstring b, int width, int space) { +int ret; + if (width <= 0) return -__LINE__; + if (0 > (ret = bJustifyLeft (b, space))) return ret; + if (b->slen <= width) + return bInsertChrs (b, 0, (width - b->slen + 1) >> 1, (unsigned char) space, (unsigned char) space); + return BSTR_OK; +} + +/* int bJustifyMargin (bstring b, int width, int space) + * + * Stretch a string to flush against left and right margins by evenly + * distributing additional white space between words. If the line is too + * long to be margin justified, it is left justified. + */ +int bJustifyMargin (bstring b, int width, int space) { +struct bstrList * sl; +int i, l, c; + + if (b == NULL || b->slen < 0 || b->mlen == 0 || b->mlen < b->slen) return -__LINE__; + if (NULL == (sl = bsplit (b, (unsigned char) space))) return -__LINE__; + for (l=c=i=0; i < sl->qty; i++) { + if (sl->entry[i]->slen > 0) { + c ++; + l += sl->entry[i]->slen; + } + } + + if (l + c >= width || c < 2) { + bstrListDestroy (sl); + return bJustifyLeft (b, space); + } + + b->slen = 0; + for (i=0; i < sl->qty; i++) { + if (sl->entry[i]->slen > 0) { + if (b->slen > 0) { + int s = (width - l + (c / 2)) / c; + bInsertChrs (b, b->slen, s, (unsigned char) space, (unsigned char) space); + l += s; + } + bconcat (b, sl->entry[i]); + c--; + if (c <= 0) break; + } + } + + bstrListDestroy (sl); + return BSTR_OK; +} + +static size_t readNothing (void *buff, size_t elsize, size_t nelem, void *parm) { + UNUSED(buff); + UNUSED(elsize); + UNUSED(nelem); + UNUSED(parm); + return 0; /* Immediately indicate EOF. */ +} + +/* struct bStream * bsFromBstr (const_bstring b); + * + * Create a bStream whose contents are a copy of the bstring passed in. + * This allows the use of all the bStream APIs with bstrings. + */ +struct bStream * bsFromBstr (const_bstring b) { +struct bStream * s = bsopen ((bNread) readNothing, NULL); + bsunread (s, b); /* Push the bstring data into the empty bStream. */ + return s; +} + +static size_t readRef (void *buff, size_t elsize, size_t nelem, void *parm) { +struct tagbstring * t = (struct tagbstring *) parm; +size_t tsz = elsize * nelem; + + if (tsz > (size_t) t->slen) tsz = (size_t) t->slen; + if (tsz > 0) { + memcpy (buff, t->data, tsz); + t->slen -= (int) tsz; + t->data += tsz; + return tsz / elsize; + } + return 0; +} + +/* The "by reference" version of the above function. This function puts + * a number of restrictions on the call site (the passed in struct + * tagbstring *will* be modified by this function, and the source data + * must remain alive and constant for the lifetime of the bStream). + * Hence it is not presented as an extern. + */ +static struct bStream * bsFromBstrRef (struct tagbstring * t) { + if (!t) return NULL; + return bsopen ((bNread) readRef, t); +} + +/* char * bStr2NetStr (const_bstring b) + * + * Convert a bstring to a netstring. See + * http://cr.yp.to/proto/netstrings.txt for a description of netstrings. + * Note: 1) The value returned should be freed with a call to bcstrfree() at + * the point when it will no longer be referenced to avoid a memory + * leak. + * 2) If the returned value is non-NULL, then it also '\0' terminated + * in the character position one past the "," terminator. + */ +char * bStr2NetStr (const_bstring b) { +char strnum[sizeof (b->slen) * 3 + 1]; +bstring s; +unsigned char * buff; + + if (b == NULL || b->data == NULL || b->slen < 0) return NULL; + sprintf (strnum, "%d:", b->slen); + if (NULL == (s = bfromcstr (strnum)) + || bconcat (s, b) == BSTR_ERR || bconchar (s, (char) ',') == BSTR_ERR) { + bdestroy (s); + return NULL; + } + buff = s->data; + bcstrfree ((char *) s); + return (char *) buff; +} + +/* bstring bNetStr2Bstr (const char * buf) + * + * Convert a netstring to a bstring. See + * http://cr.yp.to/proto/netstrings.txt for a description of netstrings. + * Note that the terminating "," *must* be present, however a following '\0' + * is *not* required. + */ +bstring bNetStr2Bstr (const char * buff) { +int i, x; +bstring b; + if (buff == NULL) return NULL; + x = 0; + for (i=0; buff[i] != ':'; i++) { + unsigned int v = buff[i] - '0'; + if (v > 9 || x > ((INT_MAX - (signed int)v) / 10)) return NULL; + x = (x * 10) + v; + } + + /* This thing has to be properly terminated */ + if (buff[i + 1 + x] != ',') return NULL; + + if (NULL == (b = bfromcstr (""))) return NULL; + if (balloc (b, x + 1) != BSTR_OK) { + bdestroy (b); + return NULL; + } + memcpy (b->data, buff + i + 1, x); + b->data[x] = (unsigned char) '\0'; + b->slen = x; + return b; +} + +static char b64ETable[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +/* bstring bBase64Encode (const_bstring b) + * + * Generate a base64 encoding. See: RFC1341 + */ +bstring bBase64Encode (const_bstring b) { +int i, c0, c1, c2, c3; +bstring out; + + if (b == NULL || b->slen < 0 || b->data == NULL) return NULL; + + out = bfromcstr (""); + for (i=0; i + 2 < b->slen; i += 3) { + if (i && ((i % 57) == 0)) { + if (bconchar (out, (char) '\015') < 0 || bconchar (out, (char) '\012') < 0) { + bdestroy (out); + return NULL; + } + } + c0 = b->data[i] >> 2; + c1 = ((b->data[i] << 4) | + (b->data[i+1] >> 4)) & 0x3F; + c2 = ((b->data[i+1] << 2) | + (b->data[i+2] >> 6)) & 0x3F; + c3 = b->data[i+2] & 0x3F; + if (bconchar (out, b64ETable[c0]) < 0 || + bconchar (out, b64ETable[c1]) < 0 || + bconchar (out, b64ETable[c2]) < 0 || + bconchar (out, b64ETable[c3]) < 0) { + bdestroy (out); + return NULL; + } + } + + if (i && ((i % 57) == 0)) { + if (bconchar (out, (char) '\015') < 0 || bconchar (out, (char) '\012') < 0) { + bdestroy (out); + return NULL; + } + } + + switch (i + 2 - b->slen) { + case 0: c0 = b->data[i] >> 2; + c1 = ((b->data[i] << 4) | + (b->data[i+1] >> 4)) & 0x3F; + c2 = (b->data[i+1] << 2) & 0x3F; + if (bconchar (out, b64ETable[c0]) < 0 || + bconchar (out, b64ETable[c1]) < 0 || + bconchar (out, b64ETable[c2]) < 0 || + bconchar (out, (char) '=') < 0) { + bdestroy (out); + return NULL; + } + break; + case 1: c0 = b->data[i] >> 2; + c1 = (b->data[i] << 4) & 0x3F; + if (bconchar (out, b64ETable[c0]) < 0 || + bconchar (out, b64ETable[c1]) < 0 || + bconchar (out, (char) '=') < 0 || + bconchar (out, (char) '=') < 0) { + bdestroy (out); + return NULL; + } + break; + case 2: break; + } + + return out; +} + +#define B64_PAD (-2) +#define B64_ERR (-1) + +static int base64DecodeSymbol (unsigned char alpha) { + if ((alpha >= 'A') && (alpha <= 'Z')) return (int)(alpha - 'A'); + else if ((alpha >= 'a') && (alpha <= 'z')) + return 26 + (int)(alpha - 'a'); + else if ((alpha >= '0') && (alpha <= '9')) + return 52 + (int)(alpha - '0'); + else if (alpha == '+') return 62; + else if (alpha == '/') return 63; + else if (alpha == '=') return B64_PAD; + else return B64_ERR; +} + +/* bstring bBase64DecodeEx (const_bstring b, int * boolTruncError) + * + * Decode a base64 block of data. All MIME headers are assumed to have been + * removed. See: RFC1341 + */ +bstring bBase64DecodeEx (const_bstring b, int * boolTruncError) { +int i, v; +unsigned char c0, c1, c2; +bstring out; + + if (b == NULL || b->slen < 0 || b->data == NULL) return NULL; + if (boolTruncError) *boolTruncError = 0; + out = bfromcstr (""); + i = 0; + for (;;) { + do { + if (i >= b->slen) return out; + if (b->data[i] == '=') { /* Bad "too early" truncation */ + if (boolTruncError) { + *boolTruncError = 1; + return out; + } + bdestroy (out); + return NULL; + } + v = base64DecodeSymbol (b->data[i]); + i++; + } while (v < 0); + c0 = (unsigned char) (v << 2); + do { + if (i >= b->slen || b->data[i] == '=') { /* Bad "too early" truncation */ + if (boolTruncError) { + *boolTruncError = 1; + return out; + } + bdestroy (out); + return NULL; + } + v = base64DecodeSymbol (b->data[i]); + i++; + } while (v < 0); + c0 |= (unsigned char) (v >> 4); + c1 = (unsigned char) (v << 4); + do { + if (i >= b->slen) { + if (boolTruncError) { + *boolTruncError = 1; + return out; + } + bdestroy (out); + return NULL; + } + if (b->data[i] == '=') { + i++; + if (i >= b->slen || b->data[i] != '=' || bconchar (out, c0) < 0) { + if (boolTruncError) { + *boolTruncError = 1; + return out; + } + bdestroy (out); /* Missing "=" at the end. */ + return NULL; + } + return out; + } + v = base64DecodeSymbol (b->data[i]); + i++; + } while (v < 0); + c1 |= (unsigned char) (v >> 2); + c2 = (unsigned char) (v << 6); + do { + if (i >= b->slen) { + if (boolTruncError) { + *boolTruncError = 1; + return out; + } + bdestroy (out); + return NULL; + } + if (b->data[i] == '=') { + if (bconchar (out, c0) < 0 || bconchar (out, c1) < 0) { + if (boolTruncError) { + *boolTruncError = 1; + return out; + } + bdestroy (out); + return NULL; + } + if (boolTruncError) *boolTruncError = 0; + return out; + } + v = base64DecodeSymbol (b->data[i]); + i++; + } while (v < 0); + c2 |= (unsigned char) (v); + if (bconchar (out, c0) < 0 || + bconchar (out, c1) < 0 || + bconchar (out, c2) < 0) { + if (boolTruncError) { + *boolTruncError = -1; + return out; + } + bdestroy (out); + return NULL; + } + } +} + +#define UU_DECODE_BYTE(b) (((b) == (signed int)'`') ? 0 : (b) - (signed int)' ') + +struct bUuInOut { + bstring src, dst; + int * badlines; +}; + +#define UU_MAX_LINELEN 45 + +static int bUuDecLine (void * parm, int ofs, int len) { +struct bUuInOut * io = (struct bUuInOut *) parm; +bstring s = io->src; +bstring t = io->dst; +int i, llen, otlen, ret, c0, c1, c2, c3, d0, d1, d2, d3; + + if (len == 0) return 0; + llen = UU_DECODE_BYTE (s->data[ofs]); + ret = 0; + + otlen = t->slen; + + if (((unsigned) llen) > UU_MAX_LINELEN) { ret = -__LINE__; + goto bl; + } + + llen += t->slen; + + for (i=1; i < s->slen && t->slen < llen;i += 4) { + unsigned char outoctet[3]; + c0 = UU_DECODE_BYTE (d0 = (int) bchare (s, i+ofs+0, ' ' - 1)); + c1 = UU_DECODE_BYTE (d1 = (int) bchare (s, i+ofs+1, ' ' - 1)); + c2 = UU_DECODE_BYTE (d2 = (int) bchare (s, i+ofs+2, ' ' - 1)); + c3 = UU_DECODE_BYTE (d3 = (int) bchare (s, i+ofs+3, ' ' - 1)); + + if (((unsigned) (c0|c1) >= 0x40)) { if (!ret) ret = -__LINE__; + if (d0 > 0x60 || (d0 < (' ' - 1) && !isspace (d0)) || + d1 > 0x60 || (d1 < (' ' - 1) && !isspace (d1))) { + t->slen = otlen; + goto bl; + } + c0 = c1 = 0; + } + outoctet[0] = (unsigned char) ((c0 << 2) | ((unsigned) c1 >> 4)); + if (t->slen+1 >= llen) { + if (0 > bconchar (t, (char) outoctet[0])) return -__LINE__; + break; + } + if ((unsigned) c2 >= 0x40) { if (!ret) ret = -__LINE__; + if (d2 > 0x60 || (d2 < (' ' - 1) && !isspace (d2))) { + t->slen = otlen; + goto bl; + } + c2 = 0; + } + outoctet[1] = (unsigned char) ((c1 << 4) | ((unsigned) c2 >> 2)); + if (t->slen+2 >= llen) { + if (0 > bcatblk (t, outoctet, 2)) return -__LINE__; + break; + } + if ((unsigned) c3 >= 0x40) { if (!ret) ret = -__LINE__; + if (d3 > 0x60 || (d3 < (' ' - 1) && !isspace (d3))) { + t->slen = otlen; + goto bl; + } + c3 = 0; + } + outoctet[2] = (unsigned char) ((c2 << 6) | ((unsigned) c3)); + if (0 > bcatblk (t, outoctet, 3)) return -__LINE__; + } + if (t->slen < llen) { if (0 == ret) ret = -__LINE__; + t->slen = otlen; + } + bl:; + if (ret && io->badlines) { + (*io->badlines)++; + return 0; + } + return ret; +} + +/* bstring bUuDecodeEx (const_bstring src, int * badlines) + * + * Performs a UUDecode of a block of data. If there are errors in the + * decoding, they are counted up and returned in "badlines", if badlines is + * not NULL. It is assumed that the "begin" and "end" lines have already + * been stripped off. The potential security problem of writing the + * filename in the begin line is something that is beyond the scope of a + * portable library. + */ + +#ifdef _MSC_VER +#pragma warning(disable:4204) +#endif + +bstring bUuDecodeEx (const_bstring src, int * badlines) { +struct tagbstring t; +struct bStream * s; +struct bStream * d; +bstring b; + + if (!src) return NULL; + t = *src; /* Short lifetime alias to header of src */ + s = bsFromBstrRef (&t); /* t is undefined after this */ + if (!s) return NULL; + d = bsUuDecode (s, badlines); + b = bfromcstralloc (256, ""); + if (NULL == b || 0 > bsread (b, d, INT_MAX)) { + bdestroy (b); + b = NULL; + } + bsclose (d); + bsclose (s); + return b; +} + +struct bsUuCtx { + struct bUuInOut io; + struct bStream * sInp; +}; + +static size_t bsUuDecodePart (void *buff, size_t elsize, size_t nelem, void *parm) { +static struct tagbstring eol = bsStatic ("\r\n"); +struct bsUuCtx * luuCtx = (struct bsUuCtx *) parm; +size_t tsz; +int l, lret; + + if (NULL == buff || NULL == parm) return 0; + tsz = elsize * nelem; + + CheckInternalBuffer:; + /* If internal buffer has sufficient data, just output it */ + if (((size_t) luuCtx->io.dst->slen) > tsz) { + memcpy (buff, luuCtx->io.dst->data, tsz); + bdelete (luuCtx->io.dst, 0, (int) tsz); + return nelem; + } + + DecodeMore:; + if (0 <= (l = binchr (luuCtx->io.src, 0, &eol))) { + int ol = 0; + struct tagbstring t; + bstring s = luuCtx->io.src; + luuCtx->io.src = &t; + + do { + if (l > ol) { + bmid2tbstr (t, s, ol, l - ol); + lret = bUuDecLine (&luuCtx->io, 0, t.slen); + if (0 > lret) { + luuCtx->io.src = s; + goto Done; + } + } + ol = l + 1; + if (((size_t) luuCtx->io.dst->slen) > tsz) break; + l = binchr (s, ol, &eol); + } while (BSTR_ERR != l); + bdelete (s, 0, ol); + luuCtx->io.src = s; + goto CheckInternalBuffer; + } + + if (BSTR_ERR != bsreada (luuCtx->io.src, luuCtx->sInp, bsbufflength (luuCtx->sInp, BSTR_BS_BUFF_LENGTH_GET))) { + goto DecodeMore; + } + + bUuDecLine (&luuCtx->io, 0, luuCtx->io.src->slen); + + Done:; + /* Output any lingering data that has been translated */ + if (((size_t) luuCtx->io.dst->slen) > 0) { + if (((size_t) luuCtx->io.dst->slen) > tsz) goto CheckInternalBuffer; + memcpy (buff, luuCtx->io.dst->data, luuCtx->io.dst->slen); + tsz = luuCtx->io.dst->slen / elsize; + luuCtx->io.dst->slen = 0; + if (tsz > 0) return tsz; + } + + /* Deallocate once EOF becomes triggered */ + bdestroy (luuCtx->io.dst); + bdestroy (luuCtx->io.src); + free (luuCtx); + return 0; +} + +/* bStream * bsUuDecode (struct bStream * sInp, int * badlines) + * + * Creates a bStream which performs the UUDecode of an an input stream. If + * there are errors in the decoding, they are counted up and returned in + * "badlines", if badlines is not NULL. It is assumed that the "begin" and + * "end" lines have already been stripped off. The potential security + * problem of writing the filename in the begin line is something that is + * beyond the scope of a portable library. + */ + +struct bStream * bsUuDecode (struct bStream * sInp, int * badlines) { +struct bsUuCtx * luuCtx = (struct bsUuCtx *) malloc (sizeof (struct bsUuCtx)); +struct bStream * sOut; + + if (NULL == luuCtx) return NULL; + + luuCtx->io.src = bfromcstr (""); + luuCtx->io.dst = bfromcstr (""); + if (NULL == luuCtx->io.dst || NULL == luuCtx->io.src) { + CleanUpFailureToAllocate:; + bdestroy (luuCtx->io.dst); + bdestroy (luuCtx->io.src); + free (luuCtx); + return NULL; + } + luuCtx->io.badlines = badlines; + if (badlines) *badlines = 0; + + luuCtx->sInp = sInp; + + sOut = bsopen ((bNread) bsUuDecodePart, luuCtx); + if (NULL == sOut) goto CleanUpFailureToAllocate; + return sOut; +} + +#define UU_ENCODE_BYTE(b) (char) (((b) == 0) ? '`' : ((b) + ' ')) + +/* bstring bUuEncode (const_bstring src) + * + * Performs a UUEncode of a block of data. The "begin" and "end" lines are + * not appended. + */ +bstring bUuEncode (const_bstring src) { +bstring out; +int i, j, jm; +unsigned int c0, c1, c2; + if (src == NULL || src->slen < 0 || src->data == NULL) return NULL; + if ((out = bfromcstr ("")) == NULL) return NULL; + for (i=0; i < src->slen; i += UU_MAX_LINELEN) { + if ((jm = i + UU_MAX_LINELEN) > src->slen) jm = src->slen; + if (bconchar (out, UU_ENCODE_BYTE (jm - i)) < 0) { + bstrFree (out); + break; + } + for (j = i; j < jm; j += 3) { + c0 = (unsigned int) bchar (src, j ); + c1 = (unsigned int) bchar (src, j + 1); + c2 = (unsigned int) bchar (src, j + 2); + if (bconchar (out, UU_ENCODE_BYTE ( (c0 & 0xFC) >> 2)) < 0 || + bconchar (out, UU_ENCODE_BYTE (((c0 & 0x03) << 4) | ((c1 & 0xF0) >> 4))) < 0 || + bconchar (out, UU_ENCODE_BYTE (((c1 & 0x0F) << 2) | ((c2 & 0xC0) >> 6))) < 0 || + bconchar (out, UU_ENCODE_BYTE ( (c2 & 0x3F))) < 0) { + bstrFree (out); + goto End; + } + } + if (bconchar (out, (char) '\r') < 0 || bconchar (out, (char) '\n') < 0) { + bstrFree (out); + break; + } + } + End:; + return out; +} + +/* bstring bYEncode (const_bstring src) + * + * Performs a YEncode of a block of data. No header or tail info is + * appended. See: http://www.yenc.org/whatis.htm and + * http://www.yenc.org/yenc-draft.1.3.txt + */ +bstring bYEncode (const_bstring src) { +int i; +bstring out; +unsigned char c; + + if (src == NULL || src->slen < 0 || src->data == NULL) return NULL; + if ((out = bfromcstr ("")) == NULL) return NULL; + for (i=0; i < src->slen; i++) { + c = (unsigned char)(src->data[i] + 42); + if (c == '=' || c == '\0' || c == '\r' || c == '\n') { + if (0 > bconchar (out, (char) '=')) { + bdestroy (out); + return NULL; + } + c += (unsigned char) 64; + } + if (0 > bconchar (out, c)) { + bdestroy (out); + return NULL; + } + } + return out; +} + +/* bstring bYDecode (const_bstring src) + * + * Performs a YDecode of a block of data. See: + * http://www.yenc.org/whatis.htm and http://www.yenc.org/yenc-draft.1.3.txt + */ +#define MAX_OB_LEN (64) + +bstring bYDecode (const_bstring src) { +int i; +bstring out; +unsigned char c; +unsigned char octetbuff[MAX_OB_LEN]; +int obl; + + if (src == NULL || src->slen < 0 || src->data == NULL) return NULL; + if ((out = bfromcstr ("")) == NULL) return NULL; + + obl = 0; + + for (i=0; i < src->slen; i++) { + if ('=' == (c = src->data[i])) { /* The = escape mode */ + i++; + if (i >= src->slen) { + bdestroy (out); + return NULL; + } + c = (unsigned char) (src->data[i] - 64); + } else { + if ('\0' == c) { + bdestroy (out); + return NULL; + } + + /* Extraneous CR/LFs are to be ignored. */ + if (c == '\r' || c == '\n') continue; + } + + octetbuff[obl] = (unsigned char) ((int) c - 42); + obl++; + + if (obl >= MAX_OB_LEN) { + if (0 > bcatblk (out, octetbuff, obl)) { + bdestroy (out); + return NULL; + } + obl = 0; + } + } + + if (0 > bcatblk (out, octetbuff, obl)) { + bdestroy (out); + out = NULL; + } + return out; +} + +/* int bSGMLEncode (bstring b) + * + * Change the string into a version that is quotable in SGML (HTML, XML). + */ +int bSGMLEncode (bstring b) { +static struct tagbstring fr[4][2] = { + { bsStatic("&"), bsStatic("&") }, + { bsStatic("\""), bsStatic(""") }, + { bsStatic("<"), bsStatic("<") }, + { bsStatic(">"), bsStatic(">") } }; +int i; + for (i = 0; i < 4; i++) { + int ret = bfindreplace (b, &fr[i][0], &fr[i][1], 0); + if (0 > ret) return ret; + } + return 0; +} + +/* bstring bStrfTime (const char * fmt, const struct tm * timeptr) + * + * Takes a format string that is compatible with strftime and a struct tm + * pointer, formats the time according to the format string and outputs + * the bstring as a result. Note that if there is an early generation of a + * '\0' character, the bstring will be truncated to this end point. + */ +bstring bStrfTime (const char * fmt, const struct tm * timeptr) { +#if defined (__TURBOC__) && !defined (__BORLANDC__) +static struct tagbstring ns = bsStatic ("bStrfTime Not supported"); + fmt = fmt; + timeptr = timeptr; + return &ns; +#else +bstring buff; +int n; +size_t r; + + if (fmt == NULL) return NULL; + + /* Since the length is not determinable beforehand, a search is + performed using the truncating "strftime" call on increasing + potential sizes for the output result. */ + + if ((n = (int) (2*strlen (fmt))) < 16) n = 16; + buff = bfromcstralloc (n+2, ""); + + for (;;) { + if (BSTR_OK != balloc (buff, n + 2)) { + bdestroy (buff); + return NULL; + } + + r = strftime ((char *) buff->data, n + 1, fmt, timeptr); + + if (r > 0) { + buff->slen = (int) r; + break; + } + + n += n; + } + + return buff; +#endif +} + +/* int bSetCstrChar (bstring a, int pos, char c) + * + * Sets the character at position pos to the character c in the bstring a. + * If the character c is NUL ('\0') then the string is truncated at this + * point. Note: this does not enable any other '\0' character in the bstring + * as terminator indicator for the string. pos must be in the position + * between 0 and b->slen inclusive, otherwise BSTR_ERR will be returned. + */ +int bSetCstrChar (bstring b, int pos, char c) { + if (NULL == b || b->mlen <= 0 || b->slen < 0 || b->mlen < b->slen) + return BSTR_ERR; + if (pos < 0 || pos > b->slen) return BSTR_ERR; + + if (pos == b->slen) { + if ('\0' != c) return bconchar (b, c); + return 0; + } + + b->data[pos] = (unsigned char) c; + if ('\0' == c) b->slen = pos; + + return 0; +} + +/* int bSetChar (bstring b, int pos, char c) + * + * Sets the character at position pos to the character c in the bstring a. + * The string is not truncated if the character c is NUL ('\0'). pos must + * be in the position between 0 and b->slen inclusive, otherwise BSTR_ERR + * will be returned. + */ +int bSetChar (bstring b, int pos, char c) { + if (NULL == b || b->mlen <= 0 || b->slen < 0 || b->mlen < b->slen) + return BSTR_ERR; + if (pos < 0 || pos > b->slen) return BSTR_ERR; + + if (pos == b->slen) { + return bconchar (b, c); + } + + b->data[pos] = (unsigned char) c; + return 0; +} + +#define INIT_SECURE_INPUT_LENGTH (256) + +/* bstring bSecureInput (int maxlen, int termchar, + * bNgetc vgetchar, void * vgcCtx) + * + * Read input from an abstracted input interface, for a length of at most + * maxlen characters. If maxlen <= 0, then there is no length limit put + * on the input. The result is terminated early if vgetchar() return EOF + * or the user specified value termchar. + * + */ +bstring bSecureInput (int maxlen, int termchar, bNgetc vgetchar, void * vgcCtx) { +int i, m, c; +bstring b, t; + + if (!vgetchar) return NULL; + + b = bfromcstralloc (INIT_SECURE_INPUT_LENGTH, ""); + if ((c = UCHAR_MAX + 1) == termchar) c++; + + for (i=0; ; i++) { + if (termchar == c || (maxlen > 0 && i >= maxlen)) break; + c = vgetchar (vgcCtx); + if (EOF == c) break; + + if (i+1 >= b->mlen) { + + /* Double size, and deal with numeric overflows */ + + if (b->mlen <= INT_MAX / 2) m = b->mlen << 1; + else if (b->mlen <= INT_MAX - 1024) m = b->mlen + 1024; + else if (b->mlen <= INT_MAX - 16) m = b->mlen + 16; + else if (b->mlen <= INT_MAX - 1) m = b->mlen + 1; + else { + bSecureDestroy (b); /* Cleanse partial buffer */ + return NULL; + } + + t = bfromcstrrangealloc (b->mlen + 1, m, ""); + if (t) memcpy (t->data, b->data, i); + bSecureDestroy (b); /* Cleanse previous buffer */ + b = t; + if (!b) return b; + } + + b->data[i] = (unsigned char) c; + } + + b->slen = i; + b->data[i] = (unsigned char) '\0'; + return b; +} + +#define BWS_BUFF_SZ (1024) + +struct bwriteStream { + bstring buff; /* Buffer for underwrites */ + void * parm; /* The stream handle for core stream */ + bNwrite writeFn; /* fwrite work-a-like fnptr for core stream */ + int isEOF; /* track stream's EOF state */ + int minBuffSz; +}; + +/* struct bwriteStream * bwsOpen (bNwrite writeFn, void * parm) + * + * Wrap a given open stream (described by a fwrite work-a-like function + * pointer and stream handle) into an open bwriteStream suitable for write + * streaming functions. + */ +struct bwriteStream * bwsOpen (bNwrite writeFn, void * parm) { +struct bwriteStream * ws; + + if (NULL == writeFn) return NULL; + ws = (struct bwriteStream *) malloc (sizeof (struct bwriteStream)); + if (ws) { + if (NULL == (ws->buff = bfromcstr (""))) { + free (ws); + ws = NULL; + } else { + ws->parm = parm; + ws->writeFn = writeFn; + ws->isEOF = 0; + ws->minBuffSz = BWS_BUFF_SZ; + } + } + return ws; +} + +#define internal_bwswriteout(ws,b) { \ + if ((b)->slen > 0) { \ + if (1 != (ws->writeFn ((b)->data, (b)->slen, 1, ws->parm))) { \ + ws->isEOF = 1; \ + return BSTR_ERR; \ + } \ + } \ +} + +/* int bwsWriteFlush (struct bwriteStream * ws) + * + * Force any pending data to be written to the core stream. + */ +int bwsWriteFlush (struct bwriteStream * ws) { + if (NULL == ws || ws->isEOF || 0 >= ws->minBuffSz || + NULL == ws->writeFn || NULL == ws->buff) return BSTR_ERR; + internal_bwswriteout (ws, ws->buff); + ws->buff->slen = 0; + return 0; +} + +/* int bwsWriteBstr (struct bwriteStream * ws, const_bstring b) + * + * Send a bstring to a bwriteStream. If the stream is at EOF BSTR_ERR is + * returned. Note that there is no deterministic way to determine the exact + * cut off point where the core stream stopped accepting data. + */ +int bwsWriteBstr (struct bwriteStream * ws, const_bstring b) { +struct tagbstring t; +int l; + + if (NULL == ws || NULL == b || NULL == ws->buff || + ws->isEOF || 0 >= ws->minBuffSz || NULL == ws->writeFn) + return BSTR_ERR; + + /* Buffer prepacking optimization */ + if (b->slen > 0 && ws->buff->mlen - ws->buff->slen > b->slen) { + static struct tagbstring empty = bsStatic (""); + if (0 > bconcat (ws->buff, b)) return BSTR_ERR; + return bwsWriteBstr (ws, &empty); + } + + if (0 > (l = ws->minBuffSz - ws->buff->slen)) { + internal_bwswriteout (ws, ws->buff); + ws->buff->slen = 0; + l = ws->minBuffSz; + } + + if (b->slen < l) return bconcat (ws->buff, b); + + if (0 > bcatblk (ws->buff, b->data, l)) return BSTR_ERR; + internal_bwswriteout (ws, ws->buff); + ws->buff->slen = 0; + + bmid2tbstr (t, (bstring) b, l, b->slen); + + if (t.slen >= ws->minBuffSz) { + internal_bwswriteout (ws, &t); + return 0; + } + + return bassign (ws->buff, &t); +} + +/* int bwsWriteBlk (struct bwriteStream * ws, void * blk, int len) + * + * Send a block of data a bwriteStream. If the stream is at EOF BSTR_ERR is + * returned. + */ +int bwsWriteBlk (struct bwriteStream * ws, void * blk, int len) { +struct tagbstring t; + if (NULL == blk || len < 0) return BSTR_ERR; + blk2tbstr (t, blk, len); + return bwsWriteBstr (ws, &t); +} + +/* int bwsIsEOF (const struct bwriteStream * ws) + * + * Returns 0 if the stream is currently writable, 1 if the core stream has + * responded by not accepting the previous attempted write. + */ +int bwsIsEOF (const struct bwriteStream * ws) { + if (NULL == ws || NULL == ws->buff || 0 > ws->minBuffSz || + NULL == ws->writeFn) return BSTR_ERR; + return ws->isEOF; +} + +/* int bwsBuffLength (struct bwriteStream * ws, int sz) + * + * Set the length of the buffer used by the bwsStream. If sz is zero, the + * length is not set. This function returns with the previous length. + */ +int bwsBuffLength (struct bwriteStream * ws, int sz) { +int oldSz; + if (ws == NULL || sz < 0) return BSTR_ERR; + oldSz = ws->minBuffSz; + if (sz > 0) ws->minBuffSz = sz; + return oldSz; +} + +/* void * bwsClose (struct bwriteStream * s) + * + * Close the bwriteStream, and return the handle to the stream that was + * originally used to open the given stream. Note that even if the stream + * is at EOF it still needs to be closed with a call to bwsClose. + */ +void * bwsClose (struct bwriteStream * ws) { +void * parm; + if (NULL == ws || NULL == ws->buff || 0 >= ws->minBuffSz || + NULL == ws->writeFn) return NULL; + bwsWriteFlush (ws); + parm = ws->parm; + ws->parm = NULL; + ws->minBuffSz = -1; + ws->writeFn = NULL; + bstrFree (ws->buff); + free (ws); + return parm; +} diff --git a/test/interop/bstraux.h b/test/interop/bstraux.h new file mode 100644 index 0000000..9f30e3c --- /dev/null +++ b/test/interop/bstraux.h @@ -0,0 +1,115 @@ +/* + * This source file is part of the bstring string library. This code was + * written by Paul Hsieh in 2002-2015, and is covered by the BSD open source + * license and the GPL. Refer to the accompanying documentation for details + * on usage and license. + */ + +/* + * bstraux.h + * + * This file is not a necessary part of the core bstring library itself, but + * is just an auxilliary module which includes miscellaneous or trivial + * functions. + */ + +#ifndef BSTRAUX_INCLUDE +#define BSTRAUX_INCLUDE + +#include <time.h> +#include "bstrlib.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Safety mechanisms */ +#define bstrDeclare(b) bstring (b) = NULL; +#define bstrFree(b) {if ((b) != NULL && (b)->slen >= 0 && (b)->mlen >= (b)->slen) { bdestroy (b); (b) = NULL; }} + +/* Backward compatibilty with previous versions of Bstrlib */ +#if !defined(BSTRLIB_REDUCE_NAMESPACE_POLLUTION) +#define bAssign(a,b) ((bassign)((a), (b))) +#define bSubs(b,pos,len,a,c) ((breplace)((b),(pos),(len),(a),(unsigned char)(c))) +#define bStrchr(b,c) ((bstrchr)((b), (c))) +#define bStrchrFast(b,c) ((bstrchr)((b), (c))) +#define bCatCstr(b,s) ((bcatcstr)((b), (s))) +#define bCatBlk(b,s,len) ((bcatblk)((b),(s),(len))) +#define bCatStatic(b,s) bcatStatic(b,s) +#define bTrunc(b,n) ((btrunc)((b), (n))) +#define bReplaceAll(b,find,repl,pos) ((bfindreplace)((b),(find),(repl),(pos))) +#define bUppercase(b) ((btoupper)(b)) +#define bLowercase(b) ((btolower)(b)) +#define bCaselessCmp(a,b) ((bstricmp)((a), (b))) +#define bCaselessNCmp(a,b,n) ((bstrnicmp)((a), (b), (n))) +#define bBase64Decode(b) (bBase64DecodeEx ((b), NULL)) +#define bUuDecode(b) (bUuDecodeEx ((b), NULL)) +#endif + +/* Unusual functions */ +extern struct bStream * bsFromBstr (const_bstring b); +extern bstring bTail (bstring b, int n); +extern bstring bHead (bstring b, int n); +extern int bSetCstrChar (bstring a, int pos, char c); +extern int bSetChar (bstring b, int pos, char c); +extern int bFill (bstring a, char c, int len); +extern int bReplicate (bstring b, int n); +extern int bReverse (bstring b); +extern int bInsertChrs (bstring b, int pos, int len, unsigned char c, unsigned char fill); +extern bstring bStrfTime (const char * fmt, const struct tm * timeptr); +#define bAscTime(t) (bStrfTime ("%c\n", (t))) +#define bCTime(t) ((t) ? bAscTime (localtime (t)) : NULL) + +/* Spacing formatting */ +extern int bJustifyLeft (bstring b, int space); +extern int bJustifyRight (bstring b, int width, int space); +extern int bJustifyMargin (bstring b, int width, int space); +extern int bJustifyCenter (bstring b, int width, int space); + +/* Esoteric standards specific functions */ +extern char * bStr2NetStr (const_bstring b); +extern bstring bNetStr2Bstr (const char * buf); +extern bstring bBase64Encode (const_bstring b); +extern bstring bBase64DecodeEx (const_bstring b, int * boolTruncError); +extern struct bStream * bsUuDecode (struct bStream * sInp, int * badlines); +extern bstring bUuDecodeEx (const_bstring src, int * badlines); +extern bstring bUuEncode (const_bstring src); +extern bstring bYEncode (const_bstring src); +extern bstring bYDecode (const_bstring src); +extern int bSGMLEncode (bstring b); + +/* Writable stream */ +typedef int (* bNwrite) (const void * buf, size_t elsize, size_t nelem, void * parm); + +struct bwriteStream * bwsOpen (bNwrite writeFn, void * parm); +int bwsWriteBstr (struct bwriteStream * stream, const_bstring b); +int bwsWriteBlk (struct bwriteStream * stream, void * blk, int len); +int bwsWriteFlush (struct bwriteStream * stream); +int bwsIsEOF (const struct bwriteStream * stream); +int bwsBuffLength (struct bwriteStream * stream, int sz); +void * bwsClose (struct bwriteStream * stream); + +/* Security functions */ +#define bSecureDestroy(b) { \ +bstring bstr__tmp = (b); \ + if (bstr__tmp && bstr__tmp->mlen > 0 && bstr__tmp->data) { \ + (void) memset (bstr__tmp->data, 0, (size_t) bstr__tmp->mlen); \ + bdestroy (bstr__tmp); \ + } \ +} +#define bSecureWriteProtect(t) { \ + if ((t).mlen >= 0) { \ + if ((t).mlen > (t).slen)) { \ + (void) memset ((t).data + (t).slen, 0, (size_t) (t).mlen - (t).slen); \ + } \ + (t).mlen = -1; \ + } \ +} +extern bstring bSecureInput (int maxlen, int termchar, + bNgetc vgetchar, void * vgcCtx); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/test/interop/ping_test.c b/test/interop/ping_test.c new file mode 100644 index 0000000..6d699da --- /dev/null +++ b/test/interop/ping_test.c @@ -0,0 +1,27 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" + +TEST ping_test() +{ + TestClient *client; + + client = TestClientNew("clienta"); + ASSERT(TestClientConnect(client, "localhost", 1883, 1, 1)); + ASSERT(TestClientWait(client, 5000)); + TestClientDisconnect(client); + TestClientFree(client); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(ping_test); + GREATEST_MAIN_END(); +} diff --git a/test/interop/testclient.c b/test/interop/testclient.c index 8d616f6..09782b2 100644 --- a/test/interop/testclient.c +++ b/test/interop/testclient.c @@ -14,12 +14,15 @@ static void TestClientOnConnect(MqttClient *client, } static void TestClientOnSubscribe(MqttClient *client, int id, - const char *filter, - MqttSubscriptionStatus status) + int *qos, int count) { TestClient *testClient = (TestClient *) MqttClientGetUserData(client); testClient->subId = id; - testClient->subStatus[testClient->subCount++] = status; + for (testClient->subCount = 0; testClient->subCount < count; + ++testClient->subCount) + { + testClient->subStatus[testClient->subCount] = qos[testClient->subCount]; + } } static void TestClientOnPublish(MqttClient *client, int id) @@ -37,6 +40,12 @@ static void TestClientOnMessage(MqttClient *client, const char *topic, SIMPLEQ_INSERT_TAIL(&testClient->messages, msg, chain); } +static void TestClientOnUnsubscribe(MqttClient *client, int id) +{ + TestClient *testClient = (TestClient *) MqttClientGetUserData(client); + testClient->unsubId = id; +} + Message *MessageNew(const char *topic, const void *data, size_t size, int qos, int retain) { @@ -69,6 +78,8 @@ TestClient *TestClientNew(const char *clientId) { TestClient *client = calloc(1, sizeof(*client)); + client->clientId = clientId; + client->client = MqttClientNew(clientId); MqttClientSetUserData(client->client, client); @@ -79,6 +90,7 @@ TestClient *TestClientNew(const char *clientId) MqttClientSetOnSubscribe(client->client, TestClientOnSubscribe); MqttClientSetOnPublish(client->client, TestClientOnPublish); MqttClientSetOnMessage(client->client, TestClientOnMessage); + MqttClientSetOnUnsubscribe(client->client, TestClientOnUnsubscribe); SIMPLEQ_INIT(&client->messages); @@ -235,8 +247,8 @@ int TestClientWait(TestClient *client, int timeout) printf("TestClientWait timeout:%d rc:%d\n", timeout, rc); int64_t now = MqttGetCurrentTime(); int64_t elapsed = now - start; - timeout -= elapsed; printf("TestClientWait elapsed:%d\n", (int) elapsed); + timeout = timeout - elapsed; if (timeout <= 0) { break; @@ -245,3 +257,27 @@ int TestClientWait(TestClient *client, int timeout) return rc != -1; } + +int TestClientUnsubscribe(TestClient *client, const char *topic) +{ + int id = MqttClientUnsubscribe(client->client, topic); + int rc; + + client->unsubId = -1; + + while ((rc = MqttClientRunOnce(client->client, -1)) != -1) + { + if (client->unsubId != -1) + { + if (client->unsubId != id) + { + printf( + "WARNING: unsubscribe id mismatch: expected %d, got %d\n", + id, client->unsubId); + } + break; + } + } + + return rc != -1; +} diff --git a/test/interop/testclient.h b/test/interop/testclient.h index 2aa229b..3665f5e 100644 --- a/test/interop/testclient.h +++ b/test/interop/testclient.h @@ -21,6 +21,8 @@ typedef struct TestClient TestClient; struct TestClient { + const char *clientId; + MqttClient *client; /* OnConnect */ @@ -37,6 +39,9 @@ struct TestClient /* OnMessage */ SIMPLEQ_HEAD(messages, Message) messages; + + /* OnUnsubscribe */ + int unsubId; }; Message *MessageNew(const char *topic, const void *data, size_t size, @@ -65,4 +70,6 @@ int TestClientMessageCount(TestClient *client); int TestClientWait(TestClient *client, int timeout); +int TestClientUnsubscribe(TestClient *client, const char *topic); + #endif diff --git a/test/interop/unsubscribe_test.c b/test/interop/unsubscribe_test.c new file mode 100644 index 0000000..a7e4668 --- /dev/null +++ b/test/interop/unsubscribe_test.c @@ -0,0 +1,31 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" + +TEST unsubscribe_test() +{ + TestClient *client; + + client = TestClientNew("clienta"); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + ASSERT(TestClientSubscribe(client, topics[0], 2)); + ASSERT(TestClientPublish(client, 2, 0, topics[0], "msg")); + ASSERT(TestClientUnsubscribe(client, topics[0])); + ASSERT(TestClientPublish(client, 2, 0, topics[0], "msg")); + TestClientDisconnect(client); + ASSERT_EQ(1, TestClientMessageCount(client)); + TestClientFree(client); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(unsubscribe_test); + GREATEST_MAIN_END(); +} diff --git a/test/interop/username_and_password_test.c b/test/interop/username_and_password_test.c new file mode 100644 index 0000000..6e0eaab --- /dev/null +++ b/test/interop/username_and_password_test.c @@ -0,0 +1,30 @@ +#include "greatest.h" +#include "testclient.h" +#include "cleanup.c" +#include "topics.c" + +TEST username_and_password_test() +{ + TestClient *client; + + client = TestClientNew("clienta"); + ASSERT_EQ(0, MqttClientSetAuth(client->client, "myusername", NULL)); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + TestClientDisconnect(client); + ASSERT_EQ(0, MqttClientSetAuth(client->client, "myusername", "mypassword")); + ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1)); + TestClientDisconnect(client); + TestClientFree(client); + + PASS(); +} + +GREATEST_MAIN_DEFS(); + +int main(int argc, char **argv) +{ + GREATEST_MAIN_BEGIN(); + cleanup(); + RUN_TEST(username_and_password_test); + GREATEST_MAIN_END(); +} diff --git a/tools/sub.c b/tools/sub.c index ebf3372..e0577c9 100644 --- a/tools/sub.c +++ b/tools/sub.c @@ -21,11 +21,10 @@ void onConnect(MqttClient *client, MqttConnectionStatus status, MqttClientSubscribe(client, options->topic, options->qos); } -void onSubscribe(MqttClient *client, int id, const char *filter, - MqttSubscriptionStatus status) +void onSubscribe(MqttClient *client, int id, int *qos, int count) { (void) client; - printf("onSubscribe id=%d status=%d\n", id, status); + printf("onSubscribe id=%d\n", id); } void onMessage(MqttClient *client, const char *topic, const void *data, |
