aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/CMakeLists.txt4
-rw-r--r--src/client.c1160
-rw-r--r--src/deserialize.c286
-rw-r--r--src/deserialize.h11
-rw-r--r--src/mqtt.h4
-rw-r--r--src/packet.c76
-rw-r--r--src/packet.h96
-rw-r--r--src/serialize.c326
-rw-r--r--src/serialize.h11
-rw-r--r--src/stream.c10
-rw-r--r--src/stream.h2
-rw-r--r--src/stream_mqtt.h1
-rw-r--r--test/interop/CMakeLists.txt3
-rw-r--r--test/interop/ping_test.c27
-rw-r--r--test/interop/testclient.c44
-rw-r--r--test/interop/testclient.h7
-rw-r--r--test/interop/unsubscribe_test.c31
-rw-r--r--test/interop/username_and_password_test.c30
-rw-r--r--tools/sub.c5
19 files changed, 928 insertions, 1206 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index f51fabb..5a565ca 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -2,14 +2,14 @@ ADD_SUBDIRECTORY(lib)
ADD_LIBRARY(mqtt STATIC
client.c
- deserialize.c
misc.c
packet.c
- serialize.c
socket.c
socketstream.c
stream.c
stream_mqtt.c
+ stringstream.c
+ message.c
$<TARGET_OBJECTS:bstrlib>
)
diff --git a/src/client.c b/src/client.c
index a6b0998..9238555 100644
--- a/src/client.c
+++ b/src/client.c
@@ -5,10 +5,11 @@
#include "socketstream.h"
#include "socket.h"
#include "misc.h"
-#include "serialize.h"
-#include "deserialize.h"
#include "log.h"
#include "private.h"
+#include "stringstream.h"
+#include "stream_mqtt.h"
+#include "message.h"
#include "queue.h"
@@ -25,9 +26,6 @@
#error define PRId64 for your platform
#endif
-TAILQ_HEAD(MessageList, MqttPacket);
-typedef struct MessageList MessageList;
-
struct MqttClient
{
SocketStream stream;
@@ -56,9 +54,9 @@ struct MqttClient
/* packets waiting to be sent over network */
SIMPLEQ_HEAD(, MqttPacket) sendQueue;
/* sent messages that are not done yet */
- MessageList outMessages;
+ MqttMessageList outMessages;
/* received messages that are not done yet */
- MessageList inMessages;
+ MqttMessageList inMessages;
int sessionPresent;
/* when was the last packet sent */
int64_t lastPacketSentTime;
@@ -80,18 +78,13 @@ struct MqttClient
int paused;
bstring userName;
bstring password;
-};
-
-enum MessageState
-{
- MessageStateQueued = 100,
- MessageStateSend,
- MessageStateSent
+ /* The packet we are receiving */
+ MqttPacket inPacket;
};
static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet);
static int MqttClientQueueSimplePacket(MqttClient *client, int type);
-static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet);
+static int MqttClientSendPacket(MqttClient *client);
static int MqttClientRecvPacket(MqttClient *client);
static uint16_t MqttClientNextPacketId(MqttClient *client);
static void MqttClientProcessMessageQueue(MqttClient *client);
@@ -99,14 +92,14 @@ static void MqttClientClearQueues(MqttClient *client);
static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client)
{
- MqttPacket *packet;
+ MqttMessage *msg;
int queued = 0;
int inMessagesCount = 0;
int outMessagesCount = 0;
- TAILQ_FOREACH(packet, &client->outMessages, messages)
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (packet->state == MessageStateQueued)
+ if (msg->state == MqttMessageStateQueued)
{
++queued;
}
@@ -114,7 +107,7 @@ static MQTT_INLINE int MqttClientInflightMessageCount(MqttClient *client)
++outMessagesCount;
}
- TAILQ_FOREACH(packet, &client->inMessages, messages)
+ TAILQ_FOREACH(msg, &client->inMessages, chain)
{
++inMessagesCount;
}
@@ -212,11 +205,44 @@ void MqttClientSetOnPublish(MqttClient *client,
client->onPublish = cb;
}
+static const struct tagbstring MqttProtocolId = bsStatic("MQTT");
+static const char MqttProtocolLevel = 0x04;
+
+static unsigned char MqttClientConnectFlags(MqttClient *client)
+{
+ unsigned char connectFlags = 0;
+
+ if (client->cleanSession)
+ {
+ connectFlags |= 0x02;
+ }
+
+ if (client->willTopic)
+ {
+ connectFlags |= 0x04;
+ connectFlags |= (client->willQos & 3) << 3;
+ connectFlags |= (client->willRetain & 1) << 5;
+ }
+
+ if (client->userName)
+ {
+ connectFlags |= 0x80;
+ if (client->password)
+ {
+ connectFlags |= 0x40;
+ }
+ }
+
+ return connectFlags;
+}
+
int MqttClientConnect(MqttClient *client, const char *host, short port,
int keepAlive, int cleanSession)
{
int sock;
- MqttPacketConnect *packet;
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
assert(client != NULL);
assert(host != NULL);
@@ -253,44 +279,37 @@ int MqttClientConnect(MqttClient *client, const char *host, short port,
return -1;
}
- packet = (MqttPacketConnect *) MqttPacketNew(MqttPacketTypeConnect);
+ packet = MqttPacketNew(MqttPacketTypeConnect);
if (!packet)
return -1;
- if (client->cleanSession)
- {
- packet->connectFlags |= 0x02;
- }
-
- packet->keepAlive = client->keepAlive;
+ StringStreamInit(&ss);
- packet->clientId = bstrcpy(client->clientId);
+ StreamWriteMqttString(&MqttProtocolId, pss);
+ StreamWriteByte(MqttProtocolLevel, pss);
+ StreamWriteByte(MqttClientConnectFlags(client), pss);
+ StreamWriteUint16Be(client->keepAlive, pss);
+ StreamWriteMqttString(client->clientId, pss);
if (client->willTopic)
{
- packet->connectFlags |= 0x04;
-
- packet->willTopic = bstrcpy(client->willTopic);
- packet->willMessage = bstrcpy(client->willMessage);
-
- packet->connectFlags |= (client->willQos & 3) << 3;
- packet->connectFlags |= (client->willRetain & 1) << 5;
+ StreamWriteMqttString(client->willTopic, pss);
+ StreamWriteMqttString(client->willMessage, pss);
}
if (client->userName)
{
- packet->connectFlags |= 0x80;
- packet->userName = bstrcpy(client->userName);
-
- if (client->password)
+ StreamWriteMqttString(client->userName, pss);
+ if(client->password)
{
- packet->connectFlags |= 0x40;
- packet->password = bstrcpy(client->password);
+ StreamWriteMqttString(client->password, pss);
}
}
- MqttClientQueuePacket(client, &packet->base);
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
return 0;
}
@@ -339,6 +358,14 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
if (timeout < 0)
{
timeout = client->keepAlive * 1000;
+ if (timeout == 0)
+ {
+ timeout = 30 * 1000;
+ }
+ }
+ else if (timeout > (client->keepAlive * 1000) && client->keepAlive > 0)
+ {
+ timeout = client->keepAlive * 1000;
}
rv = SocketSelect(client->stream.sock, &events, timeout);
@@ -354,21 +381,12 @@ int MqttClientRunOnce(MqttClient *client, int timeout)
if (events & EV_WRITE)
{
- MqttPacket *packet;
-
LOG_DEBUG("socket writable");
- packet = SIMPLEQ_FIRST(&client->sendQueue);
-
- if (packet)
+ if (MqttClientSendPacket(client) == -1)
{
- SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue);
-
- if (MqttClientSendPacket(client, packet) == -1)
- {
- LOG_ERROR("MqttClientSendPacket failed");
- client->stopped = 1;
- }
+ LOG_ERROR("MqttClientSendPacket failed");
+ client->stopped = 1;
}
}
@@ -433,10 +451,12 @@ int MqttClientSubscribe(MqttClient *client, const char *topicFilter,
}
int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters,
- int *qos, size_t count)
+ int *qos, size_t count)
{
- MqttPacketSubscribe *packet = NULL;
+ MqttPacket *packet = NULL;
size_t i;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
assert(client != NULL);
assert(topicFilters != NULL);
@@ -444,68 +464,122 @@ int MqttClientSubscribeMany(MqttClient *client, const char **topicFilters,
assert(qos != NULL);
assert(count > 0);
- packet = (MqttPacketSubscribe *) MqttPacketWithIdNew(
- MqttPacketTypeSubscribe, MqttClientNextPacketId(client));
+ packet = MqttPacketWithIdNew(MqttPacketTypeSubscribe,
+ MqttClientNextPacketId(client));
if (!packet)
return -1;
- packet->topicFilters = bstrListCreate();
- bstrListAllocMin(packet->topicFilters, count);
+ packet->flags = 0x2;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(packet->id, pss);
- packet->qos = (int *) malloc(sizeof(int) * count);
+ LOG_DEBUG("SUBSCRIBE id:%d", (int) packet->id);
for (i = 0; i < count; ++i)
{
- packet->topicFilters->entry[i] = bfromcstr(topicFilters[i]);
- ++packet->topicFilters->qty;
+ struct tagbstring filter;
+ btfromcstr(filter, topicFilters[i]);
+ StreamWriteMqttString(&filter, pss);
+ StreamWriteByte(qos[i] & 3, pss);
}
- memcpy(packet->qos, qos, sizeof(int) * count);
+ packet->payload = ss.buffer;
- MqttClientQueuePacket(client, (MqttPacket *) packet);
-
- TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages);
+ MqttClientQueuePacket(client, packet);
- return MqttPacketId(packet);
+ return packet->id;
}
int MqttClientUnsubscribe(MqttClient *client, const char *topicFilter)
{
- MqttPacketUnsubscribe *packet = NULL;
+ MqttPacket *packet = NULL;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ struct tagbstring filter;
assert(client != NULL);
assert(topicFilter != NULL);
- packet = (MqttPacketUnsubscribe *) MqttPacketWithIdNew(
- MqttPacketTypeUnsubscribe, MqttClientNextPacketId(client));
+ packet = MqttPacketWithIdNew(MqttPacketTypeUnsubscribe,
+ MqttClientNextPacketId(client));
- packet->topicFilter = bfromcstr(topicFilter);
+ if (!packet)
+ return -1;
+
+ packet->flags = 0x02;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(packet->id, pss);
+
+ btfromcstr(filter, topicFilter);
+
+ StreamWriteMqttString(&filter, pss);
- MqttClientQueuePacket(client, (MqttPacket *) packet);
+ packet->payload = ss.buffer;
- TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet, messages);
+ MqttClientQueuePacket(client, packet);
- return MqttPacketId(packet);
+ return packet->id;
}
static MQTT_INLINE int MqttClientOutMessagesLen(MqttClient *client)
{
- MqttPacket *packet;
+ MqttMessage *msg;
int count = 0;
- TAILQ_FOREACH(packet, &client->outMessages, messages)
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
++count;
}
return count;
}
+static MqttPacket *PublishToPacket(MqttMessage *msg)
+{
+ MqttPacket *packet = NULL;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ if (msg->qos > 0)
+ {
+ packet = MqttPacketWithIdNew(MqttPacketTypePublish,
+ msg->id);
+ }
+ else
+ {
+ packet = MqttPacketNew(MqttPacketTypePublish);
+ }
+
+ if (!packet)
+ return NULL;
+
+ packet->message = msg;
+
+ StringStreamInit(&ss);
+
+ StreamWriteMqttString(msg->topic, pss);
+
+ if (msg->qos > 0)
+ {
+ StreamWriteUint16Be(msg->id, pss);
+ }
+
+ StreamWrite(bdata(msg->payload), blength(msg->payload), pss);
+
+ packet->payload = ss.buffer;
+ packet->flags = (msg->qos & 3) << 1;
+ packet->flags |= msg->retain & 1;
+
+ return packet;
+}
+
int MqttClientPublish(MqttClient *client, int qos, int retain,
const char *topic, const void *data, size_t size)
{
- MqttPacketPublish *packet;
-
- assert(client != NULL);
+ MqttMessage *message;
/* first check if the queue is already full */
if (qos > 0 && client->maxQueued > 0 &&
@@ -514,55 +588,55 @@ int MqttClientPublish(MqttClient *client, int qos, int retain,
return -1;
}
- if (qos > 0)
+ message = calloc(1, sizeof(*message));
+ if (!message)
{
- packet = (MqttPacketPublish *) MqttPacketWithIdNew(
- MqttPacketTypePublish, MqttClientNextPacketId(client));
+ return -1;
}
- else
+
+ message->state = MqttMessageStateQueued;
+ message->qos = qos;
+ message->retain = retain;
+ message->dup = 0;
+ message->timestamp = MqttGetCurrentTime();
+
+ if (qos == 0)
{
- packet = (MqttPacketPublish *) MqttPacketNew(MqttPacketTypePublish);
- }
+ /* Copy payload and topic directly from user buffers as we don't need
+ to keep the message data around after this function. */
+ MqttPacket *packet;
+ struct tagbstring bttopic, btpayload;
- if (!packet)
- return -1;
+ btfromcstr(bttopic, topic);
+ message->topic = &bttopic;
- packet->qos = qos;
- packet->retain = retain;
- packet->topicName = bfromcstr(topic);
- packet->message = blk2bstr(data, size);
+ btfromblk(btpayload, data, size);
+ message->payload = &btpayload;
- if (qos > 0)
- {
- /* check how many messages there are coming in and going out currently
- that are not yet done */
- if (client->maxInflight == 0 ||
- MqttClientInflightMessageCount(client) < client->maxInflight)
- {
- LOG_DEBUG("setting message (%d) state to MessageStateSend",
- MqttPacketId(packet));
- packet->base.state = MessageStateSend;
- }
- else
- {
- LOG_DEBUG("setting message (%d) state to MessageStateQueued",
- MqttPacketId(packet));
- packet->base.state = MessageStateQueued;
- }
+ packet = PublishToPacket(message);
+
+ message->topic = NULL;
+ message->payload = NULL;
+
+ MqttClientQueuePacket(client, packet);
+
+ MqttMessageFree(message);
- /* add the message to the outMessages queue to wait for processing */
- TAILQ_INSERT_TAIL(&client->outMessages, (MqttPacket *) packet,
- messages);
+ return 0;
}
else
{
- MqttClientQueuePacket(client, (MqttPacket *) packet);
- }
+ /* Duplicate the user buffers as we need the data to be available
+ longer. */
+ message->topic = bfromcstr(topic);
+ message->payload = blk2bstr(data, size);
- if (qos > 0)
- return MqttPacketId(packet);
+ message->id = MqttClientNextPacketId(client);
- return 0;
+ TAILQ_INSERT_TAIL(&client->outMessages, message, chain);
+
+ return message->id;
+ }
}
int MqttClientPublishCString(MqttClient *client, int qos, int retain,
@@ -655,6 +729,7 @@ static void MqttClientQueuePacket(MqttClient *client, MqttPacket *packet)
{
assert(client != NULL);
LOG_DEBUG("queuing packet %s", MqttPacketName(packet->type));
+ packet->state = MqttPacketStateWriteType;
SIMPLEQ_INSERT_TAIL(&client->sendQueue, packet, sendQueue);
}
@@ -667,128 +742,332 @@ static int MqttClientQueueSimplePacket(MqttClient *client, int type)
return 0;
}
-static int MqttClientSendPacket(MqttClient *client, MqttPacket *packet)
+static int MqttClientSendPacket(MqttClient *client)
{
- if (MqttPacketSerialize(packet, &client->stream.base) == -1)
- return -1;
+ MqttPacket *packet;
- packet->sentAt = MqttGetCurrentTime();
- client->lastPacketSentTime = packet->sentAt;
+ packet = SIMPLEQ_FIRST(&client->sendQueue);
- if (packet->type == MqttPacketTypeDisconnect)
+ if (!packet)
{
- client->stopped = 1;
+ LOG_WARNING("MqttClientSendPacket called with no queued packets");
+ return 0;
}
- /* If the packet is not on any message list, it can be removed after
- sending. */
- if (TAILQ_NEXT(packet, messages) == NULL &&
- TAILQ_PREV(packet, MessageList, messages) == NULL &&
- TAILQ_FIRST(&client->inMessages) != packet &&
- TAILQ_FIRST(&client->outMessages) != packet)
+ while (packet != NULL)
{
- LOG_DEBUG("freeing packet %s after sending",
- MqttPacketName(MqttPacketType(packet)));
- MqttPacketFree(packet);
+ switch (packet->state)
+ {
+ case MqttPacketStateWriteType:
+ {
+ unsigned char typeAndFlags = ((packet->type & 0x0F) << 4) |
+ (packet->flags & 0x0F);
+
+ if (StreamWriteByte(typeAndFlags, &client->stream.base) == -1)
+ {
+ return -1;
+ }
+
+ packet->state = MqttPacketStateWriteRemainingLength;
+
+ break;
+ }
+
+ case MqttPacketStateWriteRemainingLength:
+ {
+ if (StreamWriteRemainingLength(blength(packet->payload),
+ &client->stream.base) == -1)
+ {
+ return -1;
+ }
+
+ packet->state = MqttPacketStateWritePayload;
+
+ break;
+ }
+
+ case MqttPacketStateWritePayload:
+ {
+ if (packet->payload)
+ {
+ if (StreamWrite(bdata(packet->payload),
+ blength(packet->payload),
+ &client->stream.base) == -1)
+ {
+ return -1;
+ }
+ }
+
+ packet->state = MqttPacketStateWriteComplete;
+
+ break;
+ }
+
+ case MqttPacketStateWriteComplete:
+ {
+ client->lastPacketSentTime = MqttGetCurrentTime();
+
+ if (packet->type == MqttPacketTypeDisconnect)
+ {
+ client->stopped = 1;
+ }
+
+ LOG_DEBUG("sent %s", MqttPacketName(packet->type));
+
+ if (packet->type == MqttPacketTypePublish && packet->message)
+ {
+ MqttMessage *msg = packet->message;
+
+ if (msg->qos == 1)
+ {
+ msg->state = MqttMessageStateWaitPubAck;
+ }
+ else if (msg->qos == 2)
+ {
+ msg->state = MqttMessageStateWaitPubRec;
+ }
+ }
+
+ if (packet->message)
+ {
+ packet->message->timestamp = client->lastPacketSentTime;
+ }
+
+ SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue);
+
+ MqttPacketFree(packet);
+
+ packet = SIMPLEQ_FIRST(&client->sendQueue);
+
+ break;
+ }
+ }
}
return 0;
}
-static void MqttClientHandleConnAck(MqttClient *client,
- MqttPacketConnAck *packet)
+static int MqttClientHandleConnAck(MqttClient *client)
{
- client->sessionPresent = packet->connAckFlags & 1;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ unsigned char flags;
+ unsigned char rc;
+
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadByte(&flags, pss);
+
+ StreamReadByte(&rc, pss);
+
+ client->sessionPresent = flags & 1;
LOG_DEBUG("sessionPresent:%d", client->sessionPresent);
if (client->onConnect)
{
- LOG_DEBUG("calling onConnect rc:%d", packet->returnCode);
- client->onConnect(client, packet->returnCode, client->sessionPresent);
+ LOG_DEBUG("calling onConnect rc:%d", rc);
+ client->onConnect(client, rc, client->sessionPresent);
}
+
+ return 0;
}
-static void MqttClientHandlePingResp(MqttClient *client)
+static int MqttClientHandlePingResp(MqttClient *client)
{
LOG_DEBUG("got ping response");
client->pingSent = 0;
+ return 0;
}
-static void MqttClientHandleSubAck(MqttClient *client, MqttPacketSubAck *packet)
+static int MqttClientHandleSubAck(MqttClient *client)
{
- MqttPacket *sub;
+ uint16_t id;
+ int *qos;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ int count;
+ int i;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(sub, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ LOG_DEBUG("received SUBACK with id:%d", (int) id);
+
+ count = blength(client->inPacket.payload) - StreamTell(pss);
+
+ if (count <= 0)
{
- if (MqttPacketType(sub) == MqttPacketTypeSubscribe &&
- MqttPacketId(sub) == MqttPacketId(packet))
- {
- break;
- }
+ LOG_ERROR("number of return codes invalid");
+ return -1;
}
- if (!sub)
+ qos = malloc(count * sizeof(int));
+
+ for (i = 0; i < count; ++i)
{
- LOG_ERROR("SUBSCRIBE with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ unsigned char byte;
+ StreamReadByte(&byte, pss);
+ qos[i] = byte;
}
- else
+
+ if (client->onSubscribe)
{
- if (client->onSubscribe)
- {
- MqttPacketSubscribe *sub2;
- int i;
+ client->onSubscribe(client, id, qos, count);
+ }
- sub2 = (MqttPacketSubscribe *) sub;
+ free(qos);
- for (i = 0; i < sub2->topicFilters->qty; ++i)
- {
- const char *filter = bdata(sub2->topicFilters->entry[i]);
- int rc = packet->returnCode[i];
+ return 0;
+}
- LOG_DEBUG("calling onSubscribe id:%d filter:'%s' rc:%d",
- MqttPacketId(packet), filter, rc);
+static int MqttClientSendPubAck(MqttClient *client, uint16_t id)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
- client->onSubscribe(client, MqttPacketId(packet), filter, rc);
- }
- }
+ packet = MqttPacketWithIdNew(MqttPacketTypePubAck, id);
- TAILQ_REMOVE(&client->outMessages, sub, messages);
- MqttPacketFree(sub);
- }
+ if (!packet)
+ return -1;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(id, pss);
+
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
+}
+
+static int MqttClientSendPubRec(MqttClient *client, MqttMessage *msg)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ packet = MqttPacketWithIdNew(MqttPacketTypePubRec, msg->id);
+
+ if (!packet)
+ return -1;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(msg->id, pss);
+
+ packet->payload = ss.buffer;
+ packet->message = msg;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
+}
+
+static int MqttClientSendPubRel(MqttClient *client, MqttMessage *msg)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ packet = MqttPacketWithIdNew(MqttPacketTypePubRel, msg->id);
+
+ if (!packet)
+ return -1;
+
+ packet->flags = 0x2;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(msg->id, pss);
+
+ packet->payload = ss.buffer;
+ packet->message = msg;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
+}
+
+static int MqttClientSendPubComp(MqttClient *client, uint16_t id)
+{
+ MqttPacket *packet;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+
+ packet = MqttPacketWithIdNew(MqttPacketTypePubComp, id);
+
+ if (!packet)
+ return -1;
+
+ StringStreamInit(&ss);
+
+ StreamWriteUint16Be(id, pss);
+
+ packet->payload = ss.buffer;
+
+ MqttClientQueuePacket(client, packet);
+
+ return 0;
}
-static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packet)
+static int MqttClientHandlePublish(MqttClient *client)
{
+ MqttMessage *msg;
+ uint16_t id;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ MqttPacket *packet;
+ int qos;
+ int retain;
+ bstring topic;
+ void *payload;
+ int payloadSize;
+
+ /* We are paused - do nothing */
if (client->paused)
- return;
+ return 0;
+
+ packet = &client->inPacket;
+
+ qos = (packet->flags >> 1) & 3;
+ retain = packet->flags & 1;
+
+ StringStreamInitFromBstring(&ss, packet->payload);
+
+ StreamReadMqttString(&topic, pss);
+
+ StreamReadUint16Be(&id, pss);
+
+ payload = bdataofs(ss.buffer, ss.pos);
+ payloadSize = blength(ss.buffer) - ss.pos;
- if (MqttPacketPublishQos(packet) == 2)
+ if (qos == 2)
{
/* Check if we have sent a PUBREC previously with the same id. If we
have, we have to resend the PUBREC. We must not call the onMessage
callback again. */
- MqttPacket *pubRec;
-
- TAILQ_FOREACH(pubRec, &client->inMessages, messages)
+ TAILQ_FOREACH(msg, &client->inMessages, chain)
{
- if (MqttPacketId(pubRec) == MqttPacketId(packet) &&
- MqttPacketType(pubRec) == MqttPacketTypePubRec)
+ if (msg->id == id &&
+ msg->state == MqttMessageStateWaitPubRel)
{
break;
}
}
- if (pubRec)
+ if (msg)
{
- LOG_DEBUG("resending PUBREC id:%d", MqttPacketId(packet));
- MqttClientQueuePacket(client, pubRec);
- return;
+ LOG_DEBUG("resending PUBREC id:%u", msg->id);
+ MqttClientSendPubRec(client, msg);
+ bdestroy(topic);
+ return 0;
}
}
@@ -796,268 +1075,347 @@ static void MqttClientHandlePublish(MqttClient *client, MqttPacketPublish *packe
{
LOG_DEBUG("calling onMessage");
client->onMessage(client,
- bdata(packet->topicName),
- bdata(packet->message),
- blength(packet->message),
- packet->qos,
- packet->retain);
+ bdata(topic),
+ payload,
+ payloadSize,
+ qos,
+ retain);
}
- if (MqttPacketPublishQos(packet) > 0)
+ bdestroy(topic);
+
+ if (qos == 1)
+ {
+ MqttClientSendPubAck(client, id);
+ }
+ else if (qos == 2)
{
- int type = (MqttPacketPublishQos(packet) == 1) ? MqttPacketTypePubAck :
- MqttPacketTypePubRec;
+ msg = calloc(1, sizeof(*msg));
- MqttPacket *resp = MqttPacketWithIdNew(type, MqttPacketId(packet));
+ msg->state = MqttMessageStateWaitPubRel;
+ msg->id = id;
+ msg->qos = qos;
- if (MqttPacketPublishQos(packet) == 2)
- {
- /* append to inMessages as we need a reply to this response */
- TAILQ_INSERT_TAIL(&client->inMessages, resp, messages);
- }
+ TAILQ_INSERT_TAIL(&client->inMessages, msg, chain);
- MqttClientQueuePacket(client, resp);
+ MqttClientSendPubRec(client, msg);
}
+
+ return 0;
}
-static void MqttClientHandlePubAck(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubAck(MqttClient *client)
{
- MqttPacket *pub;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
+
+ assert(client != NULL);
- TAILQ_FOREACH(pub, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (MqttPacketId(pub) == MqttPacketId(packet) &&
- MqttPacketType(pub) == MqttPacketTypePublish)
+ if (msg->id == id &&
+ msg->state == MqttMessageStateWaitPubAck)
{
break;
}
}
- if (!pub)
+ if (!msg)
{
- LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ LOG_ERROR("no message found with id %d", (int) id);
+ return -1;
}
- else
- {
- TAILQ_REMOVE(&client->outMessages, pub, messages);
- MqttPacketFree(pub);
- if (client->onPublish)
- {
- client->onPublish(client, MqttPacketId(packet));
- }
+ TAILQ_REMOVE(&client->outMessages, msg, chain);
+
+ if (client->onPublish)
+ {
+ client->onPublish(client, msg->id);
}
+
+ MqttMessageFree(msg);
+
+ return 0;
}
-static void MqttClientHandlePubRec(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubRec(MqttClient *client)
{
- MqttPacket *pub;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(pub, &client->outMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (MqttPacketId(pub) == MqttPacketId(packet) &&
- MqttPacketType(pub) == MqttPacketTypePublish)
+ /* Also check if we are waiting for PUBCOMP, if we have sent PUBREL but
+ they haven't received it. */
+ if (msg->id == id &&
+ (msg->state == MqttMessageStateWaitPubRec ||
+ msg->state == MqttMessageStateWaitPubComp))
{
break;
}
}
- if (!pub)
+ if (!msg)
{
- LOG_ERROR("PUBLISH with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ LOG_ERROR("no message found with id %d", (int) id);
+ return -1;
}
- else
- {
- MqttPacket *pubRel;
- TAILQ_REMOVE(&client->outMessages, pub, messages);
- MqttPacketFree(pub);
+ msg->state = MqttMessageStateWaitPubComp;
- pubRel = MqttPacketWithIdNew(MqttPacketTypePubRel, MqttPacketId(packet));
- pubRel->state = MessageStateSend;
+ bdestroy(msg->payload);
+ msg->payload = NULL;
- TAILQ_INSERT_TAIL(&client->outMessages, pubRel, messages);
- }
+ bdestroy(msg->topic);
+ msg->topic = NULL;
+
+ if (MqttClientSendPubRel(client, msg) == -1)
+ return -1;
+
+ return 0;
}
-static void MqttClientHandlePubRel(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubRel(MqttClient *client)
{
- MqttPacket *pubRec;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(pubRec, &client->inMessages, messages)
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
+
+ TAILQ_FOREACH(msg, &client->inMessages, chain)
{
- if (MqttPacketId(pubRec) == MqttPacketId(packet) &&
- MqttPacketType(pubRec) == MqttPacketTypePubRec)
+ if (msg->id == id &&
+ msg->state == MqttMessageStateWaitPubRel)
{
break;
}
}
- if (!pubRec)
+ if (!msg)
{
- MqttPacket *pubComp;
-
- TAILQ_FOREACH(pubComp, &client->inMessages, messages)
- {
- if (MqttPacketId(pubComp) == MqttPacketId(packet) &&
- MqttPacketType(pubComp) == MqttPacketTypePubComp)
- {
- break;
- }
- }
-
- if (pubComp)
- {
- MqttClientQueuePacket(client, pubComp);
- }
- else
- {
- LOG_ERROR("PUBREC with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
- }
+ LOG_ERROR("no message found with id %d", (int) id);
+ return -1;
}
- else
- {
- MqttPacket *pubComp;
-
- TAILQ_REMOVE(&client->inMessages, pubRec, messages);
- MqttPacketFree(pubRec);
- pubComp = MqttPacketWithIdNew(MqttPacketTypePubComp,
- MqttPacketId(packet));
+ TAILQ_REMOVE(&client->inMessages, msg, chain);
+ MqttMessageFree(msg);
- TAILQ_INSERT_TAIL(&client->inMessages, pubComp, messages);
+ if (MqttClientSendPubComp(client, id) == -1)
+ return -1;
- MqttClientQueuePacket(client, pubComp);
- }
+ return 0;
}
-static void MqttClientHandlePubComp(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandlePubComp(MqttClient *client)
{
- MqttPacket *pubRel;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
+ uint16_t id;
+ MqttMessage *msg;
+
+ assert(client != NULL);
+
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
+
+ StreamReadUint16Be(&id, pss);
- TAILQ_FOREACH(pubRel, &client->outMessages, messages)
+ TAILQ_FOREACH(msg, &client->outMessages, chain)
{
- if (MqttPacketId(pubRel) == MqttPacketId(packet) &&
- MqttPacketType(pubRel) == MqttPacketTypePubRel)
+ if (msg->id == id && msg->state == MqttMessageStateWaitPubComp)
{
break;
}
}
- if (!pubRel)
+ if (!msg)
{
- LOG_ERROR("PUBREL with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ LOG_WARNING("no message found with id %d", (int) id);
+ return 0;
}
- else
- {
- TAILQ_REMOVE(&client->outMessages, pubRel, messages);
- MqttPacketFree(pubRel);
- if (client->onPublish)
- {
- LOG_DEBUG("calling onPublish id:%d", MqttPacketId(packet));
- client->onPublish(client, MqttPacketId(packet));
- }
+ TAILQ_REMOVE(&client->outMessages, msg, chain);
+
+ MqttMessageFree(msg);
+
+ if (client->onPublish)
+ {
+ LOG_DEBUG("calling onPublish id:%d", id);
+ client->onPublish(client, id);
}
+
+ return 0;
}
-static void MqttClientHandleUnsubAck(MqttClient *client, MqttPacket *packet)
+static int MqttClientHandleUnsubAck(MqttClient *client)
{
- MqttPacket *sub;
+ uint16_t id;
+ StringStream ss;
+ Stream *pss = (Stream *) &ss;
assert(client != NULL);
- assert(packet != NULL);
- TAILQ_FOREACH(sub, &client->outMessages, messages)
- {
- if (MqttPacketId(sub) == MqttPacketId(packet) &&
- MqttPacketType(sub) == MqttPacketTypeUnsubscribe)
- {
- break;
- }
- }
+ StringStreamInitFromBstring(&ss, client->inPacket.payload);
- if (!sub)
+ StreamReadUint16Be(&id, pss);
+
+ if (client->onUnsubscribe)
{
- LOG_ERROR("UNSUBSCRIBE with id:%d not found", MqttPacketId(packet));
- client->stopped = 1;
+ client->onUnsubscribe(client, id);
}
- else
- {
- TAILQ_REMOVE(&client->outMessages, sub, messages);
- MqttPacketFree(sub);
- if (client->onUnsubscribe)
- {
- LOG_DEBUG("calling onUnsubscribe id:%d", MqttPacketId(packet));
- client->onUnsubscribe(client, MqttPacketId(packet));
- }
- }
+ return 0;
}
-static int MqttClientRecvPacket(MqttClient *client)
+static int MqttClientHandlePacket(MqttClient *client)
{
- MqttPacket *packet = NULL;
-
- if (MqttPacketDeserialize(&packet, (Stream *) &client->stream) == -1)
- return -1;
-
- LOG_DEBUG("received packet %s", MqttPacketName(packet->type));
+ int rc;
- switch (MqttPacketType(packet))
+ switch (client->inPacket.type)
{
case MqttPacketTypeConnAck:
- MqttClientHandleConnAck(client, (MqttPacketConnAck *) packet);
+ rc = MqttClientHandleConnAck(client);
break;
case MqttPacketTypePingResp:
- MqttClientHandlePingResp(client);
+ rc = MqttClientHandlePingResp(client);
break;
case MqttPacketTypeSubAck:
- MqttClientHandleSubAck(client, (MqttPacketSubAck *) packet);
+ rc = MqttClientHandleSubAck(client);
break;
- case MqttPacketTypePublish:
- MqttClientHandlePublish(client, (MqttPacketPublish *) packet);
+ case MqttPacketTypeUnsubAck:
+ rc = MqttClientHandleUnsubAck(client);
break;
case MqttPacketTypePubAck:
- MqttClientHandlePubAck(client, packet);
+ rc = MqttClientHandlePubAck(client);
break;
case MqttPacketTypePubRec:
- MqttClientHandlePubRec(client, packet);
+ rc = MqttClientHandlePubRec(client);
break;
- case MqttPacketTypePubRel:
- MqttClientHandlePubRel(client, packet);
+ case MqttPacketTypePubComp:
+ rc = MqttClientHandlePubComp(client);
break;
- case MqttPacketTypePubComp:
- MqttClientHandlePubComp(client, packet);
+ case MqttPacketTypePubRel:
+ rc = MqttClientHandlePubRel(client);
break;
- case MqttPacketTypeUnsubAck:
- MqttClientHandleUnsubAck(client, packet);
+ case MqttPacketTypePublish:
+ rc = MqttClientHandlePublish(client);
break;
default:
- LOG_DEBUG("unhandled packet type=%d", MqttPacketType(packet));
+ LOG_ERROR("packet not handled yet");
+ rc = -1;
break;
}
- MqttPacketFree(packet);
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+
+ client->inPacket.state = MqttPacketStateReadType;
+
+ return rc;
+}
+
+static int MqttClientRecvPacket(MqttClient *client)
+{
+ while (1)
+ {
+ switch (client->inPacket.state)
+ {
+ case MqttPacketStateReadType:
+ {
+ unsigned char typeAndFlags;
+ int rc;
+
+ if ((rc = StreamReadByte(&typeAndFlags, &client->stream.base)) != 1)
+ {
+ LOG_ERROR("failed reading packet type: %d", rc);
+ return -1;
+ }
+
+ client->inPacket.type = typeAndFlags >> 4;
+ client->inPacket.flags = typeAndFlags & 0x0F;
+
+ if (client->inPacket.type < MqttPacketTypeConnect ||
+ client->inPacket.type > MqttPacketTypeDisconnect)
+ {
+ LOG_ERROR("unknown packet type: %d", client->inPacket.type);
+ return -1;
+ }
+
+ client->inPacket.state = MqttPacketStateReadRemainingLength;
+
+ break;
+ }
+
+ case MqttPacketStateReadRemainingLength:
+ {
+ if (StreamReadRemainingLength(&client->inPacket.remainingLength,
+ &client->stream.base) == -1)
+ {
+ LOG_ERROR("failed to read remaining length");
+ return -1;
+ }
+ client->inPacket.state = MqttPacketStateReadPayload;
+ break;
+ }
+
+ case MqttPacketStateReadPayload:
+ {
+ if (client->inPacket.remainingLength > 0)
+ {
+ client->inPacket.payload = bfromcstr("");
+ ballocmin(client->inPacket.payload,
+ client->inPacket.remainingLength+1);
+ if (StreamRead(bdata(client->inPacket.payload),
+ client->inPacket.remainingLength,
+ &client->stream.base) == -1)
+ {
+ LOG_ERROR("failed reading packet payload");
+ bdestroy(client->inPacket.payload);
+ client->inPacket.payload = NULL;
+ return -1;
+ }
+ client->inPacket.payload->slen = client->inPacket.remainingLength;
+ }
+ client->inPacket.state = MqttPacketStateReadComplete;
+ break;
+ }
+
+ case MqttPacketStateReadComplete:
+ {
+ int type = client->inPacket.type;
+ LOG_DEBUG("received %s", MqttPacketName(type));
+ return MqttClientHandlePacket(client);
+ }
+ }
+ }
return 0;
}
@@ -1072,101 +1430,89 @@ static uint16_t MqttClientNextPacketId(MqttClient *client)
return id;
}
-static int64_t MqttPacketTimeSinceSent(MqttPacket *packet)
+static int64_t MqttMessageTimeSinceSent(MqttMessage *msg)
{
int64_t now = MqttGetCurrentTime();
- return now - packet->sentAt;
+ return now - msg->timestamp;
}
-static void MqttClientProcessInMessages(MqttClient *client)
+static int MqttMessageShouldResend(MqttClient *client, MqttMessage *msg)
{
- MqttPacket *packet, *next;
-
- LOG_DEBUG("processing inMessages");
-
- TAILQ_FOREACH_SAFE(packet, &client->inMessages, messages, next)
+ if (msg->timestamp > 0 &&
+ MqttMessageTimeSinceSent(msg) >= client->retryTimeout*1000)
{
- LOG_DEBUG("packet type:%s id:%d",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet));
-
- if (MqttPacketType(packet) == MqttPacketTypePubComp)
- {
- int64_t elapsed = MqttPacketTimeSinceSent(packet);
- if (packet->sentAt > 0 &&
- elapsed >= client->retryTimeout*1000)
- {
- LOG_DEBUG("freeing PUBCOMP with id:%d elapsed:%" PRId64,
- MqttPacketId(packet), elapsed);
-
- TAILQ_REMOVE(&client->inMessages, packet, messages);
-
- MqttPacketFree(packet);
- }
- }
+ return 1;
}
+
+ return 0;
}
-static int MqttPacketShouldResend(MqttClient *client, MqttPacket *packet)
+static void MqttClientProcessInMessages(MqttClient *client)
{
- if (packet->sentAt > 0 &&
- MqttPacketTimeSinceSent(packet) > client->retryTimeout*1000)
+ MqttMessage *msg, *next;
+
+ TAILQ_FOREACH_SAFE(msg, &client->inMessages, chain, next)
{
- return 1;
- }
+ switch (msg->state)
+ {
+ case MqttMessageStateWaitPubRel:
+ if (MqttMessageShouldResend(client, msg))
+ {
+ MqttClientSendPubRec(client, msg);
+ }
+ break;
- return 0;
+ default:
+ break;
+ }
+ }
}
static void MqttClientProcessOutMessages(MqttClient *client)
{
- MqttPacket *packet, *next;
+ MqttMessage *msg, *next;
+ MqttPacket *packet;
int inflight = MqttClientInflightMessageCount(client);
- LOG_DEBUG("processing outMessages inflight:%d", inflight);
-
- TAILQ_FOREACH_SAFE(packet, &client->outMessages, messages, next)
+ TAILQ_FOREACH_SAFE(msg, &client->outMessages, chain, next)
{
- LOG_DEBUG("packet type:%s id:%d state:%d",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet),
- packet->state);
-
- switch (packet->state)
+ switch (msg->state)
{
- case MessageStateQueued:
+ case MqttMessageStateQueued:
+ {
if (inflight >= client->maxInflight)
{
- LOG_DEBUG("cannot dequeue %s/%d",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet));
- break;
- }
- else
- {
- /* If there's less than maxInflight messages currently
- inflight, we can dequeue some messages by falling
- through to MessageStateSend. */
- LOG_DEBUG("dequeuing %s (%d)",
- MqttPacketName(MqttPacketType(packet)),
- MqttPacketId(packet));
- ++inflight;
+ continue;
}
-
- case MessageStateSend:
- packet->state = MessageStateSent;
+ /* State change from MqttMessageStatePublish happens after
+ the packet has been sent (in MqttClientSendPacket). */
+ msg->state = MqttMessageStatePublish;
+ packet = PublishToPacket(msg);
MqttClientQueuePacket(client, packet);
+ ++inflight;
break;
+ }
- case MessageStateSent:
- if (MqttPacketShouldResend(client, packet))
+ case MqttMessageStateWaitPubAck:
+ case MqttMessageStateWaitPubRec:
+ {
+ if (MqttMessageShouldResend(client, msg))
{
- packet->state = MessageStateSend;
+ msg->state = MqttMessageStatePublish;
+ packet = PublishToPacket(msg);
+ MqttClientQueuePacket(client, packet);
}
break;
+ }
- default:
+ case MqttMessageStateWaitPubComp:
+ {
+ if (MqttMessageShouldResend(client, msg))
+ {
+ MqttClientSendPubRel(client, msg);
+ }
break;
+ }
}
}
}
@@ -1182,30 +1528,22 @@ static void MqttClientClearQueues(MqttClient *client)
while (!SIMPLEQ_EMPTY(&client->sendQueue))
{
MqttPacket *packet = SIMPLEQ_FIRST(&client->sendQueue);
-
SIMPLEQ_REMOVE_HEAD(&client->sendQueue, sendQueue);
-
- if (TAILQ_NEXT(packet, messages) == NULL &&
- TAILQ_PREV(packet, MessageList, messages) == NULL &&
- TAILQ_FIRST(&client->inMessages) != packet &&
- TAILQ_FIRST(&client->outMessages) != packet)
- {
- MqttPacketFree(packet);
- }
+ MqttPacketFree(packet);
}
while (!TAILQ_EMPTY(&client->outMessages))
{
- MqttPacket *packet = TAILQ_FIRST(&client->outMessages);
- TAILQ_REMOVE(&client->outMessages, packet, messages);
- MqttPacketFree(packet);
+ MqttMessage *msg = TAILQ_FIRST(&client->outMessages);
+ TAILQ_REMOVE(&client->outMessages, msg, chain);
+ MqttMessageFree(msg);
}
while (!TAILQ_EMPTY(&client->inMessages))
{
- MqttPacket *packet = TAILQ_FIRST(&client->inMessages);
- TAILQ_REMOVE(&client->inMessages, packet, messages);
- MqttPacketFree(packet);
+ MqttMessage *msg = TAILQ_FIRST(&client->inMessages);
+ TAILQ_REMOVE(&client->inMessages, msg, chain);
+ MqttMessageFree(msg);
}
}
diff --git a/src/deserialize.c b/src/deserialize.c
deleted file mode 100644
index 96d7789..0000000
--- a/src/deserialize.c
+++ /dev/null
@@ -1,286 +0,0 @@
-#include "deserialize.h"
-#include "packet.h"
-#include "stream_mqtt.h"
-#include "log.h"
-
-#include <stdlib.h>
-#include <assert.h>
-
-typedef int (*MqttPacketDeserializeFunc)(MqttPacket **packet, Stream *stream);
-
-static int MqttPacketWithIdDeserialize(MqttPacket **packet, Stream *stream)
-{
- size_t remainingLength = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (remainingLength != 2)
- return -1;
-
- if (StreamReadUint16Be(&(*packet)->id, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketConnAckDeserialize(MqttPacketConnAck **packet, Stream *stream)
-{
- size_t remainingLength = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (remainingLength != 2)
- return -1;
-
- if (StreamRead(&(*packet)->connAckFlags, 1, stream) != 1)
- return -1;
-
- if (StreamRead(&(*packet)->returnCode, 1, stream) != 1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketSubAckDeserialize(MqttPacketSubAck **packet, Stream *stream)
-{
- size_t remainingLength = 0;
- size_t i;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1)
- return -1;
-
- remainingLength -= 2;
-
- (*packet)->returnCode = (unsigned char *) malloc(
- sizeof(*(*packet)->returnCode) * remainingLength);
-
- for (i = 0; i < remainingLength; ++i)
- {
- if (StreamRead(&((*packet)->returnCode[i]), 1, stream) == -1)
- return -1;
- }
-
- return 0;
-}
-
-static int MqttPacketTypeUnsubAckDeserialize(MqttPacket **packet, Stream *stream)
-{
- size_t remainingLength = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (remainingLength != 2)
- return -1;
-
- if (StreamReadUint16Be(&(*packet)->id, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketPublishDeserialize(MqttPacketPublish **packet, Stream *stream)
-{
- size_t remainingLength = 0;
- size_t payloadSize = 0;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- if (StreamReadMqttString(&(*packet)->topicName, stream) == -1)
- return -1;
-
- LOG_DEBUG("remainingLength:%lu", remainingLength);
-
- payloadSize = remainingLength - blength((*packet)->topicName) - 2;
-
- LOG_DEBUG("qos:%d payloadSize:%lu", MqttPacketPublishQos(*packet),
- payloadSize);
-
- if (MqttPacketHasId((const MqttPacket *) *packet))
- {
- LOG_DEBUG("packet has id");
- payloadSize -= 2;
- if (StreamReadUint16Be(&((*packet)->base.id), stream) == -1)
- {
- return -1;
- }
- }
-
- LOG_DEBUG("reading payload payloadSize:%lu\n", payloadSize);
-
- /* Allocate extra byte for a NULL terminator. If the user tries to print
- the payload directly. */
-
- (*packet)->message = bfromcstralloc(payloadSize+1, "");
-
- if (StreamRead(bdata((*packet)->message), payloadSize, stream) == -1)
- return -1;
-
- (*packet)->message->slen = payloadSize;
- (*packet)->message->data[payloadSize] = '\0';
-
- return 0;
-}
-
-static int MqttPacketGenericDeserializer(MqttPacket **packet, Stream *stream)
-{
- size_t remainingLength = 0;
- char buffer[256];
-
- (void) packet;
-
- if (StreamReadRemainingLength(&remainingLength, stream) == -1)
- return -1;
-
- while (remainingLength > 0)
- {
- size_t l = sizeof(buffer);
-
- if (remainingLength < l)
- l = remainingLength;
-
- if (StreamRead(buffer, l, stream) != (int64_t) l)
- return -1;
-
- remainingLength -= l;
- }
-
- return 0;
-}
-
-static int ValidateFlags(int type, int flags)
-{
- int rv = 0;
-
- switch (type)
- {
- case MqttPacketTypePublish:
- {
- int qos = (flags >> 1) & 2;
- if (qos >= 0 && qos <= 2)
- rv = 1;
- break;
- }
-
- case MqttPacketTypePubRel:
- case MqttPacketTypeSubscribe:
- case MqttPacketTypeUnsubscribe:
- if (flags == 2)
- {
- rv = 1;
- }
- break;
-
- default:
- if (flags == 0)
- {
- rv = 1;
- }
- break;
- }
-
- return rv;
-}
-
-int MqttPacketDeserialize(MqttPacket **packet, Stream *stream)
-{
- MqttPacketDeserializeFunc deserializer = NULL;
- char typeAndFlags;
- int type;
- int flags;
- int rv;
-
- if (StreamRead(&typeAndFlags, 1, stream) != 1)
- return -1;
-
- type = (typeAndFlags & 0xF0) >> 4;
- flags = (typeAndFlags & 0x0F);
-
- if (!ValidateFlags(type, flags))
- {
- return -1;
- }
-
- switch (type)
- {
- case MqttPacketTypeConnect:
- break;
-
- case MqttPacketTypeConnAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketConnAckDeserialize;
- break;
-
- case MqttPacketTypePublish:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketPublishDeserialize;
- break;
-
- case MqttPacketTypePubAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypePubRec:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypePubRel:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypePubComp:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketWithIdDeserialize;
- break;
-
- case MqttPacketTypeSubscribe:
- break;
-
- case MqttPacketTypeSubAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketSubAckDeserialize;
- break;
-
- case MqttPacketTypeUnsubscribe:
- break;
-
- case MqttPacketTypeUnsubAck:
- deserializer = (MqttPacketDeserializeFunc) MqttPacketTypeUnsubAckDeserialize;
- break;
-
- case MqttPacketTypePingReq:
- break;
-
- case MqttPacketTypePingResp:
- break;
-
- case MqttPacketTypeDisconnect:
- break;
-
- default:
- return -1;
- }
-
- if (!deserializer)
- {
- deserializer = MqttPacketGenericDeserializer;
- }
-
- *packet = MqttPacketNew(type);
-
- if (!*packet)
- return -1;
-
- if (type == MqttPacketTypePublish)
- {
- MqttPacketPublishDup(*packet) = (flags >> 3) & 1;
- MqttPacketPublishQos(*packet) = (flags >> 1) & 3;
- MqttPacketPublishRetain(*packet) = flags & 1;
- }
-
- rv = deserializer(packet, stream);
-
- return rv;
-}
diff --git a/src/deserialize.h b/src/deserialize.h
deleted file mode 100644
index 8c29b3d..0000000
--- a/src/deserialize.h
+++ /dev/null
@@ -1,11 +0,0 @@
-#ifndef DESERIALIZE_H
-#define DESERIALIZE_H
-
-#include "config.h"
-
-typedef struct MqttPacket MqttPacket;
-typedef struct Stream Stream;
-
-int MqttPacketDeserialize(MqttPacket **packet, Stream *stream);
-
-#endif
diff --git a/src/mqtt.h b/src/mqtt.h
index 2b84962..fdead0d 100644
--- a/src/mqtt.h
+++ b/src/mqtt.h
@@ -33,8 +33,8 @@ typedef void (*MqttClientOnConnectCallback)(MqttClient *client,
typedef void (*MqttClientOnSubscribeCallback)(MqttClient *client,
int id,
- const char *topicFilter,
- MqttSubscriptionStatus status);
+ int *qos,
+ int count);
typedef void (*MqttClientOnUnsubscribeCallback)(MqttClient *client, int id);
diff --git a/src/packet.c b/src/packet.c
index 47aa689..c833851 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -28,42 +28,16 @@ const char *MqttPacketName(int type)
}
}
-static MQTT_INLINE size_t MqttPacketStructSize(int type)
-{
- switch (type)
- {
- case MqttPacketTypeConnect: return sizeof(MqttPacketConnect);
- case MqttPacketTypeConnAck: return sizeof(MqttPacketConnAck);
- case MqttPacketTypePublish: return sizeof(MqttPacketPublish);
- case MqttPacketTypePubAck:
- case MqttPacketTypePubRec:
- case MqttPacketTypePubRel:
- case MqttPacketTypePubComp: return sizeof(MqttPacket);
- case MqttPacketTypeSubscribe: return sizeof(MqttPacketSubscribe);
- case MqttPacketTypeSubAck: return sizeof(MqttPacketSubAck);
- case MqttPacketTypeUnsubscribe: return sizeof(MqttPacketUnsubscribe);
- case MqttPacketTypeUnsubAck: return sizeof(MqttPacket);
- case MqttPacketTypePingReq: return sizeof(MqttPacket);
- case MqttPacketTypePingResp: return sizeof(MqttPacket);
- case MqttPacketTypeDisconnect: return sizeof(MqttPacket);
- default: return (size_t) -1;
- }
-}
-
MqttPacket *MqttPacketNew(int type)
{
MqttPacket *packet = NULL;
- packet = (MqttPacket *) calloc(1, MqttPacketStructSize(type));
+ packet = (MqttPacket *) calloc(1, sizeof(*packet));
if (!packet)
return NULL;
packet->type = type;
- /* this will make sure that TAILQ_PREV does not segfault if a message
- has not been added to a list at any point */
- packet->messages.tqe_prev = &packet->messages.tqe_next;
-
return packet;
}
@@ -78,52 +52,6 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id)
void MqttPacketFree(MqttPacket *packet)
{
- if (MqttPacketType(packet) == MqttPacketTypeConnect)
- {
- MqttPacketConnect *p = (MqttPacketConnect *) packet;
- bdestroy(p->clientId);
- bdestroy(p->willTopic);
- bdestroy(p->willMessage);
- bdestroy(p->userName);
- bdestroy(p->password);
- }
- else if (MqttPacketType(packet) == MqttPacketTypePublish)
- {
- MqttPacketPublish *p = (MqttPacketPublish *) packet;
- bdestroy(p->topicName);
- bdestroy(p->message);
- }
- else if (MqttPacketType(packet) == MqttPacketTypeSubscribe)
- {
- MqttPacketSubscribe *p = (MqttPacketSubscribe *) packet;
- bstrListDestroy(p->topicFilters);
- }
- else if (MqttPacketType(packet) == MqttPacketTypeUnsubscribe)
- {
- MqttPacketUnsubscribe *p = (MqttPacketUnsubscribe *) packet;
- bdestroy(p->topicFilter);
- }
+ bdestroy(packet->payload);
free(packet);
}
-
-int MqttPacketHasId(const MqttPacket *packet)
-{
- switch (packet->type)
- {
- case MqttPacketTypePublish:
- return MqttPacketPublishQos(packet) > 0;
-
- case MqttPacketTypePubAck:
- case MqttPacketTypePubRec:
- case MqttPacketTypePubRel:
- case MqttPacketTypePubComp:
- case MqttPacketTypeSubscribe:
- case MqttPacketTypeSubAck:
- case MqttPacketTypeUnsubscribe:
- case MqttPacketTypeUnsubAck:
- return 1;
-
- default:
- return 0;
- }
-}
diff --git a/src/packet.h b/src/packet.h
index 7ab4f73..36dc81f 100644
--- a/src/packet.h
+++ b/src/packet.h
@@ -29,87 +29,33 @@ enum
MqttPacketTypeDisconnect = 0xE
};
+enum MqttPacketState
+{
+ MqttPacketStateReadType,
+ MqttPacketStateReadRemainingLength,
+ MqttPacketStateReadPayload,
+ MqttPacketStateReadComplete,
+
+ MqttPacketStateWriteType,
+ MqttPacketStateWriteRemainingLength,
+ MqttPacketStateWritePayload,
+ MqttPacketStateWriteComplete
+};
+
+struct MqttMessage;
+
typedef struct MqttPacket MqttPacket;
struct MqttPacket
{
int type;
- uint16_t id;
- int state;
int flags;
- int64_t sentAt;
+ int state;
+ uint16_t id;
+ size_t remainingLength;
+ bstring payload;
+ struct MqttMessage *message;
SIMPLEQ_ENTRY(MqttPacket) sendQueue;
- TAILQ_ENTRY(MqttPacket) messages;
-};
-
-#define MqttPacketType(packet) (((MqttPacket *) (packet))->type)
-
-#define MqttPacketId(packet) (((MqttPacket *) (packet))->id)
-
-#define MqttPacketSentAt(packet) (((MqttPacket *) (packet))->sentAt)
-
-typedef struct MqttPacketConnect MqttPacketConnect;
-
-struct MqttPacketConnect
-{
- MqttPacket base;
- char connectFlags;
- uint16_t keepAlive;
- bstring clientId;
- bstring willTopic;
- bstring willMessage;
- bstring userName;
- bstring password;
-};
-
-typedef struct MqttPacketConnAck MqttPacketConnAck;
-
-struct MqttPacketConnAck
-{
- MqttPacket base;
- unsigned char connAckFlags;
- unsigned char returnCode;
-};
-
-typedef struct MqttPacketPublish MqttPacketPublish;
-
-struct MqttPacketPublish
-{
- MqttPacket base;
- bstring topicName;
- bstring message;
- char qos;
- char dup;
- char retain;
-};
-
-#define MqttPacketPublishQos(p) (((MqttPacketPublish *) p)->qos)
-#define MqttPacketPublishDup(p) (((MqttPacketPublish *) p)->dup)
-#define MqttPacketPublishRetain(p) (((MqttPacketPublish *) p)->retain)
-
-typedef struct MqttPacketSubscribe MqttPacketSubscribe;
-
-struct MqttPacketSubscribe
-{
- MqttPacket base;
- struct bstrList *topicFilters;
- int *qos;
-};
-
-typedef struct MqttPacketSubAck MqttPacketSubAck;
-
-struct MqttPacketSubAck
-{
- MqttPacket base;
- unsigned char *returnCode;
-};
-
-typedef struct MqttPacketUnsubscribe MqttPacketUnsubscribe;
-
-struct MqttPacketUnsubscribe
-{
- MqttPacket base;
- bstring topicFilter;
};
const char *MqttPacketName(int type);
@@ -120,6 +66,4 @@ MqttPacket *MqttPacketWithIdNew(int type, uint16_t id);
void MqttPacketFree(MqttPacket *packet);
-int MqttPacketHasId(const MqttPacket *packet);
-
#endif
diff --git a/src/serialize.c b/src/serialize.c
deleted file mode 100644
index c1c8eb4..0000000
--- a/src/serialize.c
+++ /dev/null
@@ -1,326 +0,0 @@
-#include "serialize.h"
-#include "packet.h"
-#include "stream_mqtt.h"
-#include "log.h"
-
-#include <bstrlib/bstrlib.h>
-
-#include <stdlib.h>
-#include <assert.h>
-
-typedef int (*MqttPacketSerializeFunc)(const MqttPacket *packet,
- Stream *stream);
-
-static const struct tagbstring MqttProtocolId = bsStatic("MQTT");
-static const char MqttProtocolLevel = 0x04;
-
-static MQTT_INLINE size_t MqttStringLengthSerialized(const_bstring s)
-{
- return 2 + blength(s);
-}
-
-static size_t MqttPacketConnectGetRemainingLength(const MqttPacketConnect *packet)
-{
- size_t remainingLength = 0;
-
- remainingLength += MqttStringLengthSerialized(&MqttProtocolId) + 1 + 1 + 2;
-
- remainingLength += MqttStringLengthSerialized(packet->clientId);
-
- if (packet->connectFlags & 0x80)
- remainingLength += MqttStringLengthSerialized(packet->userName);
-
- if (packet->connectFlags & 0x40)
- remainingLength += MqttStringLengthSerialized(packet->password);
-
- if (packet->connectFlags & 0x04)
- remainingLength += MqttStringLengthSerialized(packet->willTopic) +
- MqttStringLengthSerialized(packet->willMessage);
-
- return remainingLength;
-}
-
-static size_t MqttPacketSubscribeGetRemainingLength(const MqttPacketSubscribe *packet)
-{
- size_t remaining = 2;
- int i;
-
- for (i = 0; i < packet->topicFilters->qty; ++i)
- {
- remaining += MqttStringLengthSerialized(packet->topicFilters->entry[i]);
- remaining += 1;
- }
-
- return remaining;
-}
-
-static size_t MqttPacketUnsubscribeGetRemainingLength(const MqttPacketUnsubscribe *packet)
-{
- return 2 + MqttStringLengthSerialized(packet->topicFilter);
-}
-
-static size_t MqttPacketPublishGetRemainingLength(const MqttPacketPublish *packet)
-{
- size_t remainingLength = 0;
-
- remainingLength += MqttStringLengthSerialized(packet->topicName);
-
- /* Packet id */
- if (MqttPacketPublishQos(packet) == 1 || MqttPacketPublishQos(packet) == 2)
- {
- remainingLength += 2;
- }
-
- remainingLength += blength(packet->message);
-
- return remainingLength;
-}
-
-static size_t MqttPacketGetRemainingLength(const MqttPacket *packet)
-{
- switch (packet->type)
- {
- case MqttPacketTypeConnect:
- return MqttPacketConnectGetRemainingLength(
- (MqttPacketConnect *) packet);
-
- case MqttPacketTypeSubscribe:
- return MqttPacketSubscribeGetRemainingLength(
- (MqttPacketSubscribe *) packet);
-
- case MqttPacketTypePublish:
- return MqttPacketPublishGetRemainingLength(
- (MqttPacketPublish *) packet);
-
- case MqttPacketTypePubAck:
- case MqttPacketTypePubRec:
- case MqttPacketTypePubRel:
- case MqttPacketTypePubComp:
- return 2;
-
- case MqttPacketTypeUnsubscribe:
- return MqttPacketUnsubscribeGetRemainingLength(
- (MqttPacketUnsubscribe *) packet);
-
- default:
- return 0;
- }
-}
-
-static int MqttPacketFlags(const MqttPacket *packet)
-{
- switch (packet->type)
- {
- case MqttPacketTypePublish:
- return ((MqttPacketPublishDup(packet) & 1) << 3) |
- ((MqttPacketPublishQos(packet) & 3) << 1) |
- (MqttPacketPublishRetain(packet) & 1);
-
- case MqttPacketTypePubRel:
- case MqttPacketTypeSubscribe:
- case MqttPacketTypeUnsubscribe:
- return 0x2;
-
- default:
- return 0;
- }
-}
-
-static int MqttPacketBaseSerialize(const MqttPacket *packet, Stream *stream)
-{
- unsigned char typeAndFlags;
- size_t remainingLength;
-
- typeAndFlags = ((packet->type & 0x0F) << 4) |
- (MqttPacketFlags(packet) & 0x0F);
- remainingLength = MqttPacketGetRemainingLength(packet);
-
- LOG_DEBUG("type:%02X (%s) flags:%02X", packet->type,
- MqttPacketName(packet->type), MqttPacketFlags(packet));
-
- if (StreamWrite(&typeAndFlags, 1, stream) != 1)
- return -1;
-
- if (StreamWriteRemainingLength(remainingLength, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketWithIdSerialize(const MqttPacket *packet, Stream *stream)
-{
- assert(MqttPacketHasId((const MqttPacket *) packet));
-
- if (MqttPacketBaseSerialize(packet, stream) == -1)
- return -1;
-
- if (StreamWriteUint16Be(packet->id, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketConnectSerialize(const MqttPacketConnect *packet, Stream *stream)
-{
- if (MqttPacketBaseSerialize(&packet->base, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(&MqttProtocolId, stream) == -1)
- return -1;
-
- if (StreamWrite(&MqttProtocolLevel, 1, stream) != 1)
- return -1;
-
- if (StreamWrite(&packet->connectFlags, 1, stream) != 1)
- return -1;
-
- if (StreamWriteUint16Be(packet->keepAlive, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->clientId, stream) == -1)
- return -1;
-
- if (packet->connectFlags & 0x04)
- {
- if (StreamWriteMqttString(packet->willTopic, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->willMessage, stream) == -1)
- return -1;
- }
-
- if (packet->connectFlags & 0x80)
- {
- if (StreamWriteMqttString(packet->userName, stream) == -1)
- return -1;
-
- if (packet->connectFlags & 0x40)
- {
- if (StreamWriteMqttString(packet->password, stream) == -1)
- return -1;
- }
- }
-
- return 0;
-}
-
-static int MqttPacketSubscribeSerialize(const MqttPacketSubscribe *packet, Stream *stream)
-{
- int i;
-
- if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1)
- return -1;
-
- for (i = 0; i < packet->topicFilters->qty; ++i)
- {
- unsigned char qos = (unsigned char) packet->qos[i];
-
- if (StreamWriteMqttString(packet->topicFilters->entry[i], stream) == -1)
- return -1;
-
- if (StreamWrite(&qos, 1, stream) == -1)
- return -1;
- }
-
- return 0;
-}
-
-static int MqttPacketUnsubscribeSerialize(const MqttPacketUnsubscribe *packet, Stream *stream)
-{
- if (MqttPacketWithIdSerialize((const MqttPacket *) packet, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->topicFilter, stream) == -1)
- return -1;
-
- return 0;
-}
-
-static int MqttPacketPublishSerialize(const MqttPacketPublish *packet, Stream *stream)
-{
- if (MqttPacketBaseSerialize((const MqttPacket *) packet, stream) == -1)
- return -1;
-
- if (StreamWriteMqttString(packet->topicName, stream) == -1)
- return -1;
-
- LOG_DEBUG("qos:%d", MqttPacketPublishQos(packet));
-
- if (MqttPacketPublishQos(packet) > 0)
- {
- if (StreamWriteUint16Be(packet->base.id, stream) == -1)
- return -1;
- }
-
- if (StreamWrite(bdata(packet->message), blength(packet->message), stream) == -1)
- return -1;
-
- return 0;
-}
-
-int MqttPacketSerialize(const MqttPacket *packet, Stream *stream)
-{
- MqttPacketSerializeFunc f = NULL;
-
- switch (packet->type)
- {
- case MqttPacketTypeConnect:
- f = (MqttPacketSerializeFunc) MqttPacketConnectSerialize;
- break;
-
- case MqttPacketTypeConnAck:
- break;
-
- case MqttPacketTypePublish:
- f = (MqttPacketSerializeFunc) MqttPacketPublishSerialize;
- break;
-
- case MqttPacketTypePubAck:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypePubRec:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypePubRel:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypePubComp:
- f = (MqttPacketSerializeFunc) MqttPacketWithIdSerialize;
- break;
-
- case MqttPacketTypeSubscribe:
- f = (MqttPacketSerializeFunc) MqttPacketSubscribeSerialize;
- break;
-
- case MqttPacketTypeSubAck:
- break;
-
- case MqttPacketTypeUnsubscribe:
- f = (MqttPacketSerializeFunc) MqttPacketUnsubscribeSerialize;
- break;
-
- case MqttPacketTypeUnsubAck:
- break;
-
- case MqttPacketTypePingReq:
- f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize;
- break;
-
- case MqttPacketTypePingResp:
- break;
-
- case MqttPacketTypeDisconnect:
- f = (MqttPacketSerializeFunc) MqttPacketBaseSerialize;
- break;
-
- default:
- return -1;
- }
-
- assert(f != NULL && "no serializer");
-
- return f(packet, stream);
-}
diff --git a/src/serialize.h b/src/serialize.h
deleted file mode 100644
index 7eb988f..0000000
--- a/src/serialize.h
+++ /dev/null
@@ -1,11 +0,0 @@
-#ifndef SERIALIZE_H
-#define SERIALIZE_H
-
-#include "config.h"
-
-typedef struct MqttPacket MqttPacket;
-typedef struct Stream Stream;
-
-int MqttPacketSerialize(const MqttPacket *packet, Stream *stream);
-
-#endif
diff --git a/src/stream.c b/src/stream.c
index fd154a1..1c46668 100644
--- a/src/stream.c
+++ b/src/stream.c
@@ -47,6 +47,11 @@ int64_t StreamReadUint16Be(uint16_t *v, Stream *stream)
return 2;
}
+int64_t StreamReadByte(unsigned char *byte, Stream *stream)
+{
+ return StreamRead(byte, sizeof(*byte), stream);
+}
+
int64_t StreamWrite(const void *ptr, size_t size, Stream *stream)
{
STREAM_CHECK_OP(stream, write);
@@ -65,6 +70,11 @@ int64_t StreamWriteUint16Be(uint16_t v, Stream *stream)
return StreamWrite(data, sizeof(data), stream);
}
+int64_t StreamWriteByte(unsigned char byte, Stream *stream)
+{
+ return StreamWrite(&byte, sizeof(byte), stream);
+}
+
int StreamSeek(Stream *stream, int64_t offset, int whence)
{
STREAM_CHECK_OP(stream, seek);
diff --git a/src/stream.h b/src/stream.h
index 839facb..50f1772 100644
--- a/src/stream.h
+++ b/src/stream.h
@@ -27,9 +27,11 @@ int StreamClose(Stream *stream);
int64_t StreamRead(void *ptr, size_t size, Stream *stream);
int64_t StreamReadUint16Be(uint16_t *v, Stream *stream);
+int64_t StreamReadByte(unsigned char *byte, Stream *stream);
int64_t StreamWrite(const void *ptr, size_t size, Stream *stream);
int64_t StreamWriteUint16Be(uint16_t v, Stream *stream);
+int64_t StreamWriteByte(unsigned char byte, Stream *stream);
int StreamSeek(Stream *stream, int64_t offset, int whence);
diff --git a/src/stream_mqtt.h b/src/stream_mqtt.h
index 9023430..a128524 100644
--- a/src/stream_mqtt.h
+++ b/src/stream_mqtt.h
@@ -2,6 +2,7 @@
#define STREAM_MQTT_H
#include "stream.h"
+#include "stringstream.h"
#include <bstrlib/bstrlib.h>
diff --git a/test/interop/CMakeLists.txt b/test/interop/CMakeLists.txt
index e907776..d4b43d7 100644
--- a/test/interop/CMakeLists.txt
+++ b/test/interop/CMakeLists.txt
@@ -17,3 +17,6 @@ ADD_INTEROP_TEST(keepalive_test)
ADD_INTEROP_TEST(redelivery_on_reconnect_test)
ADD_INTEROP_TEST(subscribe_failure_test)
ADD_INTEROP_TEST(dollar_topics_test)
+ADD_INTEROP_TEST(username_and_password_test)
+ADD_INTEROP_TEST(ping_test)
+ADD_INTEROP_TEST(unsubscribe_test)
diff --git a/test/interop/ping_test.c b/test/interop/ping_test.c
new file mode 100644
index 0000000..6d699da
--- /dev/null
+++ b/test/interop/ping_test.c
@@ -0,0 +1,27 @@
+#include "greatest.h"
+#include "testclient.h"
+#include "cleanup.c"
+#include "topics.c"
+
+TEST ping_test()
+{
+ TestClient *client;
+
+ client = TestClientNew("clienta");
+ ASSERT(TestClientConnect(client, "localhost", 1883, 1, 1));
+ ASSERT(TestClientWait(client, 5000));
+ TestClientDisconnect(client);
+ TestClientFree(client);
+
+ PASS();
+}
+
+GREATEST_MAIN_DEFS();
+
+int main(int argc, char **argv)
+{
+ GREATEST_MAIN_BEGIN();
+ cleanup();
+ RUN_TEST(ping_test);
+ GREATEST_MAIN_END();
+}
diff --git a/test/interop/testclient.c b/test/interop/testclient.c
index 8d616f6..09782b2 100644
--- a/test/interop/testclient.c
+++ b/test/interop/testclient.c
@@ -14,12 +14,15 @@ static void TestClientOnConnect(MqttClient *client,
}
static void TestClientOnSubscribe(MqttClient *client, int id,
- const char *filter,
- MqttSubscriptionStatus status)
+ int *qos, int count)
{
TestClient *testClient = (TestClient *) MqttClientGetUserData(client);
testClient->subId = id;
- testClient->subStatus[testClient->subCount++] = status;
+ for (testClient->subCount = 0; testClient->subCount < count;
+ ++testClient->subCount)
+ {
+ testClient->subStatus[testClient->subCount] = qos[testClient->subCount];
+ }
}
static void TestClientOnPublish(MqttClient *client, int id)
@@ -37,6 +40,12 @@ static void TestClientOnMessage(MqttClient *client, const char *topic,
SIMPLEQ_INSERT_TAIL(&testClient->messages, msg, chain);
}
+static void TestClientOnUnsubscribe(MqttClient *client, int id)
+{
+ TestClient *testClient = (TestClient *) MqttClientGetUserData(client);
+ testClient->unsubId = id;
+}
+
Message *MessageNew(const char *topic, const void *data, size_t size,
int qos, int retain)
{
@@ -69,6 +78,8 @@ TestClient *TestClientNew(const char *clientId)
{
TestClient *client = calloc(1, sizeof(*client));
+ client->clientId = clientId;
+
client->client = MqttClientNew(clientId);
MqttClientSetUserData(client->client, client);
@@ -79,6 +90,7 @@ TestClient *TestClientNew(const char *clientId)
MqttClientSetOnSubscribe(client->client, TestClientOnSubscribe);
MqttClientSetOnPublish(client->client, TestClientOnPublish);
MqttClientSetOnMessage(client->client, TestClientOnMessage);
+ MqttClientSetOnUnsubscribe(client->client, TestClientOnUnsubscribe);
SIMPLEQ_INIT(&client->messages);
@@ -235,8 +247,8 @@ int TestClientWait(TestClient *client, int timeout)
printf("TestClientWait timeout:%d rc:%d\n", timeout, rc);
int64_t now = MqttGetCurrentTime();
int64_t elapsed = now - start;
- timeout -= elapsed;
printf("TestClientWait elapsed:%d\n", (int) elapsed);
+ timeout = timeout - elapsed;
if (timeout <= 0)
{
break;
@@ -245,3 +257,27 @@ int TestClientWait(TestClient *client, int timeout)
return rc != -1;
}
+
+int TestClientUnsubscribe(TestClient *client, const char *topic)
+{
+ int id = MqttClientUnsubscribe(client->client, topic);
+ int rc;
+
+ client->unsubId = -1;
+
+ while ((rc = MqttClientRunOnce(client->client, -1)) != -1)
+ {
+ if (client->unsubId != -1)
+ {
+ if (client->unsubId != id)
+ {
+ printf(
+ "WARNING: unsubscribe id mismatch: expected %d, got %d\n",
+ id, client->unsubId);
+ }
+ break;
+ }
+ }
+
+ return rc != -1;
+}
diff --git a/test/interop/testclient.h b/test/interop/testclient.h
index 2aa229b..3665f5e 100644
--- a/test/interop/testclient.h
+++ b/test/interop/testclient.h
@@ -21,6 +21,8 @@ typedef struct TestClient TestClient;
struct TestClient
{
+ const char *clientId;
+
MqttClient *client;
/* OnConnect */
@@ -37,6 +39,9 @@ struct TestClient
/* OnMessage */
SIMPLEQ_HEAD(messages, Message) messages;
+
+ /* OnUnsubscribe */
+ int unsubId;
};
Message *MessageNew(const char *topic, const void *data, size_t size,
@@ -65,4 +70,6 @@ int TestClientMessageCount(TestClient *client);
int TestClientWait(TestClient *client, int timeout);
+int TestClientUnsubscribe(TestClient *client, const char *topic);
+
#endif
diff --git a/test/interop/unsubscribe_test.c b/test/interop/unsubscribe_test.c
new file mode 100644
index 0000000..a7e4668
--- /dev/null
+++ b/test/interop/unsubscribe_test.c
@@ -0,0 +1,31 @@
+#include "greatest.h"
+#include "testclient.h"
+#include "cleanup.c"
+#include "topics.c"
+
+TEST unsubscribe_test()
+{
+ TestClient *client;
+
+ client = TestClientNew("clienta");
+ ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1));
+ ASSERT(TestClientSubscribe(client, topics[0], 2));
+ ASSERT(TestClientPublish(client, 2, 0, topics[0], "msg"));
+ ASSERT(TestClientUnsubscribe(client, topics[0]));
+ ASSERT(TestClientPublish(client, 2, 0, topics[0], "msg"));
+ TestClientDisconnect(client);
+ ASSERT_EQ(1, TestClientMessageCount(client));
+ TestClientFree(client);
+
+ PASS();
+}
+
+GREATEST_MAIN_DEFS();
+
+int main(int argc, char **argv)
+{
+ GREATEST_MAIN_BEGIN();
+ cleanup();
+ RUN_TEST(unsubscribe_test);
+ GREATEST_MAIN_END();
+}
diff --git a/test/interop/username_and_password_test.c b/test/interop/username_and_password_test.c
new file mode 100644
index 0000000..6e0eaab
--- /dev/null
+++ b/test/interop/username_and_password_test.c
@@ -0,0 +1,30 @@
+#include "greatest.h"
+#include "testclient.h"
+#include "cleanup.c"
+#include "topics.c"
+
+TEST username_and_password_test()
+{
+ TestClient *client;
+
+ client = TestClientNew("clienta");
+ ASSERT_EQ(0, MqttClientSetAuth(client->client, "myusername", NULL));
+ ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1));
+ TestClientDisconnect(client);
+ ASSERT_EQ(0, MqttClientSetAuth(client->client, "myusername", "mypassword"));
+ ASSERT(TestClientConnect(client, "localhost", 1883, 60, 1));
+ TestClientDisconnect(client);
+ TestClientFree(client);
+
+ PASS();
+}
+
+GREATEST_MAIN_DEFS();
+
+int main(int argc, char **argv)
+{
+ GREATEST_MAIN_BEGIN();
+ cleanup();
+ RUN_TEST(username_and_password_test);
+ GREATEST_MAIN_END();
+}
diff --git a/tools/sub.c b/tools/sub.c
index ebf3372..e0577c9 100644
--- a/tools/sub.c
+++ b/tools/sub.c
@@ -21,11 +21,10 @@ void onConnect(MqttClient *client, MqttConnectionStatus status,
MqttClientSubscribe(client, options->topic, options->qos);
}
-void onSubscribe(MqttClient *client, int id, const char *filter,
- MqttSubscriptionStatus status)
+void onSubscribe(MqttClient *client, int id, int *qos, int count)
{
(void) client;
- printf("onSubscribe id=%d status=%d\n", id, status);
+ printf("onSubscribe id=%d\n", id);
}
void onMessage(MqttClient *client, const char *topic, const void *data,