aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorOskari Timperi <oskari.timperi@iki.fi>2017-03-18 09:29:19 +0200
committerOskari Timperi <oskari.timperi@iki.fi>2017-03-18 09:29:19 +0200
commit7aeef53b089272f4633cc40512296bfd884a58d4 (patch)
tree894753ced0495f725ad8362859f88d5b61e29eb7 /src
parente9958e8a0f5aa5fbe0a4a03be42b8bf640add6f7 (diff)
parent2c76b0da9e0aba2211d5b4a8e51c79e47ad9b6c8 (diff)
downloadmqtt-7aeef53b089272f4633cc40512296bfd884a58d4.tar.gz
mqtt-7aeef53b089272f4633cc40512296bfd884a58d4.zip
Merge branch 'the-great-refactor'v0.5
* the-great-refactor: Add big_message_test Fix publish message serialization Modify the code to use nonblocking sockets Fix indentation Free userName and password in MqttClientFree() Add forgotten files Massive refactoring of the internals
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt4
-rw-r--r--src/client.c1310
-rw-r--r--src/deserialize.c286
-rw-r--r--src/deserialize.h11
-rw-r--r--src/message.c11
-rw-r--r--src/message.h40
-rw-r--r--src/mqtt.h6
-rw-r--r--src/packet.c76
-rw-r--r--src/packet.h98
-rw-r--r--src/serialize.c326
-rw-r--r--src/serialize.h11
-rw-r--r--src/socket.c73
-rw-r--r--src/socket.h26
-rw-r--r--src/stream.c10
-rw-r--r--src/stream.h2
-rw-r--r--src/stream_mqtt.c30
-rw-r--r--src/stream_mqtt.h6
-rw-r--r--src/stringstream.c115
-rw-r--r--src/stringstream.h21
19 files changed, 1219 insertions, 1243 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index f51fabb..5a565ca 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -2,14 +2,14 @@ ADD_SUBDIRECTORY(lib)
ADD_LIBRARY(mqtt STATIC
client.c
- deserialize.c
misc.c
packet.c
- serialize.c
socket.c
socketstream.c
stream.c
stream_mqtt.c
+ stringstream.c
+ message.c
$<TARGET_OBJECTS:bstrlib>
)
diff --git a/src/client.c b/src/client.c
index a6b0998..b95c8d5 100644
--- a/src/client.c
+++ b/src/client.c
@@ -5,10 +5,11 @@
#include "socketstream.h"
#include "socket.h"
#include "misc.h"
-#include "serialize.h"
-#include "deserialize.h"
#include "log.h"
#include "private.h"
+#include "stringstream.h"
+#include "stream_mqtt.h"
+#include "message.h"
#include "queue.h"
@@ -25,8 +26,14 @@
#error define PRId64 for your platform
#endif
-TAILQ_HEAD(MessageList, MqttPacket);
-typedef struct MessageList MessageList;
+typedef enum MqttClientState MqttClientState;
+
+enum MqttClientState
+{
+ MqttClientStateDisconnected,
+ MqttClientStateConnecting,
+ MqttClientStateConnected,
+};
struct MqttClient
{
@@ -56,9 +63,9 @@ struct MqttClient
/* packets waiting to be sent over network */
SIMPLEQ_HEAD(, MqttPacket) sendQueue;
/* sent messages that are not done yet */
- MessageList outMessages;
+ MqttMessageList outMessages;
/* received messages that are not done yet */
- MessageList inMessages;
+ MqttMessageList inMessages;
int sessionPresent;
/* when was the last packet sent */
int64_t lastPacketSentTime;
@@ -80,18 +87,14 @@ struct MqttClient
int paused;
bstring userName;
bstring password;
-};
-
-enum MessageState
-{
- MessageStateQueued = 100,
- MessageStateSend,
- MessageStateSent
+ /* The packet we are receiving */
+ MqttPacket inPacket;
+ MqttClientState state;
};
static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet);
static int MqttClientQueueSimplePacket(MqttClient *client, int type);
-static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet);
+static int MqttClientSendPacket(MqttClient *client);
static int MqttClientRecvPacket(MqttClient *client);
static uint16_t MqttClientNextPacketId(MqttClient *client);
static void MqttClientProcessMessageQueue(MqttClient *client);
@@ -99,14 +102,14 @@ static void MqttClientClearQueues(MqttClient *client);
static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client)
{
- MqttPacket *packet;
+ MqttMessage *msg;
int queued = 0;
int inMessagesCount = 0;
int outMessagesCount = 0;
- TAILQ_FOREACH(packet, &client->outMessages, messages)
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (packet->state == MessageStateQueued)
+ if (msg->state == MqttMessageStateQueued)
{
++queued;
}
@@ -114,7 +117,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client)
++outMessagesCount;
}
- TAILQ_FOREACH(packet, &client->inMessages, messages)
+ TAILQ_FOREACH(msg, &client->inMessages, chain)
{
++inMessagesCount;
}
@@ -142,6 +145,8 @@ MqttClient *MqttClientNew(const char *clientId)
client->maxQueued = 0;
client->maxInflight = 20;
+ client->state = MqttClientStateDisconnected;
+
TAILQ_INIT(&client->outMessages);
TAILQ_INIT(&client->inMessages);
SIMPLEQ_INIT(&client->sendQueue);
@@ -157,6 +162,8 @@ void MqttClientFree(MqttClient *client)
bdestroy(client->willTopic);
bdestroy(client->willMessage);
bdestroy(client->host);
+ bdestroy(client->userName);
+ bdestroy(client->password);
if (client->stream.sock != -1)
{
@@ -212,15 +219,54 @@ void MqttClientSetOnPublish(MqttClient *client,
client->onPublish = cb;
}
+static const struct tagbstring MqttProtocolId = bsStatic("MQTT");
+static const char MqttProtocolLevel = 0x04;
+
+static unsigned char MqttClientConnectFlags(MqttClient *client)
+{
+ unsigned char connectFlags = 0;
+
+ if (client->cleanSession)
+ {
+ connectFlags |= 0x02;
+ }
+
+ if (client->willTopic)
+ {
+ connectFlags |= 0x04;
+ connectFlags |= (client->willQos & 3) << 3;
+ connectFlags |= (client->willRetain & 1) << 5;
+ }
+
+ if (client->userName)
+ {
+ connectFlags |= 0x80;
+ if (client->password)
+ {
+ connectFlags |= 0x40;
+ }
+ }
+
+ return connectFlags;
+}
+
int MqttClientConnect(MqttClient *client, const char *host, short port,
int keepAlive, int cleanSession)
{
int sock;
- MqttPacketConnect *packet;
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
assert(client != NULL);
assert(host != NULL);
+ if (client->state != MqttClientStateDisconnected)
+ {
+ LOG_ERROR("client must be disconnected to connect");
+ return -1;
+ }
+
if (client->host)
bassigncstr(client->host, host);
else
@@ -242,10 +288,13 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
LOG_DEBUG("connecting");
- if ((sock = SocketConnect(host, port)) == -1)
+ if ((sock = SocketConnect(host, port, 1)) == -1)
{
- LOG_ERROR("SocketConnect failed!");
- return -1;
+ if (SocketErrno != SOCKET_EINPROGRESS)
+ {
+ LOG_ERROR("SocketConnect failed!");
+ return -1;
+ }
}
if (SocketStreamOpen(&client->stream, sock) == -1)
@@ -253,44 +302,39 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
return -1;
}
- packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect);
+ packet = MqttPacketNew(MqttPacketTypeConnect);
if (!packet)
return -1;
- if (client->cleanSession)
- {
- packet->connectFlags |= 0x02;
- }
+ StringStreamInit(&ss);
- packet->keepAlive = client->keepAlive;
-
- packet->clientId = bstrcpy(client->clientId);
+ StreamWriteMqttString(&MqttProtocolId, pss);
+ StreamWriteByte(MqttProtocolLevel, pss);
+ StreamWriteByte(MqttClientConnectFlags(client), pss);
+ StreamWriteUint16Be(client->keepAlive, pss);
+ StreamWriteMqttString(client->clientId, pss);
if (client->willTopic)
{
- packet->connectFlags |= 0x04;
-
- packet->willTopic = bstrcpy(client->willTopic);
- packet->willMessage = bstrcpy(client->willMessage);
-
- packet->connectFlags |= (client->willQos & 3) << 3;
- packet->connectFlags |= (client->willRetain & 1) << 5;
+ StreamWriteMqttString(client->willTopic, pss);
+ StreamWriteMqttString(client->willMessage, pss);
}
if (client->userName)
{
- packet->connectFlags |= 0x80;
- packet->userName = bstrcpy(client->userName);
-
- if (client->password)
+ StreamWriteMqttString(client->userName, pss);
+ if(client->password)
{
- packet->connectFlags |= 0x40;
- packet->password = bstrcpy(client->password);
+ StreamWriteMqttString(client->password, pss);
}
}
- MqttClientQueuePacket(client, &packet->base);
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
+
+ client->state = MqttClientStateConnecting;
return 0;
}
@@ -303,13 +347,14 @@ int MqttClientDisconnect(MqttClient *client)
int MqttClientIsConnected(MqttClient *client)
{
- return client->stream.sock != -1;
+ return client->stream.sock != -1 &&
+ client->state == MqttClientStateConnected;
}
int MqttClientRunOnce(MqttClient *client, int timeout)
{
int rv;
- int events;
+ int events = 0;
assert(client != NULL);
@@ -319,19 +364,31 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
return -1;
}
- events = EV_READ;
+ if (client->state == MqttClientStateConnected)
+ {
+ events = EV_READ;
- /* Handle outMessages and inMessages, moving queued messages to sendQueue
- if there are less than maxInflight number of messages in flight */
- MqttClientProcessMessageQueue(client);
+ /* Handle outMessages and inMessages, moving queued messages to sendQueue
+ if there are less than maxInflight number of messages in flight */
+ MqttClientProcessMessageQueue(client);
- if (SIMPLEQ_EMPTY(&client->sendQueue))
+ if (SIMPLEQ_EMPTY(&client->sendQueue))
+ {
+ LOG_DEBUG("nothing to write");
+ }
+ else
+ {
+ events |= EV_WRITE;
+ }
+ }
+ else if (client->state == MqttClientStateConnecting)
{
- LOG_DEBUG("nothing to write");
+ events = EV_WRITE;
}
else
{
- events |= EV_WRITE;
+ LOG_ERROR("not connected");
+ return -1;
}
LOG_DEBUG("selecting");
@@ -339,6 +396,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
if (timeout < 0)
{
timeout = client->keepAlive * 1000;
+ if (timeout == 0)
+ {
+ timeout = 30 * 1000;
+ }
+ }
+ else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0)
+ {
+ timeout = client->keepAlive * 1000;
}
rv = SocketSelect(client->stream.sock, &events, timeout);
@@ -354,22 +419,26 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
if (events & EV_WRITE)
{
- MqttPacket *packet;
-
LOG_DEBUG("socket writable");
- packet = SIMPLEQ_FIRST(&client->sendQueue);
-
- if (packet)
+ if (client->state == MqttClientStateConnecting)
{
- SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue);
-
- if (MqttClientSendPacket(client, packet) == -1)
+ int sockError;
+ SocketGetError(client->stream.sock, &sockError);
+ LOG_DEBUG("sockError: %d", sockError);
+ if (sockError == 0)
{
- LOG_ERROR("MqttClientSendPacket failed");
- client->stopped = 1;
+ LOG_DEBUG("connected!");
+ client->state = MqttClientStateConnected;
+ return 0;
}
}
+
+ if (MqttClientSendPacket(client) == -1)
+ {
+ LOG_ERROR("MqttClientSendPacket failed");
+ client->stopped = 1;
+ }
}
if (events & EV_READ)
@@ -433,10 +502,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter,
}
int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters,
- int *qos, size_t count)
+ int *qos, size_t count)
{
- MqttPacketSubscribe *packet = NULL;
+ MqttPacket *packet = NULL;
size_t i;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
assert(client != NULL);
assert(topicFilters != NULL);
@@ -444,68 +515,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters,
assert(qos != NULL);
assert(count > 0);
- packet = (MqttPacketSubscribe *) MqttPacketWithIdNew(
- MqttPacketTypeSubscribe, MqttClientNextPacketId(client));
+ packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe,
+ MqttClientNextPacketId(client));
if (!packet)
return -1;
- packet->topicFilters = bstrListCreate();
- bstrListAllocMin(packet->topicFilters, count);
+ packet->flags = 0x2;
+
+ StringStreamInit(&ss);
- packet->qos = (int *) malloc(sizeof(int) * count);
+ StreamWriteUint16Be(packet->id, pss);
+
+ LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id);
for (i = 0; i < count; ++i)
{
- packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]);
- ++packet->topicFilters->qty;
+ struct tagbstring filter;
+ btfromcstr(filter, topicFilters[i]);
+ StreamWriteMqttString(&filter, pss);
+ StreamWriteByte(qos[i] & 3, pss);
}
- memcpy(packet->qos, qos, sizeof(int) * count);
-
- MqttClientQueuePacket(client, (MqttPacket *) packet);
+ packet->payload = ss.buffer;
- TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages);
+ MqttClientQueuePacket(client, packet);
- return MqttPacketId(packet);
+ return packet->id;
}
int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter)
{
- MqttPacketUnsubscribe *packet = NULL;
+ MqttPacket *packet = NULL;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ struct tagbstring filter;
assert(client != NULL);
assert(topicFilter != NULL);
- packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew(
- MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client));
+ packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe,
+ MqttClientNextPacketId(client));
+
+ if (!packet)
+ return -1;
+
+ packet->flags = 0x02;
+
+ StringStreamInit(&ss);
- packet->topicFilter = bfromcstr(topicFilter);
+ StreamWriteUint16Be(packet->id, pss);
- MqttClientQueuePacket(client, (MqttPacket *) packet);
+ btfromcstr(filter, topicFilter);
- TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages);
+ StreamWriteMqttString(&filter, pss);
- return MqttPacketId(packet);
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
+
+ return packet->id;
}
static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client)
{
- MqttPacket *packet;
+ MqttMessage *msg;
int count = 0;
- TAILQ_FOREACH(packet, &client->outMessages, messages)
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
++count;
}
return count;
}
+static MqttPacket *PublishToPacket(MqttMessage *msg)
+{
+ MqttPacket *packet = NULL;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ if (msg->qos > 0)
+ {
+ packet = MqttPacketWithIdNew(MqttPacketTypePublish,
+ msg->id);
+ }
+ else
+ {
+ packet = MqttPacketNew(MqttPacketTypePublish);
+ }
+
+ if (!packet)
+ return NULL;
+
+ packet->message = msg;
+
+ StringStreamInit(&ss);
+
+ StreamWriteMqttString(msg->topic, pss);
+
+ if (msg->qos > 0)
+ {
+ StreamWriteUint16Be(msg->id, pss);
+ }
+
+ StreamWrite(bdata(msg->payload), blength(msg->payload), pss);
+
+ packet->payload = ss.buffer;
+ packet->flags = (msg->qos & 3) << 1;
+ packet->flags |= msg->retain & 1;
+
+ return packet;
+}
+
int MqttClientPublish(MqttClient *client, int qos, int retain,
const char *topic, const void *data, size_t size)
{
- MqttPacketPublish *packet;
-
- assert(client != NULL);
+ MqttMessage *message;
/* first check if the queue is already full */
if (qos > 0 && client->maxQueued > 0 &&
@@ -514,55 +639,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain,
return -1;
}
- if (qos > 0)
+ message = calloc(1, sizeof(*message));
+ if (!message)
{
- packet = (MqttPacketPublish *) MqttPacketWithIdNew(
- MqttPacketTypePublish, MqttClientNextPacketId(client));
+ return -1;
}
- else
+
+ message->state = MqttMessageStateQueued;
+ message->qos = qos;
+ message->retain = retain;
+ message->dup = 0;
+ message->timestamp = MqttGetCurrentTime();
+
+ if (qos == 0)
{
- packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish);
- }
+ /* Copy payload and topic directly from user buffers as we don't need
+ to keep the message data around after this function. */
+ MqttPacket *packet;
+ struct tagbstring bttopic, btpayload;
- if (!packet)
- return -1;
+ btfromcstr(bttopic, topic);
+ message->topic = &bttopic;
- packet->qos = qos;
- packet->retain = retain;
- packet->topicName = bfromcstr(topic);
- packet->message = blk2bstr(data, size);
+ btfromblk(btpayload, data, size);
+ message->payload = &btpayload;
- if (qos > 0)
- {
- /* check how many messages there are coming in and going out currently
- that are not yet done */
- if (client->maxInflight == 0 ||
- MqttClientInflightMessageCount(client) < client->maxInflight)
- {
- LOG_DEBUG("setting message (%d) state to MessageStateSend",
- MqttPacketId(packet));
- packet->base.state = MessageStateSend;
- }
- else
- {
- LOG_DEBUG("setting message (%d) state to MessageStateQueued",
- MqttPacketId(packet));
- packet->base.state = MessageStateQueued;
- }
+ packet = PublishToPacket(message);
- /* add the message to the outMessages queue to wait for processing */
- TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet,
- messages);
+ message->topic = NULL;
+ message->payload = NULL;
+
+ MqttClientQueuePacket(client, packet);
+
+ MqttMessageFree(message);
+
+ return 0;
}
else
{
- MqttClientQueuePacket(client, (MqttPacket *) packet);
- }
+ /* Duplicate the user buffers as we need the data to be available
+ longer. */
+ message->topic = bfromcstr(topic);
+ message->payload = blk2bstr(data, size);
- if (qos > 0)
- return MqttPacketId(packet);
+ message->id = MqttClientNextPacketId(client);
- return 0;
+ TAILQ_INSERT_TAIL(&client->outMessages, message, chain);
+
+ return message->id;
+ }
}
int MqttClientPublishCString(MqttClient *client, int qos, int retain,
@@ -613,7 +738,7 @@ int MqttClientSetAuth(MqttClient *client, const char *userName,
{
assert(client != NULL);
- if (MqttClientIsConnected(client))
+ if (client->state == MqttClientStateConnecting)
{
LOG_ERROR("MqttClientSetAuth must be called before MqttClientConnect");
return -1;
@@ -655,6 +780,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet)
{
assert(client != NULL);
LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type));
+ packet->state = MqttPacketStateWriteType;
SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue);
}
@@ -667,128 +793,363 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type)
return 0;
}
-static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet)
+static int MqttClientSendPacket(MqttClient *client)
{
- if (MqttPacketSerialize(packet, &client->stream.base) == -1)
- return -1;
+ MqttPacket *packet;
- packet->sentAt = MqttGetCurrentTime();
- client->lastPacketSentTime = packet->sentAt;
+ packet = SIMPLEQ_FIRST(&client->sendQueue);
- if (packet->type == MqttPacketTypeDisconnect)
+ if (!packet)
{
- client->stopped = 1;
+ LOG_WARNING("MqttClientSendPacket called with no queued packets");
+ return 0;
}
- /* If the packet is not on any message list, it can be removed after
- sending. */
- if (TAILQ_NEXT(packet, messages) == NULL &&
- TAILQ_PREV(packet, MessageList, messages) == NULL &&
- TAILQ_FIRST(&client->inMessages) != packet &&
- TAILQ_FIRST(&client->outMessages) != packet)
+ while (packet != NULL)
{
- LOG_DEBUG("freeing packet %s after sending",
- MqttPacketName(MqttPacketType(packet)));
- MqttPacketFree(packet);
+ switch (packet->state)
+ {
+ case MqttPacketStateWriteType:
+ {
+ unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) |
+ (packet->flags & 0x0F);
+
+ if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ return -1;
+ }
+
+ packet->state = MqttPacketStateWriteRemainingLength;
+ packet->remainingLength = blength(packet->payload);
+
+ break;
+ }
+
+ case MqttPacketStateWriteRemainingLength:
+ {
+ if (StreamWriteRemainingLength(&packet->remainingLength,
+ &client->stream.base) == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ return -1;
+ }
+
+ packet->state = MqttPacketStateWritePayload;
+ packet->remainingLength = blength(packet->payload);
+
+ break;
+ }
+
+ case MqttPacketStateWritePayload:
+ {
+ if (packet->payload)
+ {
+ int64_t offset = blength(packet->payload) - packet->remainingLength;
+ int64_t nwritten = 0;
+ int towrite = 16*1024;
+
+ if (packet->remainingLength < 16*1024)
+ towrite = packet->remainingLength;
+
+ nwritten = StreamWrite(bdataofs(packet->payload, offset),
+ towrite,
+ &client->stream.base);
+
+ if (nwritten == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ {
+ return 0;
+ }
+ return -1;
+ }
+
+ packet->remainingLength -= nwritten;
+
+ LOG_DEBUG("nwritten:%d", (int) nwritten);
+ }
+
+ if (packet->remainingLength == 0)
+ {
+ LOG_DEBUG("packet payload sent");
+ packet->state = MqttPacketStateWriteComplete;
+ }
+
+ break;
+ }
+
+ case MqttPacketStateWriteComplete:
+ {
+ client->lastPacketSentTime = MqttGetCurrentTime();
+
+ if (packet->type == MqttPacketTypeDisconnect)
+ {
+ client->stopped = 1;
+ client->state = MqttClientStateDisconnected;
+ }
+
+ LOG_DEBUG("sent %s", MqttPacketName(packet->type));
+
+ if (packet->type == MqttPacketTypePublish && packet->message)
+ {
+ MqttMessage *msg = packet->message;
+
+ if (msg->qos == 1)
+ {
+ msg->state = MqttMessageStateWaitPubAck;
+ }
+ else if (msg->qos == 2)
+ {
+ msg->state = MqttMessageStateWaitPubRec;
+ }
+ }
+
+ if (packet->message)
+ {
+ packet->message->timestamp = client->lastPacketSentTime;
+ }
+
+ SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue);
+
+ MqttPacketFree(packet);
+
+ packet = SIMPLEQ_FIRST(&client->sendQueue);
+
+ break;
+ }
+ }
}
return 0;
}
-static void MqttClientHandleConnAck(MqttClient *client,
- MqttPacketConnAck *packet)
+static int MqttClientHandleConnAck(MqttClient *client)
{
- client->sessionPresent = packet->connAckFlags & 1;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ unsigned char flags;
+ unsigned char rc;
+
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadByte(&flags, pss);
+
+ StreamReadByte(&rc, pss);
+
+ client->sessionPresent = flags & 1;
LOG_DEBUG("sessionPresent:%d", client->sessionPresent);
if (client->onConnect)
{
- LOG_DEBUG("calling onConnect rc:%d", packet->returnCode);
- client->onConnect(client, packet->returnCode, client->sessionPresent);
+ LOG_DEBUG("calling onConnect rc:%d", rc);
+ client->onConnect(client, rc, client->sessionPresent);
}
+
+ return 0;
}
-static void MqttClientHandlePingResp(MqttClient *client)
+static int MqttClientHandlePingResp(MqttClient *client)
{
LOG_DEBUG("got ping response");
client->pingSent = 0;
+ return 0;
}
-static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet)
+static int MqttClientHandleSubAck(MqttClient *client)
{
- MqttPacket *sub;
+ uint16_t id;
+ int *qos;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ int count;
+ int i;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(sub, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ LOG_DEBUG("received SUBACK with id:%d", (int) id);
+
+ count = blength(client->inPacket.payload) - StreamTell(pss);
+
+ if (count <= 0)
{
- if (MqttPacketType(sub) == MqttPacketTypeSubscribe &&
- MqttPacketId(sub) == MqttPacketId(packet))
- {
- break;
- }
+ LOG_ERROR("number of return codes invalid");
+ return -1;
}
- if (!sub)
+ qos = malloc(count * sizeof(int));
+
+ for (i = 0; i < count; ++i)
{
- LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ unsigned char byte;
+ StreamReadByte(&byte, pss);
+ qos[i] = byte;
}
- else
+
+ if (client->onSubscribe)
{
- if (client->onSubscribe)
- {
- MqttPacketSubscribe *sub2;
- int i;
+ client->onSubscribe(client, id, qos, count);
+ }
- sub2 = (MqttPacketSubscribe *) sub;
+ free(qos);
- for (i = 0; i < sub2->topicFilters->qty; ++i)
- {
- const char *filter = bdata(sub2->topicFilters->entry[i]);
- int rc = packet->returnCode[i];
+ return 0;
+}
- LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d",
- MqttPacketId(packet), filter, rc);
+static int MqttClientSendPubAck(MqttClient *client, uint16_t id)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
- client->onSubscribe(client, MqttPacketId(packet), filter, rc);
- }
- }
+ packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id);
- TAILQ_REMOVE(&client->outMessages, sub, messages);
- MqttPacketFree(sub);
- }
+ if (!packet)
+ return -1;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(id, pss);
+
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
+}
+
+static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id);
+
+ if (!packet)
+ return -1;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(msg->id, pss);
+
+ packet->payload = ss.buffer;
+ packet->message = msg;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
+}
+
+static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id);
+
+ if (!packet)
+ return -1;
+
+ packet->flags = 0x2;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(msg->id, pss);
+
+ packet->payload = ss.buffer;
+ packet->message = msg;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
+}
+
+static int MqttClientSendPubComp(MqttClient *client, uint16_t id)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id);
+
+ if (!packet)
+ return -1;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(id, pss);
+
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
}
-static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet)
+static int MqttClientHandlePublish(MqttClient *client)
{
+ MqttMessage *msg;
+ uint16_t id;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ MqttPacket *packet;
+ int qos;
+ int retain;
+ bstring topic;
+ void *payload;
+ int payloadSize;
+
+ /* We are paused - do nothing */
if (client->paused)
- return;
+ return 0;
+
+ packet = &client->inPacket;
+
+ qos = (packet->flags >> 1) & 3;
+ retain = packet->flags & 1;
+
+ StringStreamInitFromBstring(&ss, packet->payload);
+
+ StreamReadMqttString(&topic, pss);
+
+ if (qos > 0)
+ {
+ StreamReadUint16Be(&id, pss);
+ }
+
+ payload = bdataofs(ss.buffer, ss.pos);
+ payloadSize = blength(ss.buffer) - ss.pos;
- if (MqttPacketPublishQos(packet) == 2)
+ if (qos == 2)
{
/* Check if we have sent a PUBREC previously with the same id. If we
have, we have to resend the PUBREC. We must not call the onMessage
callback again. */
- MqttPacket *pubRec;
-
- TAILQ_FOREACH(pubRec, &client->inMessages, messages)
+ TAILQ_FOREACH(msg, &client->inMessages, chain)
{
- if (MqttPacketId(pubRec) == MqttPacketId(packet) &&
- MqttPacketType(pubRec) == MqttPacketTypePubRec)
+ if (msg->id == id &&
+ msg->state == MqttMessageStateWaitPubRel)
{
break;
}
}
- if (pubRec)
+ if (msg)
{
- LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet));
- MqttClientQueuePacket(client, pubRec);
- return;
+ LOG_DEBUG("resending PUBREC id:%u", msg->id);
+ MqttClientSendPubRec(client, msg);
+ bdestroy(topic);
+ return 0;
}
}
@@ -796,268 +1157,395 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe
{
LOG_DEBUG("calling onMessage");
client->onMessage(client,
- bdata(packet->topicName),
- bdata(packet->message),
- blength(packet->message),
- packet->qos,
- packet->retain);
+ bdata(topic),
+ payload,
+ payloadSize,
+ qos,
+ retain);
}
- if (MqttPacketPublishQos(packet) > 0)
+ bdestroy(topic);
+
+ if (qos == 1)
+ {
+ MqttClientSendPubAck(client, id);
+ }
+ else if (qos == 2)
{
- int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck :
- MqttPacketTypePubRec;
+ msg = calloc(1, sizeof(*msg));
- MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet));
+ msg->state = MqttMessageStateWaitPubRel;
+ msg->id = id;
+ msg->qos = qos;
- if (MqttPacketPublishQos(packet) == 2)
- {
- /* append to inMessages as we need a reply to this response */
- TAILQ_INSERT_TAIL(&client->inMessages, resp, messages);
- }
+ TAILQ_INSERT_TAIL(&client->inMessages, msg, chain);
- MqttClientQueuePacket(client, resp);
+ MqttClientSendPubRec(client, msg);
}
+
+ return 0;
}
-static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubAck(MqttClient *client)
{
- MqttPacket *pub;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
+
+ assert(client != NULL);
- TAILQ_FOREACH(pub, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (MqttPacketId(pub) == MqttPacketId(packet) &&
- MqttPacketType(pub) == MqttPacketTypePublish)
+ if (msg->id == id &&
+ msg->state == MqttMessageStateWaitPubAck)
{
break;
}
}
- if (!pub)
+ if (!msg)
{
- LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ LOG_ERROR("no message found with id %d", (int) id);
+ return -1;
}
- else
- {
- TAILQ_REMOVE(&client->outMessages, pub, messages);
- MqttPacketFree(pub);
- if (client->onPublish)
- {
- client->onPublish(client, MqttPacketId(packet));
- }
+ TAILQ_REMOVE(&client->outMessages, msg, chain);
+
+ if (client->onPublish)
+ {
+ client->onPublish(client, msg->id);
}
+
+ MqttMessageFree(msg);
+
+ return 0;
}
-static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubRec(MqttClient *client)
{
- MqttPacket *pub;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(pub, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (MqttPacketId(pub) == MqttPacketId(packet) &&
- MqttPacketType(pub) == MqttPacketTypePublish)
+ /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but
+ they haven't received it. */
+ if (msg->id == id &&
+ (msg->state == MqttMessageStateWaitPubRec ||
+ msg->state == MqttMessageStateWaitPubComp))
{
break;
}
}
- if (!pub)
+ if (!msg)
{
- LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ LOG_ERROR("no message found with id %d", (int) id);
+ return -1;
}
- else
- {
- MqttPacket *pubRel;
- TAILQ_REMOVE(&client->outMessages, pub, messages);
- MqttPacketFree(pub);
+ msg->state = MqttMessageStateWaitPubComp;
- pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet));
- pubRel->state = MessageStateSend;
+ bdestroy(msg->payload);
+ msg->payload = NULL;
- TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages);
- }
+ bdestroy(msg->topic);
+ msg->topic = NULL;
+
+ if (MqttClientSendPubRel(client, msg) == -1)
+ return -1;
+
+ return 0;
}
-static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubRel(MqttClient *client)
{
- MqttPacket *pubRec;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(pubRec, &client->inMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->inMessages, chain)
{
- if (MqttPacketId(pubRec) == MqttPacketId(packet) &&
- MqttPacketType(pubRec) == MqttPacketTypePubRec)
+ if (msg->id == id &&
+ msg->state == MqttMessageStateWaitPubRel)
{
break;
}
}
- if (!pubRec)
+ if (!msg)
{
- MqttPacket *pubComp;
-
- TAILQ_FOREACH(pubComp, &client->inMessages, messages)
- {
- if (MqttPacketId(pubComp) == MqttPacketId(packet) &&
- MqttPacketType(pubComp) == MqttPacketTypePubComp)
- {
- break;
- }
- }
-
- if (pubComp)
- {
- MqttClientQueuePacket(client, pubComp);
- }
- else
- {
- LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
- }
+ LOG_ERROR("no message found with id %d", (int) id);
+ return -1;
}
- else
- {
- MqttPacket *pubComp;
- TAILQ_REMOVE(&client->inMessages, pubRec, messages);
- MqttPacketFree(pubRec);
+ TAILQ_REMOVE(&client->inMessages, msg, chain);
+ MqttMessageFree(msg);
- pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp,
- MqttPacketId(packet));
-
- TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages);
+ if (MqttClientSendPubComp(client, id) == -1)
+ return -1;
- MqttClientQueuePacket(client, pubComp);
- }
+ return 0;
}
-static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubComp(MqttClient *client)
{
- MqttPacket *pubRel;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
+
+ assert(client != NULL);
- TAILQ_FOREACH(pubRel, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (MqttPacketId(pubRel) == MqttPacketId(packet) &&
- MqttPacketType(pubRel) == MqttPacketTypePubRel)
+ if (msg->id == id && msg->state == MqttMessageStateWaitPubComp)
{
break;
}
}
- if (!pubRel)
+ if (!msg)
{
- LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ LOG_WARNING("no message found with id %d", (int) id);
+ return 0;
}
- else
- {
- TAILQ_REMOVE(&client->outMessages, pubRel, messages);
- MqttPacketFree(pubRel);
- if (client->onPublish)
- {
- LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet));
- client->onPublish(client, MqttPacketId(packet));
- }
+ TAILQ_REMOVE(&client->outMessages, msg, chain);
+
+ MqttMessageFree(msg);
+
+ if (client->onPublish)
+ {
+ LOG_DEBUG("calling onPublish id:%d", id);
+ client->onPublish(client, id);
}
+
+ return 0;
}
-static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandleUnsubAck(MqttClient *client)
{
- MqttPacket *sub;
+ uint16_t id;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(sub, &client->outMessages, messages)
- {
- if (MqttPacketId(sub) == MqttPacketId(packet) &&
- MqttPacketType(sub) == MqttPacketTypeUnsubscribe)
- {
- break;
- }
- }
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
- if (!sub)
+ StreamReadUint16Be(&id, pss);
+
+ if (client->onUnsubscribe)
{
- LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ client->onUnsubscribe(client, id);
}
- else
- {
- TAILQ_REMOVE(&client->outMessages, sub, messages);
- MqttPacketFree(sub);
- if (client->onUnsubscribe)
- {
- LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet));
- client->onUnsubscribe(client, MqttPacketId(packet));
- }
- }
+ return 0;
}
-static int MqttClientRecvPacket(MqttClient *client)
+static int MqttClientHandlePacket(MqttClient *client)
{
- MqttPacket *packet = NULL;
+ int rc;
- if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1)
- return -1;
-
- LOG_DEBUG("received packet %s", MqttPacketName(packet->type));
-
- switch (MqttPacketType(packet))
+ switch (client->inPacket.type)
{
case MqttPacketTypeConnAck:
- MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet);
+ rc = MqttClientHandleConnAck(client);
break;
case MqttPacketTypePingResp:
- MqttClientHandlePingResp(client);
+ rc = MqttClientHandlePingResp(client);
break;
case MqttPacketTypeSubAck:
- MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet);
+ rc = MqttClientHandleSubAck(client);
break;
- case MqttPacketTypePublish:
- MqttClientHandlePublish(client, (MqttPacketPublish *) packet);
+ case MqttPacketTypeUnsubAck:
+ rc = MqttClientHandleUnsubAck(client);
break;
case MqttPacketTypePubAck:
- MqttClientHandlePubAck(client, packet);
+ rc = MqttClientHandlePubAck(client);
break;
case MqttPacketTypePubRec:
- MqttClientHandlePubRec(client, packet);
+ rc = MqttClientHandlePubRec(client);
break;
- case MqttPacketTypePubRel:
- MqttClientHandlePubRel(client, packet);
+ case MqttPacketTypePubComp:
+ rc = MqttClientHandlePubComp(client);
break;
- case MqttPacketTypePubComp:
- MqttClientHandlePubComp(client, packet);
+ case MqttPacketTypePubRel:
+ rc = MqttClientHandlePubRel(client);
break;
- case MqttPacketTypeUnsubAck:
- MqttClientHandleUnsubAck(client, packet);
+ case MqttPacketTypePublish:
+ rc = MqttClientHandlePublish(client);
break;
default:
- LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet));
+ LOG_ERROR("packet not handled yet");
+ rc = -1;
break;
}
- MqttPacketFree(packet);
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+
+ client->inPacket.state = MqttPacketStateReadType;
+
+ return rc;
+}
+
+static int MqttClientRecvPacket(MqttClient *client)
+{
+ while (1)
+ {
+ switch (client->inPacket.state)
+ {
+ case MqttPacketStateReadType:
+ {
+ unsigned char typeAndFlags;
+
+ if (StreamReadByte(&typeAndFlags, &client->stream.base) == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ LOG_ERROR("failed reading packet type");
+ return -1;
+ }
+
+ client->inPacket.type = typeAndFlags >> 4;
+ client->inPacket.flags = typeAndFlags & 0x0F;
+
+ if (client->inPacket.type < MqttPacketTypeConnect ||
+ client->inPacket.type > MqttPacketTypeDisconnect)
+ {
+ LOG_ERROR("unknown packet type: %d", client->inPacket.type);
+ return -1;
+ }
+
+ client->inPacket.state = MqttPacketStateReadRemainingLength;
+ client->inPacket.remainingLength = 0;
+ client->inPacket.remainingLengthMul = 1;
+ client->inPacket.payload = NULL;
+
+ break;
+ }
+
+ case MqttPacketStateReadRemainingLength:
+ {
+ if (StreamReadRemainingLength(&client->inPacket.remainingLength,
+ &client->inPacket.remainingLengthMul,
+ &client->stream.base) == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ LOG_ERROR("failed to read remaining length");
+ return -1;
+ }
+
+ LOG_DEBUG("remainingLength:%lu",
+ client->inPacket.remainingLength);
+
+ client->inPacket.state = MqttPacketStateReadPayload;
+
+ break;
+ }
+
+ case MqttPacketStateReadPayload:
+ {
+ if (client->inPacket.remainingLength > 0)
+ {
+ int64_t nread, offset, toread;
+
+ if (client->inPacket.payload == NULL)
+ {
+ unsigned char *data;
+ client->inPacket.payload = bfromcstr("");
+ ballocmin(client->inPacket.payload,
+ client->inPacket.remainingLength+1);
+ data = client->inPacket.payload->data;
+ data[client->inPacket.remainingLength] = '\0';
+ }
+
+ offset = blength(client->inPacket.payload);
+
+ toread = 16*1024;
+
+ if (client->inPacket.remainingLength < (size_t) toread)
+ toread = client->inPacket.remainingLength;
+
+ nread = StreamRead(bdataofs(client->inPacket.payload,
+ offset),
+ toread,
+ &client->stream.base);
+
+ if (nread == -1)
+ {
+ if (SocketWouldBlock(SocketErrno))
+ return 0;
+ LOG_ERROR("failed reading packet payload");
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+ return -1;
+ }
+ else if (nread == 0)
+ {
+ LOG_ERROR("socket disconnected");
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+ return -1;
+ }
+
+ client->inPacket.remainingLength -= nread;
+ client->inPacket.payload->slen += nread;
+
+ LOG_DEBUG("nread:%d", (int) nread);
+ }
+
+ if (client->inPacket.remainingLength == 0)
+ {
+ client->inPacket.state = MqttPacketStateReadComplete;
+ }
+ break;
+ }
+
+ case MqttPacketStateReadComplete:
+ {
+ int type = client->inPacket.type;
+ LOG_DEBUG("received %s", MqttPacketName(type));
+ return MqttClientHandlePacket(client);
+ }
+ }
+ }
return 0;
}
@@ -1072,101 +1560,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client)
return id;
}
-static int64_t MqttPacketTimeSinceSent(MqttPacket *packet)
+static int64_t MqttMessageTimeSinceSent(MqttMessage *msg)
{
int64_t now = MqttGetCurrentTime();
- return now - packet->sentAt;
+ return now - msg->timestamp;
}
-static void MqttClientProcessInMessages(MqttClient *client)
+static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg)
{
- MqttPacket *packet, *next;
-
- LOG_DEBUG("processing inMessages");
-
- TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next)
+ if (msg->timestamp > 0 &&
+ MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000)
{
- LOG_DEBUG("packet type:%s id:%d",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet));
-
- if (MqttPacketType(packet) == MqttPacketTypePubComp)
- {
- int64_t elapsed = MqttPacketTimeSinceSent(packet);
- if (packet->sentAt > 0 &&
- elapsed >= client->retryTimeout*1000)
- {
- LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64,
- MqttPacketId(packet), elapsed);
-
- TAILQ_REMOVE(&client->inMessages, packet, messages);
-
- MqttPacketFree(packet);
- }
- }
+ return 1;
}
+
+ return 0;
}
-static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet)
+static void MqttClientProcessInMessages(MqttClient *client)
{
- if (packet->sentAt > 0 &&
- MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000)
+ MqttMessage *msg, *next;
+
+ TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next)
{
- return 1;
- }
+ switch (msg->state)
+ {
+ case MqttMessageStateWaitPubRel:
+ if (MqttMessageShouldResend(client, msg))
+ {
+ MqttClientSendPubRec(client, msg);
+ }
+ break;
- return 0;
+ default:
+ break;
+ }
+ }
}
static void MqttClientProcessOutMessages(MqttClient *client)
{
- MqttPacket *packet, *next;
+ MqttMessage *msg, *next;
+ MqttPacket *packet;
int inflight = MqttClientInflightMessageCount(client);
- LOG_DEBUG("processing outMessages inflight:%d", inflight);
-
- TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next)
+ TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next)
{
- LOG_DEBUG("packet type:%s id:%d state:%d",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet),
- packet->state);
-
- switch (packet->state)
+ switch (msg->state)
{
- case MessageStateQueued:
+ case MqttMessageStateQueued:
+ {
if (inflight >= client->maxInflight)
{
- LOG_DEBUG("cannot dequeue %s/%d",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet));
- break;
+ continue;
}
- else
- {
- /* If there's less than maxInflight messages currently
- inflight, we can dequeue some messages by falling
- through to MessageStateSend. */
- LOG_DEBUG("dequeuing %s (%d)",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet));
- ++inflight;
- }
-
- case MessageStateSend:
- packet->state = MessageStateSent;
+ /* State change from MqttMessageStatePublish happens after
+ the packet has been sent (in MqttClientSendPacket). */
+ msg->state = MqttMessageStatePublish;
+ packet = PublishToPacket(msg);
MqttClientQueuePacket(client, packet);
+ ++inflight;
break;
+ }
- case MessageStateSent:
- if (MqttPacketShouldResend(client, packet))
+ case MqttMessageStateWaitPubAck:
+ case MqttMessageStateWaitPubRec:
+ {
+ if (MqttMessageShouldResend(client, msg))
{
- packet->state = MessageStateSend;
+ msg->state = MqttMessageStatePublish;
+ packet = PublishToPacket(msg);
+ MqttClientQueuePacket(client, packet);
}
break;
+ }
- default:
+ case MqttMessageStateWaitPubComp:
+ {
+ if (MqttMessageShouldResend(client, msg))
+ {
+ MqttClientSendPubRel(client, msg);
+ }
break;
+ }
}
}
}
@@ -1182,30 +1658,22 @@ static void MqttClientClearQueues(MqttClient *client)
while (!SIMPLEQ_EMPTY(&client->sendQueue))
{
MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue);
-
SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue);
-
- if (TAILQ_NEXT(packet, messages) == NULL &&
- TAILQ_PREV(packet, MessageList, messages) == NULL &&
- TAILQ_FIRST(&client->inMessages) != packet &&
- TAILQ_FIRST(&client->outMessages) != packet)
- {
- MqttPacketFree(packet);
- }
+ MqttPacketFree(packet);
}
while (!TAILQ_EMPTY(&client->outMessages))
{
- MqttPacket *packet = TAILQ_FIRST(&client->outMessages);
- TAILQ_REMOVE(&client->outMessages, packet, messages);
- MqttPacketFree(packet);
+ MqttMessage *msg = TAILQ_FIRST(&client->outMessages);
+ TAILQ_REMOVE(&client->outMessages, msg, chain);
+ MqttMessageFree(msg);
}
while (!TAILQ_EMPTY(&client->inMessages))
{
- MqttPacket *packet = TAILQ_FIRST(&client->inMessages);
- TAILQ_REMOVE(&client->inMessages, packet, messages);
- MqttPacketFree(packet);
+ MqttMessage *msg = TAILQ_FIRST(&client->inMessages);
+ TAILQ_REMOVE(&client->inMessages, msg, chain);
+ MqttMessageFree(msg);
}
}
diff --git a/src/deserialize.c b/src/deserialize.c
deleted file mode 100644
index 96d7789..0000000
--- a/src/deserialize.c
+++ /dev/null
@@ -1,286 +0,0 @@
-#include "deserialize.h"
-#include "packet.h"
-#include "stream_mqtt.h"
-#include "log.h"
-
-#include <stdlib.h>
-#include <assert.h>
-
-typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream);
-
-static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream)
-{
- size_t remainingLength = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (remainingLength != 2)
- return -1;
-
- if (StreamReadUint16Be(&(*packet)->id, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream)
-{
- size_t remainingLength = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (remainingLength != 2)
- return -1;
-
- if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1)
- return -1;
-
- if (StreamRead(&(*packet)->returnCode, 1, stream) != 1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream)
-{
- size_t remainingLength = 0;
- size_t i;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1)
- return -1;
-
- remainingLength -= 2;
-
- (*packet)->returnCode = (unsigned char *) malloc(
- sizeof(*(*packet)->returnCode) * remainingLength);
-
- for (i = 0; i < remainingLength; ++i)
- {
- if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1)
- return -1;
- }
-
- return 0;
-}
-
-static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream)
-{
- size_t remainingLength = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (remainingLength != 2)
- return -1;
-
- if (StreamReadUint16Be(&(*packet)->id, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream)
-{
- size_t remainingLength = 0;
- size_t payloadSize = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (StreamReadMqttString(&(*packet)->topicName, stream) == -1)
- return -1;
-
- LOG_DEBUG("remainingLength:%lu", remainingLength);
-
- payloadSize = remainingLength - blength((*packet)->topicName) - 2;
-
- LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet),
- payloadSize);
-
- if (MqttPacketHasId((const MqttPacket *) *packet))
- {
- LOG_DEBUG("packet has id");
- payloadSize -= 2;
- if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1)
- {
- return -1;
- }
- }
-
- LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize);
-
- /* Allocate extra byte for a NULL terminator. If the user tries to print
- the payload directly. */
-
- (*packet)->message = bfromcstralloc(payloadSize+1, "");
-
- if (StreamRead(bdata((*packet)->message), payloadSize, stream) == -1)
- return -1;
-
- (*packet)->message->slen = payloadSize;
- (*packet)->message->data[payloadSize] = '\0';
-
- return 0;
-}
-
-static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream)
-{
- size_t remainingLength = 0;
- char buffer[256];
-
- (void) packet;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- while (remainingLength > 0)
- {
- size_t l = sizeof(buffer);
-
- if (remainingLength < l)
- l = remainingLength;
-
- if (StreamRead(buffer, l, stream) != (int64_t) l)
- return -1;
-
- remainingLength -= l;
- }
-
- return 0;
-}
-
-static int ValidateFlags(int type, int flags)
-{
- int rv = 0;
-
- switch (type)
- {
- case MqttPacketTypePublish:
- {
- int qos = (flags >> 1) & 2;
- if (qos >= 0 && qos <= 2)
- rv = 1;
- break;
- }
-
- case MqttPacketTypePubRel:
- case MqttPacketTypeSubscribe:
- case MqttPacketTypeUnsubscribe:
- if (flags == 2)
- {
- rv = 1;
- }
- break;
-
- default:
- if (flags == 0)
- {
- rv = 1;
- }
- break;
- }
-
- return rv;
-}
-
-int MqttPacketDeserialize(MqttPacket **packet, Stream *stream)
-{
- MqttPacketDeserializeFunc deserializer = NULL;
- char typeAndFlags;
- int type;
- int flags;
- int rv;
-
- if (StreamRead(&typeAndFlags, 1, stream) != 1)
- return -1;
-
- type = (typeAndFlags & 0xF0) >> 4;
- flags = (typeAndFlags & 0x0F);
-
- if (!ValidateFlags(type, flags))
- {
- return -1;
- }
-
- switch (type)
- {
- case MqttPacketTypeConnect:
- break;
-
- case MqttPacketTypeConnAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize;
- break;
-
- case MqttPacketTypePublish:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize;
- break;
-
- case MqttPacketTypePubAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypePubRec:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypePubRel:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypePubComp:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypeSubscribe:
- break;
-
- case MqttPacketTypeSubAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize;
- break;
-
- case MqttPacketTypeUnsubscribe:
- break;
-
- case MqttPacketTypeUnsubAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize;
- break;
-
- case MqttPacketTypePingReq:
- break;
-
- case MqttPacketTypePingResp:
- break;
-
- case MqttPacketTypeDisconnect:
- break;
-
- default:
- return -1;
- }
-
- if (!deserializer)
- {
- deserializer = MqttPacketGenericDeserializer;
- }
-
- *packet = MqttPacketNew(type);
-
- if (!*packet)
- return -1;
-
- if (type == MqttPacketTypePublish)
- {
- MqttPacketPublishDup(*packet) = (flags >> 3) & 1;
- MqttPacketPublishQos(*packet) = (flags >> 1) & 3;
- MqttPacketPublishRetain(*packet) = flags & 1;
- }
-
- rv = deserializer(packet, stream);
-
- return rv;
-}
diff --git a/src/deserialize.h b/src/deserialize.h
deleted file mode 100644
index 8c29b3d..0000000
--- a/src/deserialize.h
+++ /dev/null
@@ -1,11 +0,0 @@
-#ifndef DESERIALIZE_H
-#define DESERIALIZE_H
-
-#include "config.h"
-
-typedef struct MqttPacket MqttPacket;
-typedef struct Stream Stream;
-
-int MqttPacketDeserialize(MqttPacket **packet, Stream *stream);
-
-#endif
diff --git a/src/message.c b/src/message.c
new file mode 100644
index 0000000..35d9c32
--- /dev/null
+++ b/src/message.c
@@ -0,0 +1,11 @@
+#include "message.h"
+#include "stringstream.h"
+#include "stream_mqtt.h"
+#include "packet.h"
+
+void MqttMessageFree(MqttMessage *msg)
+{
+ bdestroy(msg->topic);
+ bdestroy(msg->payload);
+ free(msg);
+}
diff --git a/src/message.h b/src/message.h
new file mode 100644
index 0000000..04a3d61
--- /dev/null
+++ b/src/message.h
@@ -0,0 +1,40 @@
+#ifndef MESSAGE_H
+#define MESSAGE_H
+
+#include <stdint.h>
+
+#include "queue.h"
+#include <bstrlib/bstrlib.h>
+
+enum MqttMessageState
+{
+ MqttMessageStateQueued,
+ MqttMessageStatePublish,
+ MqttMessageStateWaitPubAck,
+ MqttMessageStateWaitPubRec,
+ MqttMessageStateWaitPubComp,
+ MqttMessageStateWaitPubRel
+};
+
+typedef struct MqttMessage MqttMessage;
+
+struct MqttMessage
+{
+ int state;
+ int qos;
+ int retain;
+ int dup;
+ int padding;
+ uint16_t id;
+ int64_t timestamp;
+ bstring topic;
+ bstring payload;
+ TAILQ_ENTRY(MqttMessage) chain;
+};
+
+typedef struct MqttMessageList MqttMessageList;
+TAILQ_HEAD(MqttMessageList, MqttMessage);
+
+void MqttMessageFree(MqttMessage *msg);
+
+#endif
diff --git a/src/mqtt.h b/src/mqtt.h
index 2b84962..840026e 100644
--- a/src/mqtt.h
+++ b/src/mqtt.h
@@ -33,8 +33,8 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client,
typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client,
int id,
- const char *topicFilter,
- MqttSubscriptionStatus status);
+ int *qos,
+ int count);
typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id);
@@ -84,7 +84,7 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter,
int qos);
int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters,
- int *qos, size_t count);
+ int *qos, size_t count);
int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter);
diff --git a/src/packet.c b/src/packet.c
index 47aa689..c833851 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -28,42 +28,16 @@ const char *MqttPacketName(int type)
}
}
-static MQTT_INLINE size_t MqttPacketStructSize(int type)
-{
- switch (type)
- {
- case MqttPacketTypeConnect: return sizeof(MqttPacketConnect);
- case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck);
- case MqttPacketTypePublish: return sizeof(MqttPacketPublish);
- case MqttPacketTypePubAck:
- case MqttPacketTypePubRec:
- case MqttPacketTypePubRel:
- case MqttPacketTypePubComp: return sizeof(MqttPacket);
- case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe);
- case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck);
- case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe);
- case MqttPacketTypeUnsubAck: return sizeof(MqttPacket);
- case MqttPacketTypePingReq: return sizeof(MqttPacket);
- case MqttPacketTypePingResp: return sizeof(MqttPacket);
- case MqttPacketTypeDisconnect: return sizeof(MqttPacket);
- default: return (size_t) -1;
- }
-}
-
MqttPacket *MqttPacketNew(int type)
{
MqttPacket *packet = NULL;
- packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type));
+ packet = (MqttPacket *) calloc(1, sizeof(*packet));
if (!packet)
return NULL;
packet->type = type;
- /* this will make sure that TAILQ_PREV does not segfault if a message
- has not been added to a list at any point */
- packet->messages.tqe_prev = &packet->messages.tqe_next;
-
return packet;
}
@@ -78,52 +52,6 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id)
void MqttPacketFree(MqttPacket *packet)
{
- if (MqttPacketType(packet) == MqttPacketTypeConnect)
- {
- MqttPacketConnect *p = (MqttPacketConnect *) packet;
- bdestroy(p->clientId);
- bdestroy(p->willTopic);
- bdestroy(p->willMessage);
- bdestroy(p->userName);
- bdestroy(p->password);
- }
- else if (MqttPacketType(packet) == MqttPacketTypePublish)
- {
- MqttPacketPublish *p = (MqttPacketPublish *) packet;
- bdestroy(p->topicName);
- bdestroy(p->message);
- }
- else if (MqttPacketType(packet) == MqttPacketTypeSubscribe)
- {
- MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet;
- bstrListDestroy(p->topicFilters);
- }
- else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe)
- {
- MqttPacketUnsubscribe *p = (MqttPacketUnsubscribe *) packet;
- bdestroy(p->topicFilter);
- }
+ bdestroy(packet->payload);
free(packet);
}
-
-int MqttPacketHasId(const MqttPacket *packet)
-{
- switch (packet->type)
- {
- case MqttPacketTypePublish:
- return MqttPacketPublishQos(packet) > 0;
-
- case MqttPacketTypePubAck:
- case MqttPacketTypePubRec:
- case MqttPacketTypePubRel:
- case MqttPacketTypePubComp:
- case MqttPacketTypeSubscribe:
- case MqttPacketTypeSubAck:
- case MqttPacketTypeUnsubscribe:
- case MqttPacketTypeUnsubAck:
- return 1;
-
- default:
- return 0;
- }
-}
diff --git a/src/packet.h b/src/packet.h
index 7ab4f73..a5e2ce7 100644
--- a/src/packet.h
+++ b/src/packet.h
@@ -29,87 +29,35 @@ enum
MqttPacketTypeDisconnect = 0xE
};
+enum MqttPacketState
+{
+ MqttPacketStateReadType,
+ MqttPacketStateReadRemainingLength,
+ MqttPacketStateReadPayload,
+ MqttPacketStateReadComplete,
+
+ MqttPacketStateWriteType,
+ MqttPacketStateWriteRemainingLength,
+ MqttPacketStateWritePayload,
+ MqttPacketStateWriteComplete
+};
+
+struct MqttMessage;
+
typedef struct MqttPacket MqttPacket;
struct MqttPacket
{
int type;
- uint16_t id;
- int state;
int flags;
- int64_t sentAt;
+ int state;
+ uint16_t id;
+ size_t remainingLength;
+ size_t remainingLengthMul;
+ /* TODO: maybe switch to have a StringStream here? */
+ bstring payload;
+ struct MqttMessage *message;
SIMPLEQ_ENTRY(MqttPacket) sendQueue;
- TAILQ_ENTRY(MqttPacket) messages;
-};
-
-#define MqttPacketType(packet) (((MqttPacket *) (packet))->type)
-
-#define MqttPacketId(packet) (((MqttPacket *) (packet))->id)
-
-#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt)
-
-typedef struct MqttPacketConnect MqttPacketConnect;
-
-struct MqttPacketConnect
-{
- MqttPacket base;
- char connectFlags;
- uint16_t keepAlive;
- bstring clientId;
- bstring willTopic;
- bstring willMessage;
- bstring userName;
- bstring password;
-};
-
-typedef struct MqttPacketConnAck MqttPacketConnAck;
-
-struct MqttPacketConnAck
-{
- MqttPacket base;
- unsigned char connAckFlags;
- unsigned char returnCode;
-};
-
-typedef struct MqttPacketPublish MqttPacketPublish;
-
-struct MqttPacketPublish
-{
- MqttPacket base;
- bstring topicName;
- bstring message;
- char qos;
- char dup;
- char retain;
-};
-
-#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos)
-#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup)
-#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain)
-
-typedef struct MqttPacketSubscribe MqttPacketSubscribe;
-
-struct MqttPacketSubscribe
-{
- MqttPacket base;
- struct bstrList *topicFilters;
- int *qos;
-};
-
-typedef struct MqttPacketSubAck MqttPacketSubAck;
-
-struct MqttPacketSubAck
-{
- MqttPacket base;
- unsigned char *returnCode;
-};
-
-typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe;
-
-struct MqttPacketUnsubscribe
-{
- MqttPacket base;
- bstring topicFilter;
};
const char *MqttPacketName(int type);
@@ -120,6 +68,4 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id);
void MqttPacketFree(MqttPacket *packet);
-int MqttPacketHasId(const MqttPacket *packet);
-
#endif
diff --git a/src/serialize.c b/src/serialize.c
deleted file mode 100644
index c1c8eb4..0000000
--- a/src/serialize.c
+++ /dev/null
@@ -1,326 +0,0 @@
-#include "serialize.h"
-#include "packet.h"
-#include "stream_mqtt.h"
-#include "log.h"
-
-#include <bstrlib/bstrlib.h>
-
-#include <stdlib.h>
-#include <assert.h>
-
-typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet,
- Stream *stream);
-
-static const struct tagbstring MqttProtocolId = bsStatic("MQTT");
-static const char MqttProtocolLevel = 0x04;
-
-static MQTT_INLINE size_t MqttStringLengthSerialized(const_bstring s)
-{
- return 2 + blength(s);
-}
-
-static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet)
-{
- size_t remainingLength = 0;
-
- remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2;
-
- remainingLength += MqttStringLengthSerialized(packet->clientId);
-
- if (packet->connectFlags & 0x80)
- remainingLength += MqttStringLengthSerialized(packet->userName);
-
- if (packet->connectFlags & 0x40)
- remainingLength += MqttStringLengthSerialized(packet->password);
-
- if (packet->connectFlags & 0x04)
- remainingLength += MqttStringLengthSerialized(packet->willTopic) +
- MqttStringLengthSerialized(packet->willMessage);
-
- return remainingLength;
-}
-
-static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet)
-{
- size_t remaining = 2;
- int i;
-
- for (i = 0; i < packet->topicFilters->qty; ++i)
- {
- remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]);
- remaining += 1;
- }
-
- return remaining;
-}
-
-static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet)
-{
- return 2 + MqttStringLengthSerialized(packet->topicFilter);
-}
-
-static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet)
-{
- size_t remainingLength = 0;
-
- remainingLength += MqttStringLengthSerialized(packet->topicName);
-
- /* Packet id */
- if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2)
- {
- remainingLength += 2;
- }
-
- remainingLength += blength(packet->message);
-
- return remainingLength;
-}
-
-static size_t MqttPacketGetRemainingLength(const MqttPacket *packet)
-{
- switch (packet->type)
- {
- case MqttPacketTypeConnect:
- return MqttPacketConnectGetRemainingLength(
- (MqttPacketConnect *) packet);
-
- case MqttPacketTypeSubscribe:
- return MqttPacketSubscribeGetRemainingLength(
- (MqttPacketSubscribe *) packet);
-
- case MqttPacketTypePublish:
- return MqttPacketPublishGetRemainingLength(
- (MqttPacketPublish *) packet);
-
- case MqttPacketTypePubAck:
- case MqttPacketTypePubRec:
- case MqttPacketTypePubRel:
- case MqttPacketTypePubComp:
- return 2;
-
- case MqttPacketTypeUnsubscribe:
- return MqttPacketUnsubscribeGetRemainingLength(
- (MqttPacketUnsubscribe *) packet);
-
- default:
- return 0;
- }
-}
-
-static int MqttPacketFlags(const MqttPacket *packet)
-{
- switch (packet->type)
- {
- case MqttPacketTypePublish:
- return ((MqttPacketPublishDup(packet) & 1) << 3) |
- ((MqttPacketPublishQos(packet) & 3) << 1) |
- (MqttPacketPublishRetain(packet) & 1);
-
- case MqttPacketTypePubRel:
- case MqttPacketTypeSubscribe:
- case MqttPacketTypeUnsubscribe:
- return 0x2;
-
- default:
- return 0;
- }
-}
-
-static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream)
-{
- unsigned char typeAndFlags;
- size_t remainingLength;
-
- typeAndFlags = ((packet->type & 0x0F) << 4) |
- (MqttPacketFlags(packet) & 0x0F);
- remainingLength = MqttPacketGetRemainingLength(packet);
-
- LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type,
- MqttPacketName(packet->type), MqttPacketFlags(packet));
-
- if (StreamWrite(&typeAndFlags, 1, stream) != 1)
- return -1;
-
- if (StreamWriteRemainingLength(remainingLength, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream)
-{
- assert(MqttPacketHasId((const MqttPacket *) packet));
-
- if (MqttPacketBaseSerialize(packet, stream) == -1)
- return -1;
-
- if (StreamWriteUint16Be(packet->id, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream)
-{
- if (MqttPacketBaseSerialize(&packet->base, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(&MqttProtocolId, stream) == -1)
- return -1;
-
- if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1)
- return -1;
-
- if (StreamWrite(&packet->connectFlags, 1, stream) != 1)
- return -1;
-
- if (StreamWriteUint16Be(packet->keepAlive, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->clientId, stream) == -1)
- return -1;
-
- if (packet->connectFlags & 0x04)
- {
- if (StreamWriteMqttString(packet->willTopic, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->willMessage, stream) == -1)
- return -1;
- }
-
- if (packet->connectFlags & 0x80)
- {
- if (StreamWriteMqttString(packet->userName, stream) == -1)
- return -1;
-
- if (packet->connectFlags & 0x40)
- {
- if (StreamWriteMqttString(packet->password, stream) == -1)
- return -1;
- }
- }
-
- return 0;
-}
-
-static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream)
-{
- int i;
-
- if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1)
- return -1;
-
- for (i = 0; i < packet->topicFilters->qty; ++i)
- {
- unsigned char qos = (unsigned char) packet->qos[i];
-
- if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1)
- return -1;
-
- if (StreamWrite(&qos, 1, stream) == -1)
- return -1;
- }
-
- return 0;
-}
-
-static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream)
-{
- if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->topicFilter, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream)
-{
- if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->topicName, stream) == -1)
- return -1;
-
- LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet));
-
- if (MqttPacketPublishQos(packet) > 0)
- {
- if (StreamWriteUint16Be(packet->base.id, stream) == -1)
- return -1;
- }
-
- if (StreamWrite(bdata(packet->message), blength(packet->message), stream) == -1)
- return -1;
-
- return 0;
-}
-
-int MqttPacketSerialize(const MqttPacket *packet, Stream *stream)
-{
- MqttPacketSerializeFunc f = NULL;
-
- switch (packet->type)
- {
- case MqttPacketTypeConnect:
- f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize;
- break;
-
- case MqttPacketTypeConnAck:
- break;
-
- case MqttPacketTypePublish:
- f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize;
- break;
-
- case MqttPacketTypePubAck:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypePubRec:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypePubRel:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypePubComp:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypeSubscribe:
- f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize;
- break;
-
- case MqttPacketTypeSubAck:
- break;
-
- case MqttPacketTypeUnsubscribe:
- f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize;
- break;
-
- case MqttPacketTypeUnsubAck:
- break;
-
- case MqttPacketTypePingReq:
- f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize;
- break;
-
- case MqttPacketTypePingResp:
- break;
-
- case MqttPacketTypeDisconnect:
- f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize;
- break;
-
- default:
- return -1;
- }
-
- assert(f != NULL && "no serializer");
-
- return f(packet, stream);
-}
diff --git a/src/serialize.h b/src/serialize.h
deleted file mode 100644
index 7eb988f..0000000
--- a/src/serialize.h
+++ /dev/null
@@ -1,11 +0,0 @@
-#ifndef SERIALIZE_H
-#define SERIALIZE_H
-
-#include "config.h"
-
-typedef struct MqttPacket MqttPacket;
-typedef struct Stream Stream;
-
-int MqttPacketSerialize(const MqttPacket *packet, Stream *stream);
-
-#endif
diff --git a/src/socket.c b/src/socket.c
index 64a7c01..b70f4fb 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -6,18 +6,6 @@
#include <assert.h>
#if defined(_WIN32)
-#include "win32.h"
-#else
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <sys/select.h>
-#include <netdb.h>
-#include <unistd.h>
-#include <arpa/inet.h>
-#endif
-
-#if defined(_WIN32)
static int InitializeWsa()
{
WSADATA wsa;
@@ -33,9 +21,9 @@ static int InitializeWsa()
#define close closesocket
#endif
-int SocketConnect(const char *host, short port)
+int SocketConnect(const char *host, short port, int nonblocking)
{
- struct addrinfo hints, *servinfo, *p = NULL;
+ struct addrinfo hints, *servinfo = NULL, *p = NULL;
int rv;
char portstr[6];
int sock;
@@ -66,8 +54,16 @@ int SocketConnect(const char *host, short port)
continue;
}
+ if (nonblocking)
+ {
+ SocketSetNonblocking(sock, 1);
+ }
+
if (connect(sock, p->ai_addr, p->ai_addrlen) == -1)
{
+ int err = SocketErrno;
+ if (err == SOCKET_EINPROGRESS)
+ break;
close(sock);
continue;
}
@@ -75,10 +71,13 @@ int SocketConnect(const char *host, short port)
break;
}
- freeaddrinfo(servinfo);
-
cleanup:
+ if (servinfo)
+ {
+ freeaddrinfo(servinfo);
+ }
+
if (p == NULL)
{
#if defined(_WIN32)
@@ -178,3 +177,45 @@ int SocketSelect(int sock, int *events, int timeout)
return rv;
}
+
+void SocketSetNonblocking(int sock, int nb)
+{
+#if defined(_WIN32)
+ unsigned int yes = nb;
+ ioctlsocket(s, FIONBIO, &yes);
+#else
+ int flags = fcntl(sock, F_GETFL, 0);
+ if (nb)
+ flags |= O_NONBLOCK;
+ else
+ flags &= ~O_NONBLOCK;
+ fcntl(sock, F_SETFL, flags);
+#endif
+}
+
+int SocketGetOpt(int sock, int level, int name, void *val, int *len)
+{
+#if defined(_WIN32)
+ return getsockopt(sock, level, name, (char *) val, len);
+#else
+ socklen_t _len = *len;
+ int rc = getsockopt(sock, level, name, val, &_len);
+ *len = _len;
+ return rc;
+#endif
+}
+
+int SocketGetError(int sock, int *error)
+{
+ int len = sizeof(*error);
+ return SocketGetOpt(sock, SOL_SOCKET, SO_ERROR, error, &len);
+}
+
+int SocketWouldBlock(int error)
+{
+#if defined(_WIN32)
+ return error == WSAEWOULDBLOCK;
+#else
+ return error == EWOULDBLOCK || error == EAGAIN;
+#endif
+}
diff --git a/src/socket.h b/src/socket.h
index e7b1a80..abc67af 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -6,7 +6,25 @@
#include <stdlib.h>
#include <stdint.h>
-int SocketConnect(const char *host, short port);
+#if defined(_WIN32)
+#include "win32.h"
+#define SocketErrno (WSAGetLastError())
+#define SOCKET_EINPROGRESS (WSAEWOULDBLOCK)
+#else
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/select.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <errno.h>
+#define SocketErrno (errno)
+#define SOCKET_EINPROGRESS (EINPROGRESS)
+#endif
+
+int SocketConnect(const char *host, short port, int nonblocking);
int SocketDisconnect(int sock);
@@ -24,4 +42,10 @@ int64_t SocketRecv(int sock, void *buf, size_t len, int flags);
int64_t SocketSend(int sock, const void *buf, size_t len, int flags);
+void SocketSetNonblocking(int sock, int nb);
+
+int SocketGetError(int sock, int *error);
+
+int SocketWouldBlock(int error);
+
#endif
diff --git a/src/stream.c b/src/stream.c
index fd154a1..1c46668 100644
--- a/src/stream.c
+++ b/src/stream.c
@@ -47,6 +47,11 @@ int64_t StreamReadUint16Be(uint16_t *v, Stream *stream)
return 2;
}
+int64_t StreamReadByte(unsigned char *byte, Stream *stream)
+{
+ return StreamRead(byte, sizeof(*byte), stream);
+}
+
int64_t StreamWrite(const void *ptr, size_t size, Stream *stream)
{
STREAM_CHECK_OP(stream, write);
@@ -65,6 +70,11 @@ int64_t StreamWriteUint16Be(uint16_t v, Stream *stream)
return StreamWrite(data, sizeof(data), stream);
}
+int64_t StreamWriteByte(unsigned char byte, Stream *stream)
+{
+ return StreamWrite(&byte, sizeof(byte), stream);
+}
+
int StreamSeek(Stream *stream, int64_t offset, int whence)
{
STREAM_CHECK_OP(stream, seek);
diff --git a/src/stream.h b/src/stream.h
index 839facb..50f1772 100644
--- a/src/stream.h
+++ b/src/stream.h
@@ -27,9 +27,11 @@ int StreamClose(Stream *stream);
int64_t StreamRead(void *ptr, size_t size, Stream *stream);
int64_t StreamReadUint16Be(uint16_t *v, Stream *stream);
+int64_t StreamReadByte(unsigned char *byte, Stream *stream);
int64_t StreamWrite(const void *ptr, size_t size, Stream *stream);
int64_t StreamWriteUint16Be(uint16_t v, Stream *stream);
+int64_t StreamWriteByte(unsigned char byte, Stream *stream);
int StreamSeek(Stream *stream, int64_t offset, int whence);
diff --git a/src/stream_mqtt.c b/src/stream_mqtt.c
index 3864ef3..f2bd9cd 100644
--- a/src/stream_mqtt.c
+++ b/src/stream_mqtt.c
@@ -42,37 +42,39 @@ int64_t StreamWriteMqttString(const_bstring buf, Stream *stream)
return 2 + blength(buf);
}
-int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream)
+int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul,
+ Stream *stream)
{
- size_t multiplier = 1;
unsigned char encodedByte;
- *remainingLength = 0;
do
{
if (StreamRead(&encodedByte, 1, stream) != 1)
return -1;
- *remainingLength += (encodedByte & 127) * multiplier;
- if (multiplier > 128*128*128)
+ *remainingLength += (encodedByte & 127) * (*mul);
+ if ((*mul) > 128*128*128)
return -1;
- multiplier *= 128;
+ (*mul) *= 128;
}
while ((encodedByte & 128) != 0);
+ *mul = 0;
return 0;
}
-int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream)
+int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream)
{
- size_t nbytes = 0;
do
{
- unsigned char encodedByte = remainingLength % 128;
- remainingLength /= 128;
- if (remainingLength > 0)
+ size_t tmp = *remainingLength;
+ unsigned char encodedByte = tmp % 128;
+ tmp /= 128;
+ if (tmp > 0)
encodedByte |= 128;
if (StreamWrite(&encodedByte, 1, stream) != 1)
+ {
return -1;
- ++nbytes;
+ }
+ *remainingLength = tmp;
}
- while (remainingLength > 0);
- return nbytes;
+ while (*remainingLength > 0);
+ return 0;
}
diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h
index 9023430..8c8ccb5 100644
--- a/src/stream_mqtt.h
+++ b/src/stream_mqtt.h
@@ -2,13 +2,15 @@
#define STREAM_MQTT_H
#include "stream.h"
+#include "stringstream.h"
#include <bstrlib/bstrlib.h>
int64_t StreamReadMqttString(bstring *buf, Stream *stream);
int64_t StreamWriteMqttString(const_bstring buf, Stream *stream);
-int64_t StreamReadRemainingLength(size_t *remainingLength, Stream *stream);
-int64_t StreamWriteRemainingLength(size_t remainingLength, Stream *stream);
+int64_t StreamReadRemainingLength(size_t *remainingLength, size_t *mul,
+ Stream *stream);
+int64_t StreamWriteRemainingLength(size_t *remainingLength, Stream *stream);
#endif
diff --git a/src/stringstream.c b/src/stringstream.c
new file mode 100644
index 0000000..4353932
--- /dev/null
+++ b/src/stringstream.c
@@ -0,0 +1,115 @@
+#include "stringstream.h"
+
+#include <assert.h>
+
+static int StringStreamClose(Stream *base)
+{
+ StringStream *ss = (StringStream *) base;
+ bdestroy(ss->buffer);
+ ss->buffer = NULL;
+ return 0;
+}
+
+static int64_t StringStreamRead(void *ptr, size_t size, Stream *stream)
+{
+ StringStream *ss = (StringStream *) stream;
+ int64_t available = blength(ss->buffer) - ss->pos;
+ void *bufptr;
+
+ if (available <= 0)
+ {
+ return -1;
+ }
+
+ if (size > (size_t) available)
+ size = available;
+
+ /* Use a temp buffer pointer to make some warnings disappear when using
+ GCC */
+ bufptr = bdataofs(ss->buffer, ss->pos);
+ memcpy(ptr, bufptr, size);
+
+ ss->pos += size;
+
+ return size;
+}
+
+static int64_t StringStreamWrite(const void *ptr, size_t size, Stream *stream)
+{
+ StringStream *ss = (StringStream *) stream;
+ struct tagbstring buf;
+ if (ss->buffer->mlen <= 0)
+ return -1;
+ btfromblk(buf, ptr, size);
+ bsetstr(ss->buffer, ss->pos, &buf, '\0');
+ ss->pos += size;
+ return size;
+}
+
+int StringStreamSeek(Stream *base, int64_t offset, int whence)
+{
+ StringStream *ss = (StringStream *) base;
+ int64_t newpos = 0;
+
+ if (whence == SEEK_SET)
+ {
+ newpos = offset;
+ }
+ else if (whence == SEEK_CUR)
+ {
+ newpos = ss->pos + offset;
+ }
+ else if (whence == SEEK_END)
+ {
+ newpos = blength(ss->buffer) - offset;
+ }
+ else
+ {
+ return -1;
+ }
+
+ if (newpos > blength(ss->buffer))
+ return -1;
+
+ if (newpos < 0)
+ return -1;
+
+ ss->pos = newpos;
+
+ return 0;
+}
+
+int64_t StringStreamTell(Stream *base)
+{
+ StringStream *ss = (StringStream *) base;
+ return ss->pos;
+}
+
+static const StreamOps StringStreamOps =
+{
+ StringStreamRead,
+ StringStreamWrite,
+ StringStreamClose,
+ StringStreamSeek,
+ StringStreamTell
+};
+
+int StringStreamInit(StringStream *stream)
+{
+ assert(stream != NULL);
+ memset(stream, 0, sizeof(*stream));
+ stream->pos = 0;
+ stream->buffer = bfromcstr("");
+ stream->base.ops = &StringStreamOps;
+ return 0;
+}
+
+int StringStreamInitFromBstring(StringStream *stream, bstring buffer)
+{
+ assert(stream != NULL);
+ memset(stream, 0, sizeof(*stream));
+ stream->pos = 0;
+ stream->buffer = buffer;
+ stream->base.ops = &StringStreamOps;
+ return 0;
+}
diff --git a/src/stringstream.h b/src/stringstream.h
new file mode 100644
index 0000000..60d42fb
--- /dev/null
+++ b/src/stringstream.h
@@ -0,0 +1,21 @@
+#ifndef STRINGSTREAM_H
+#define STRINGSTREAM_H
+
+#include "stream.h"
+#include <bstrlib/bstrlib.h>
+#include <stdio.h>
+
+typedef struct StringStream StringStream;
+
+struct StringStream
+{
+ Stream base;
+ bstring buffer;
+ int64_t pos;
+};
+
+int StringStreamInit(StringStream *stream);
+
+int StringStreamInitFromBstring(StringStream *stream, bstring buffer);
+
+#endif